diff --git a/tensorflow_federated/python/common_libs/test_utils.py b/tensorflow_federated/python/common_libs/test_utils.py
index bc1a231a435c4e1285dc132456f3b1f253fffec1..3193517ce4ec80f77eef9e10a1903cb0570b11c8 100644
--- a/tensorflow_federated/python/common_libs/test_utils.py
+++ b/tensorflow_federated/python/common_libs/test_utils.py
@@ -54,11 +54,11 @@ def skip_test_for_gpu(test_fn):
   """
 
   @functools.wraps(test_fn)
-  def wrapped_test_fn(self):
+  def wrapped_test_fn(self, *args, **kwargs):
     gpu_devices = tf.config.list_logical_devices('GPU')
     if gpu_devices:
       self.skipTest('skip GPU test')
-    test_fn(self)
+    test_fn(self, *args, **kwargs)
 
   return wrapped_test_fn