diff --git a/tensorflow_federated/python/core/impl/BUILD b/tensorflow_federated/python/core/impl/BUILD
index 9ea79f34529b17a0e266cecc72c66d220ea787ab..f91cbaad975b75648ec7dbe3a736c37bab1da7c0 100644
--- a/tensorflow_federated/python/core/impl/BUILD
+++ b/tensorflow_federated/python/core/impl/BUILD
@@ -962,6 +962,7 @@ py_library(
     deps = [
         ":context_stack_base",
         ":tf_computation_context",
+        ":type_utils",
         "//tensorflow_federated/proto/v0:computation_py_pb2",
         "//tensorflow_federated/python/common_libs:py_typecheck",
         "//tensorflow_federated/python/common_libs:serialization_utils",
diff --git a/tensorflow_federated/python/core/impl/eager_executor.py b/tensorflow_federated/python/core/impl/eager_executor.py
index 3346c5097f912cab426a9dd998c0ab591582531e..7c90b188319134e677a7219c3290de9923fb0cbb 100644
--- a/tensorflow_federated/python/core/impl/eager_executor.py
+++ b/tensorflow_federated/python/core/impl/eager_executor.py
@@ -240,9 +240,7 @@ def to_representation_for_type(value, type_spec=None, device=None):
     if isinstance(value, list):
       value = tensorflow_utils.make_data_set_from_elements(
           None, value, type_spec.element)
-    py_typecheck.check_type(
-        value,
-        (tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset))
+    py_typecheck.check_type(value, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     element_type = type_utils.tf_dtypes_and_shapes_to_type(
         tf.compat.v1.data.get_output_types(value),
         tf.compat.v1.data.get_output_shapes(value))
diff --git a/tensorflow_federated/python/core/impl/executor_service_utils.py b/tensorflow_federated/python/core/impl/executor_service_utils.py
index 30f2c3546a836c380a75c6618d1fc0c71ec30774..0af55234f48cfa8ee9a528cd839faa366dcc2275 100644
--- a/tensorflow_federated/python/core/impl/executor_service_utils.py
+++ b/tensorflow_federated/python/core/impl/executor_service_utils.py
@@ -127,7 +127,7 @@ def serialize_sequence_value(value):
     of `executor_pb2.Value` with the serialized content of `value`, and
     `type_spec` is the type of the serialized value.
   """
-  py_typecheck.check_type(value, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+  py_typecheck.check_type(value, type_utils.TF_DATASET_REPRESENTATION_TYPES)
   # TFF must store the type spec here because TF will lose the ordering of the
   # names for `tf.data.Dataset` that return elements of `collections.Mapping`
   # type. This allows TFF to preserve and restore the key ordering upon
@@ -219,7 +219,7 @@ def serialize_value(value, type_spec=None):
         executor_pb2.Value(tuple=executor_pb2.Value.Tuple(element=tup_elems)))
     return result_proto, type_spec
   elif isinstance(type_spec, computation_types.SequenceType):
-    if not isinstance(value, tensorflow_utils.DATASET_REPRESENTATION_TYPES):
+    if not isinstance(value, type_utils.TF_DATASET_REPRESENTATION_TYPES):
       raise TypeError(
           'Cannot serialize Python type {!s} as TFF type {!s}.'.format(
               py_typecheck.type_string(type(value)),
diff --git a/tensorflow_federated/python/core/impl/tensorflow_serialization.py b/tensorflow_federated/python/core/impl/tensorflow_serialization.py
index 63a79a14657093df298af10e330fe2d055639b5c..7f16c51a8fe1fcf5bb239a11679df4c57d0cd56f 100644
--- a/tensorflow_federated/python/core/impl/tensorflow_serialization.py
+++ b/tensorflow_federated/python/core/impl/tensorflow_serialization.py
@@ -35,6 +35,7 @@ from tensorflow_federated.python.common_libs import serialization_utils
 from tensorflow_federated.python.core.api import computation_types
 from tensorflow_federated.python.core.impl import context_stack_base
 from tensorflow_federated.python.core.impl import tf_computation_context
+from tensorflow_federated.python.core.impl import type_utils
 from tensorflow_federated.python.core.impl.compiler import type_serialization
 from tensorflow_federated.python.core.impl.utils import function_utils
 from tensorflow_federated.python.core.impl.utils import tensorflow_utils
@@ -359,8 +360,7 @@ def serialize_dataset(
     SerializationError: if there was an error in TensorFlow during
       serialization.
   """
-  py_typecheck.check_type(dataset,
-                          tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+  py_typecheck.check_type(dataset, type_utils.TF_DATASET_REPRESENTATION_TYPES)
   module = tf.Module()
   module.dataset = dataset
   module.dataset_fn = tf.function(lambda: module.dataset, input_signature=())
diff --git a/tensorflow_federated/python/core/impl/type_utils.py b/tensorflow_federated/python/core/impl/type_utils.py
index 4e1da1693c93c1b4832b75bfbd0e67e2490c854d..b867095035f2477a3b179b919764d45c18b5fe37 100644
--- a/tensorflow_federated/python/core/impl/type_utils.py
+++ b/tensorflow_federated/python/core/impl/type_utils.py
@@ -32,6 +32,10 @@ from tensorflow_federated.python.core.api import typed_object
 from tensorflow_federated.python.core.impl.compiler import placement_literals
 
 
+TF_DATASET_REPRESENTATION_TYPES = (tf.data.Dataset, tf.compat.v1.data.Dataset,
+                                   tf.compat.v2.data.Dataset)
+
+
 def infer_type(arg):
   """Infers the TFF type of the argument (a `computation_types.Type` instance).
 
@@ -57,9 +61,7 @@ def infer_type(arg):
     return arg.type_signature
   elif tf.is_tensor(arg):
     return computation_types.TensorType(arg.dtype.base_dtype, arg.shape)
-  elif isinstance(
-      arg,
-      (tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset)):
+  elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES):
     return computation_types.SequenceType(
         tf_dtypes_and_shapes_to_type(
             tf.compat.v1.data.get_output_types(arg),
diff --git a/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py b/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py
index f0c251bbb7145c0cbbdbafce08b1acbd940562bf..abf099353405a0652dab519a2775fddf992cccee 100644
--- a/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py
+++ b/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py
@@ -368,10 +368,6 @@ def stamp_parameter_in_graph(parameter_name, parameter_type, graph):
         'graph.'.format(parameter_type))
 
 
-DATASET_REPRESENTATION_TYPES = (tf.data.Dataset, tf.compat.v1.data.Dataset,
-                                tf.compat.v2.data.Dataset)
-
-
 def make_dataset_from_variant_tensor(variant_tensor, type_spec):
   """Constructs a `tf.data.Dataset` from a variant tensor and type spec.
 
@@ -1041,7 +1037,7 @@ def fetch_value_in_session(sess, value):
   py_typecheck.check_type(sess, tf.compat.v1.Session)
   # TODO(b/113123634): Investigate handling `list`s and `tuple`s of
   # `tf.data.Dataset`s and what the API would look like to support this.
-  if isinstance(value, DATASET_REPRESENTATION_TYPES):
+  if isinstance(value, type_utils.TF_DATASET_REPRESENTATION_TYPES):
     with sess.graph.as_default():
       iterator = tf.compat.v1.data.make_one_shot_iterator(value)
       next_element = iterator.get_next()
@@ -1057,7 +1053,7 @@ def fetch_value_in_session(sess, value):
     dataset_results = {}
     flat_tensors = []
     for idx, v in enumerate(flattened_value):
-      if isinstance(v, DATASET_REPRESENTATION_TYPES):
+      if isinstance(v, type_utils.TF_DATASET_REPRESENTATION_TYPES):
         dataset_tensors = fetch_value_in_session(sess, v)
         if not dataset_tensors:
           # An empty list has been returned; we must pack the shape information
@@ -1193,7 +1189,7 @@ def coerce_dataset_elements_to_tff_type_spec(dataset, element_type):
     ValueError: if the elements of `dataset` cannot be coerced into
       `element_type`.
   """
-  py_typecheck.check_type(dataset, DATASET_REPRESENTATION_TYPES)
+  py_typecheck.check_type(dataset, type_utils.TF_DATASET_REPRESENTATION_TYPES)
   py_typecheck.check_type(element_type, computation_types.Type)
 
   if isinstance(element_type, computation_types.TensorType):
diff --git a/tensorflow_federated/python/core/impl/utils/tensorflow_utils_test.py b/tensorflow_federated/python/core/impl/utils/tensorflow_utils_test.py
index 9753e644a1f0260413162d477c86465a33ddacf7..f1cb8f954d6503fa8b4bad02a54d0d9ab93274cd 100644
--- a/tensorflow_federated/python/core/impl/utils/tensorflow_utils_test.py
+++ b/tensorflow_federated/python/core/impl/utils/tensorflow_utils_test.py
@@ -45,7 +45,7 @@ class GraphUtilsTest(test.TestCase):
       self.assertEqual(type_spec.dtype, val.dtype.base_dtype)
       self.assertEqual(repr(type_spec.shape), repr(val.shape))
     elif binding_oneof == 'sequence':
-      self.assertIsInstance(val, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+      self.assertIsInstance(val, type_utils.TF_DATASET_REPRESENTATION_TYPES)
       sequence_oneof = binding.sequence.WhichOneof('binding')
       self.assertEqual(sequence_oneof, 'variant_tensor_name')
       variant_tensor = graph.get_tensor_by_name(
@@ -148,7 +148,7 @@ class GraphUtilsTest(test.TestCase):
     with tf.Graph().as_default():
       x = self._checked_stamp_parameter('foo',
                                         computation_types.SequenceType(tf.bool))
-      self.assertIsInstance(x, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+      self.assertIsInstance(x, type_utils.TF_DATASET_REPRESENTATION_TYPES)
       test.assert_nested_struct_eq(
           tf.compat.v1.data.get_output_types(x), tf.bool)
       test.assert_nested_struct_eq(
@@ -158,7 +158,7 @@ class GraphUtilsTest(test.TestCase):
     with tf.Graph().as_default():
       x = self._checked_stamp_parameter(
           'foo', computation_types.SequenceType((tf.int32, [50])))
-      self.assertIsInstance(x, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+      self.assertIsInstance(x, type_utils.TF_DATASET_REPRESENTATION_TYPES)
       test.assert_nested_struct_eq(
           tf.compat.v1.data.get_output_types(x), tf.int32)
       test.assert_nested_struct_eq(
@@ -171,7 +171,7 @@ class GraphUtilsTest(test.TestCase):
           computation_types.SequenceType(
               collections.OrderedDict([('A', (tf.float32, [3, 4, 5])),
                                        ('B', (tf.int32, [1]))])))
-      self.assertIsInstance(x, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+      self.assertIsInstance(x, type_utils.TF_DATASET_REPRESENTATION_TYPES)
       test.assert_nested_struct_eq(
           tf.compat.v1.data.get_output_types(x), {
               'A': tf.float32,
@@ -464,7 +464,7 @@ class GraphUtilsTest(test.TestCase):
     output_map = {'foo': tf.data.experimental.to_variant(data_set)}
     result = tensorflow_utils.assemble_result_from_graph(
         type_spec, binding, output_map)
-    self.assertIsInstance(result, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(result, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     self.assertEqual(
         str(tf.compat.v1.data.get_output_types(result)),
         'OrderedDict([(\'X\', tf.int32), (\'Y\', tf.int32)])')
@@ -486,7 +486,7 @@ class GraphUtilsTest(test.TestCase):
     output_map = {'foo': tf.data.experimental.to_variant(data_set)}
     result = tensorflow_utils.assemble_result_from_graph(
         type_spec, binding, output_map)
-    self.assertIsInstance(result, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(result, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     self.assertEqual(
         str(tf.compat.v1.data.get_output_types(result)),
         'TestNamedTuple(X=tf.int32, Y=tf.int32)')
@@ -566,7 +566,7 @@ class GraphUtilsTest(test.TestCase):
   def test_make_data_set_from_elements_with_empty_list(self):
     ds = tensorflow_utils.make_data_set_from_elements(
         tf.compat.v1.get_default_graph(), [], tf.float32)
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     self.assertEqual(
         tf.compat.v1.Session().run(ds.reduce(1.0, lambda x, y: x + y)), 1.0)
 
@@ -575,7 +575,7 @@ class GraphUtilsTest(test.TestCase):
     ds = tensorflow_utils.make_data_set_from_elements(
         tf.compat.v1.get_default_graph(), [],
         computation_types.TensorType(tf.float32, [None, 10]))
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     self.assertEqual(
         tf.compat.v1.data.get_output_shapes(ds).as_list(),
         tf.TensorShape([0, 10]).as_list())
@@ -589,14 +589,14 @@ class GraphUtilsTest(test.TestCase):
             computation_types.TensorType(tf.float32, [None, 10]),
             computation_types.TensorType(tf.float32, [None, 5])
         ])
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     self.assertEqual(tf.compat.v1.data.get_output_shapes(ds), ([0, 10], [0, 5]))
 
   @test.graph_mode_test
   def test_make_data_set_from_elements_with_list_of_ints(self):
     ds = tensorflow_utils.make_data_set_from_elements(
         tf.compat.v1.get_default_graph(), [1, 2, 3, 4], tf.int32)
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     self.assertEqual(
         tf.compat.v1.Session().run(ds.reduce(0, lambda x, y: x + y)), 10)
 
@@ -610,7 +610,7 @@ class GraphUtilsTest(test.TestCase):
             'a': 3,
             'b': 4,
         }], [('a', tf.int32), ('b', tf.int32)])
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     self.assertEqual(
         tf.compat.v1.Session().run(
             ds.reduce(0, lambda x, y: x + y['a'] + y['b'])), 10)
@@ -628,7 +628,7 @@ class GraphUtilsTest(test.TestCase):
                 ('b', 4),
             ]),
         ], [('a', tf.int32), ('b', tf.int32)])
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     self.assertEqual(
         tf.compat.v1.Session().run(
             ds.reduce(0, lambda x, y: x + y['a'] + y['b'])), 10)
@@ -640,7 +640,7 @@ class GraphUtilsTest(test.TestCase):
             [[1], [2]],
             [[3], [4]],
         ], [[tf.int32], [tf.int32]])
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     self.assertEqual(
         tf.compat.v1.Session().run(
             ds.reduce(0, lambda x, y: x + tf.reduce_sum(y))), 10)
@@ -658,7 +658,7 @@ class GraphUtilsTest(test.TestCase):
                 ('b', 4),
             ]),
         ], [('a', tf.int32), ('b', tf.int32)])
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
     self.assertEqual(
         tf.compat.v1.Session().run(
             ds.reduce(0, lambda x, y: x + y['a'] + y['b'])), 10)
@@ -674,7 +674,7 @@ class GraphUtilsTest(test.TestCase):
             'b': [4],
         }], [('a', [tf.int32]), ('b', [tf.int32])])
 
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
 
     def reduce_fn(x, y):
       return x + tf.reduce_sum(y['a']) + tf.reduce_sum(y['b'])
@@ -692,7 +692,7 @@ class GraphUtilsTest(test.TestCase):
             'b': 4,
         }], [('a', tf.int32), ('b', tf.int32)])
 
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
 
     def reduce_fn(x, y):
       return x + tf.reduce_sum(y['a']) + tf.reduce_sum(y['b'])
@@ -709,7 +709,7 @@ class GraphUtilsTest(test.TestCase):
             'a': np.array([3], dtype=np.int32),
             'b': np.array([4], dtype=np.int32),
         }], [('a', (tf.int32, [1])), ('b', (tf.int32, [1]))])
-    self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
+    self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES)
 
     def reduce_fn(x, y):
       return x + tf.reduce_sum(y['a']) + tf.reduce_sum(y['b'])