Skip to content
Snippets Groups Projects
Commit 7b441ffa authored by Zheng Xu's avatar Zheng Xu Committed by tensorflow-copybara
Browse files

Update skip_test_for_gpu for parameterized tests.

PiperOrigin-RevId: 338076995
parent 7e5e415a
No related branches found
No related tags found
No related merge requests found
......@@ -54,11 +54,11 @@ def skip_test_for_gpu(test_fn):
"""
@functools.wraps(test_fn)
def wrapped_test_fn(self):
def wrapped_test_fn(self, *args, **kwargs):
gpu_devices = tf.config.list_logical_devices('GPU')
if gpu_devices:
self.skipTest('skip GPU test')
test_fn(self)
test_fn(self, *args, **kwargs)
return wrapped_test_fn
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment