diff --git a/tensorflow_federated/python/core/impl/BUILD b/tensorflow_federated/python/core/impl/BUILD index 9ea79f34529b17a0e266cecc72c66d220ea787ab..f91cbaad975b75648ec7dbe3a736c37bab1da7c0 100644 --- a/tensorflow_federated/python/core/impl/BUILD +++ b/tensorflow_federated/python/core/impl/BUILD @@ -962,6 +962,7 @@ py_library( deps = [ ":context_stack_base", ":tf_computation_context", + ":type_utils", "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:serialization_utils", diff --git a/tensorflow_federated/python/core/impl/eager_executor.py b/tensorflow_federated/python/core/impl/eager_executor.py index 3346c5097f912cab426a9dd998c0ab591582531e..7c90b188319134e677a7219c3290de9923fb0cbb 100644 --- a/tensorflow_federated/python/core/impl/eager_executor.py +++ b/tensorflow_federated/python/core/impl/eager_executor.py @@ -240,9 +240,7 @@ def to_representation_for_type(value, type_spec=None, device=None): if isinstance(value, list): value = tensorflow_utils.make_data_set_from_elements( None, value, type_spec.element) - py_typecheck.check_type( - value, - (tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset)) + py_typecheck.check_type(value, type_utils.TF_DATASET_REPRESENTATION_TYPES) element_type = type_utils.tf_dtypes_and_shapes_to_type( tf.compat.v1.data.get_output_types(value), tf.compat.v1.data.get_output_shapes(value)) diff --git a/tensorflow_federated/python/core/impl/executor_service_utils.py b/tensorflow_federated/python/core/impl/executor_service_utils.py index 30f2c3546a836c380a75c6618d1fc0c71ec30774..0af55234f48cfa8ee9a528cd839faa366dcc2275 100644 --- a/tensorflow_federated/python/core/impl/executor_service_utils.py +++ b/tensorflow_federated/python/core/impl/executor_service_utils.py @@ -127,7 +127,7 @@ def serialize_sequence_value(value): of `executor_pb2.Value` with the serialized content of `value`, and `type_spec` is the type of the serialized value. """ - py_typecheck.check_type(value, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + py_typecheck.check_type(value, type_utils.TF_DATASET_REPRESENTATION_TYPES) # TFF must store the type spec here because TF will lose the ordering of the # names for `tf.data.Dataset` that return elements of `collections.Mapping` # type. This allows TFF to preserve and restore the key ordering upon @@ -219,7 +219,7 @@ def serialize_value(value, type_spec=None): executor_pb2.Value(tuple=executor_pb2.Value.Tuple(element=tup_elems))) return result_proto, type_spec elif isinstance(type_spec, computation_types.SequenceType): - if not isinstance(value, tensorflow_utils.DATASET_REPRESENTATION_TYPES): + if not isinstance(value, type_utils.TF_DATASET_REPRESENTATION_TYPES): raise TypeError( 'Cannot serialize Python type {!s} as TFF type {!s}.'.format( py_typecheck.type_string(type(value)), diff --git a/tensorflow_federated/python/core/impl/tensorflow_serialization.py b/tensorflow_federated/python/core/impl/tensorflow_serialization.py index 63a79a14657093df298af10e330fe2d055639b5c..7f16c51a8fe1fcf5bb239a11679df4c57d0cd56f 100644 --- a/tensorflow_federated/python/core/impl/tensorflow_serialization.py +++ b/tensorflow_federated/python/core/impl/tensorflow_serialization.py @@ -35,6 +35,7 @@ from tensorflow_federated.python.common_libs import serialization_utils from tensorflow_federated.python.core.api import computation_types from tensorflow_federated.python.core.impl import context_stack_base from tensorflow_federated.python.core.impl import tf_computation_context +from tensorflow_federated.python.core.impl import type_utils from tensorflow_federated.python.core.impl.compiler import type_serialization from tensorflow_federated.python.core.impl.utils import function_utils from tensorflow_federated.python.core.impl.utils import tensorflow_utils @@ -359,8 +360,7 @@ def serialize_dataset( SerializationError: if there was an error in TensorFlow during serialization. """ - py_typecheck.check_type(dataset, - tensorflow_utils.DATASET_REPRESENTATION_TYPES) + py_typecheck.check_type(dataset, type_utils.TF_DATASET_REPRESENTATION_TYPES) module = tf.Module() module.dataset = dataset module.dataset_fn = tf.function(lambda: module.dataset, input_signature=()) diff --git a/tensorflow_federated/python/core/impl/type_utils.py b/tensorflow_federated/python/core/impl/type_utils.py index 4e1da1693c93c1b4832b75bfbd0e67e2490c854d..b867095035f2477a3b179b919764d45c18b5fe37 100644 --- a/tensorflow_federated/python/core/impl/type_utils.py +++ b/tensorflow_federated/python/core/impl/type_utils.py @@ -32,6 +32,10 @@ from tensorflow_federated.python.core.api import typed_object from tensorflow_federated.python.core.impl.compiler import placement_literals +TF_DATASET_REPRESENTATION_TYPES = (tf.data.Dataset, tf.compat.v1.data.Dataset, + tf.compat.v2.data.Dataset) + + def infer_type(arg): """Infers the TFF type of the argument (a `computation_types.Type` instance). @@ -57,9 +61,7 @@ def infer_type(arg): return arg.type_signature elif tf.is_tensor(arg): return computation_types.TensorType(arg.dtype.base_dtype, arg.shape) - elif isinstance( - arg, - (tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset)): + elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES): return computation_types.SequenceType( tf_dtypes_and_shapes_to_type( tf.compat.v1.data.get_output_types(arg), diff --git a/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py b/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py index f0c251bbb7145c0cbbdbafce08b1acbd940562bf..abf099353405a0652dab519a2775fddf992cccee 100644 --- a/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py +++ b/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py @@ -368,10 +368,6 @@ def stamp_parameter_in_graph(parameter_name, parameter_type, graph): 'graph.'.format(parameter_type)) -DATASET_REPRESENTATION_TYPES = (tf.data.Dataset, tf.compat.v1.data.Dataset, - tf.compat.v2.data.Dataset) - - def make_dataset_from_variant_tensor(variant_tensor, type_spec): """Constructs a `tf.data.Dataset` from a variant tensor and type spec. @@ -1041,7 +1037,7 @@ def fetch_value_in_session(sess, value): py_typecheck.check_type(sess, tf.compat.v1.Session) # TODO(b/113123634): Investigate handling `list`s and `tuple`s of # `tf.data.Dataset`s and what the API would look like to support this. - if isinstance(value, DATASET_REPRESENTATION_TYPES): + if isinstance(value, type_utils.TF_DATASET_REPRESENTATION_TYPES): with sess.graph.as_default(): iterator = tf.compat.v1.data.make_one_shot_iterator(value) next_element = iterator.get_next() @@ -1057,7 +1053,7 @@ def fetch_value_in_session(sess, value): dataset_results = {} flat_tensors = [] for idx, v in enumerate(flattened_value): - if isinstance(v, DATASET_REPRESENTATION_TYPES): + if isinstance(v, type_utils.TF_DATASET_REPRESENTATION_TYPES): dataset_tensors = fetch_value_in_session(sess, v) if not dataset_tensors: # An empty list has been returned; we must pack the shape information @@ -1193,7 +1189,7 @@ def coerce_dataset_elements_to_tff_type_spec(dataset, element_type): ValueError: if the elements of `dataset` cannot be coerced into `element_type`. """ - py_typecheck.check_type(dataset, DATASET_REPRESENTATION_TYPES) + py_typecheck.check_type(dataset, type_utils.TF_DATASET_REPRESENTATION_TYPES) py_typecheck.check_type(element_type, computation_types.Type) if isinstance(element_type, computation_types.TensorType): diff --git a/tensorflow_federated/python/core/impl/utils/tensorflow_utils_test.py b/tensorflow_federated/python/core/impl/utils/tensorflow_utils_test.py index 9753e644a1f0260413162d477c86465a33ddacf7..f1cb8f954d6503fa8b4bad02a54d0d9ab93274cd 100644 --- a/tensorflow_federated/python/core/impl/utils/tensorflow_utils_test.py +++ b/tensorflow_federated/python/core/impl/utils/tensorflow_utils_test.py @@ -45,7 +45,7 @@ class GraphUtilsTest(test.TestCase): self.assertEqual(type_spec.dtype, val.dtype.base_dtype) self.assertEqual(repr(type_spec.shape), repr(val.shape)) elif binding_oneof == 'sequence': - self.assertIsInstance(val, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(val, type_utils.TF_DATASET_REPRESENTATION_TYPES) sequence_oneof = binding.sequence.WhichOneof('binding') self.assertEqual(sequence_oneof, 'variant_tensor_name') variant_tensor = graph.get_tensor_by_name( @@ -148,7 +148,7 @@ class GraphUtilsTest(test.TestCase): with tf.Graph().as_default(): x = self._checked_stamp_parameter('foo', computation_types.SequenceType(tf.bool)) - self.assertIsInstance(x, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(x, type_utils.TF_DATASET_REPRESENTATION_TYPES) test.assert_nested_struct_eq( tf.compat.v1.data.get_output_types(x), tf.bool) test.assert_nested_struct_eq( @@ -158,7 +158,7 @@ class GraphUtilsTest(test.TestCase): with tf.Graph().as_default(): x = self._checked_stamp_parameter( 'foo', computation_types.SequenceType((tf.int32, [50]))) - self.assertIsInstance(x, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(x, type_utils.TF_DATASET_REPRESENTATION_TYPES) test.assert_nested_struct_eq( tf.compat.v1.data.get_output_types(x), tf.int32) test.assert_nested_struct_eq( @@ -171,7 +171,7 @@ class GraphUtilsTest(test.TestCase): computation_types.SequenceType( collections.OrderedDict([('A', (tf.float32, [3, 4, 5])), ('B', (tf.int32, [1]))]))) - self.assertIsInstance(x, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(x, type_utils.TF_DATASET_REPRESENTATION_TYPES) test.assert_nested_struct_eq( tf.compat.v1.data.get_output_types(x), { 'A': tf.float32, @@ -464,7 +464,7 @@ class GraphUtilsTest(test.TestCase): output_map = {'foo': tf.data.experimental.to_variant(data_set)} result = tensorflow_utils.assemble_result_from_graph( type_spec, binding, output_map) - self.assertIsInstance(result, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(result, type_utils.TF_DATASET_REPRESENTATION_TYPES) self.assertEqual( str(tf.compat.v1.data.get_output_types(result)), 'OrderedDict([(\'X\', tf.int32), (\'Y\', tf.int32)])') @@ -486,7 +486,7 @@ class GraphUtilsTest(test.TestCase): output_map = {'foo': tf.data.experimental.to_variant(data_set)} result = tensorflow_utils.assemble_result_from_graph( type_spec, binding, output_map) - self.assertIsInstance(result, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(result, type_utils.TF_DATASET_REPRESENTATION_TYPES) self.assertEqual( str(tf.compat.v1.data.get_output_types(result)), 'TestNamedTuple(X=tf.int32, Y=tf.int32)') @@ -566,7 +566,7 @@ class GraphUtilsTest(test.TestCase): def test_make_data_set_from_elements_with_empty_list(self): ds = tensorflow_utils.make_data_set_from_elements( tf.compat.v1.get_default_graph(), [], tf.float32) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) self.assertEqual( tf.compat.v1.Session().run(ds.reduce(1.0, lambda x, y: x + y)), 1.0) @@ -575,7 +575,7 @@ class GraphUtilsTest(test.TestCase): ds = tensorflow_utils.make_data_set_from_elements( tf.compat.v1.get_default_graph(), [], computation_types.TensorType(tf.float32, [None, 10])) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) self.assertEqual( tf.compat.v1.data.get_output_shapes(ds).as_list(), tf.TensorShape([0, 10]).as_list()) @@ -589,14 +589,14 @@ class GraphUtilsTest(test.TestCase): computation_types.TensorType(tf.float32, [None, 10]), computation_types.TensorType(tf.float32, [None, 5]) ]) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) self.assertEqual(tf.compat.v1.data.get_output_shapes(ds), ([0, 10], [0, 5])) @test.graph_mode_test def test_make_data_set_from_elements_with_list_of_ints(self): ds = tensorflow_utils.make_data_set_from_elements( tf.compat.v1.get_default_graph(), [1, 2, 3, 4], tf.int32) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) self.assertEqual( tf.compat.v1.Session().run(ds.reduce(0, lambda x, y: x + y)), 10) @@ -610,7 +610,7 @@ class GraphUtilsTest(test.TestCase): 'a': 3, 'b': 4, }], [('a', tf.int32), ('b', tf.int32)]) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) self.assertEqual( tf.compat.v1.Session().run( ds.reduce(0, lambda x, y: x + y['a'] + y['b'])), 10) @@ -628,7 +628,7 @@ class GraphUtilsTest(test.TestCase): ('b', 4), ]), ], [('a', tf.int32), ('b', tf.int32)]) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) self.assertEqual( tf.compat.v1.Session().run( ds.reduce(0, lambda x, y: x + y['a'] + y['b'])), 10) @@ -640,7 +640,7 @@ class GraphUtilsTest(test.TestCase): [[1], [2]], [[3], [4]], ], [[tf.int32], [tf.int32]]) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) self.assertEqual( tf.compat.v1.Session().run( ds.reduce(0, lambda x, y: x + tf.reduce_sum(y))), 10) @@ -658,7 +658,7 @@ class GraphUtilsTest(test.TestCase): ('b', 4), ]), ], [('a', tf.int32), ('b', tf.int32)]) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) self.assertEqual( tf.compat.v1.Session().run( ds.reduce(0, lambda x, y: x + y['a'] + y['b'])), 10) @@ -674,7 +674,7 @@ class GraphUtilsTest(test.TestCase): 'b': [4], }], [('a', [tf.int32]), ('b', [tf.int32])]) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) def reduce_fn(x, y): return x + tf.reduce_sum(y['a']) + tf.reduce_sum(y['b']) @@ -692,7 +692,7 @@ class GraphUtilsTest(test.TestCase): 'b': 4, }], [('a', tf.int32), ('b', tf.int32)]) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) def reduce_fn(x, y): return x + tf.reduce_sum(y['a']) + tf.reduce_sum(y['b']) @@ -709,7 +709,7 @@ class GraphUtilsTest(test.TestCase): 'a': np.array([3], dtype=np.int32), 'b': np.array([4], dtype=np.int32), }], [('a', (tf.int32, [1])), ('b', (tf.int32, [1]))]) - self.assertIsInstance(ds, tensorflow_utils.DATASET_REPRESENTATION_TYPES) + self.assertIsInstance(ds, type_utils.TF_DATASET_REPRESENTATION_TYPES) def reduce_fn(x, y): return x + tf.reduce_sum(y['a']) + tf.reduce_sum(y['b'])