diff --git a/tensorflow_federated/python/common_libs/test_utils.py b/tensorflow_federated/python/common_libs/test_utils.py index bc1a231a435c4e1285dc132456f3b1f253fffec1..3193517ce4ec80f77eef9e10a1903cb0570b11c8 100644 --- a/tensorflow_federated/python/common_libs/test_utils.py +++ b/tensorflow_federated/python/common_libs/test_utils.py @@ -54,11 +54,11 @@ def skip_test_for_gpu(test_fn): """ @functools.wraps(test_fn) - def wrapped_test_fn(self): + def wrapped_test_fn(self, *args, **kwargs): gpu_devices = tf.config.list_logical_devices('GPU') if gpu_devices: self.skipTest('skip GPU test') - test_fn(self) + test_fn(self, *args, **kwargs) return wrapped_test_fn