From fb91b2de80cbcc75cc7f8c6c0de293b67098698f Mon Sep 17 00:00:00 2001 From: Michael Reneer <michaelreneer@google.com> Date: Fri, 24 Jul 2020 14:27:58 -0700 Subject: [PATCH] Move the `tensorflow_computation_context` modules into the `impl/tensorflow_context` package. PiperOrigin-RevId: 323072438 --- tensorflow_federated/python/core/impl/BUILD | 34 +------------------ .../python/core/impl/context_stack/BUILD | 1 + .../python/core/impl/tensorflow_context/BUILD | 31 +++++++++++++++++ .../tensorflow_computation_context.py} | 0 .../tensorflow_computation_context_test.py} | 25 ++++++++------ .../core/impl/tensorflow_deserialization.py | 8 ++--- .../core/impl/tensorflow_serialization.py | 4 +-- .../python/core/impl/types/BUILD | 1 + 8 files changed, 55 insertions(+), 49 deletions(-) rename tensorflow_federated/python/core/impl/{tf_computation_context.py => tensorflow_context/tensorflow_computation_context.py} (100%) rename tensorflow_federated/python/core/impl/{tf_computation_context_test.py => tensorflow_context/tensorflow_computation_context_test.py} (74%) diff --git a/tensorflow_federated/python/core/impl/BUILD b/tensorflow_federated/python/core/impl/BUILD index f07bb10f8..2224fff3b 100644 --- a/tensorflow_federated/python/core/impl/BUILD +++ b/tensorflow_federated/python/core/impl/BUILD @@ -368,12 +368,12 @@ py_library( srcs = ["tensorflow_serialization.py"], srcs_version = "PY3", deps = [ - ":tf_computation_context", "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:serialization_utils", "//tensorflow_federated/python/core/api:computation_types", "//tensorflow_federated/python/core/impl/context_stack:context_stack_base", + "//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation_context", "//tensorflow_federated/python/core/impl/types:type_conversions", "//tensorflow_federated/python/core/impl/types:type_serialization", "//tensorflow_federated/python/core/impl/utils:function_utils", @@ -406,38 +406,6 @@ py_library( deps = ["//tensorflow_federated/python/core/api:computations"], ) -py_library( - name = "tf_computation_context", - srcs = ["tf_computation_context.py"], - srcs_version = "PY3", - deps = [ - ":computation_impl", - ":tensorflow_deserialization", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/api:computation_base", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_conversions", - ], -) - -py_test( - name = "tf_computation_context_test", - size = "small", - srcs = ["tf_computation_context_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":tf_computation_context", - "//tensorflow_federated/python/common_libs:test", - "//tensorflow_federated/python/core/api:computation_types", - "//tensorflow_federated/python/core/api:computations", - "//tensorflow_federated/python/core/api:intrinsics", - "//tensorflow_federated/python/core/impl/executors:default_executor", - "//tensorflow_federated/python/core/impl/types:placement_literals", - ], -) - py_library( name = "tree_to_cc_transformations", srcs = ["tree_to_cc_transformations.py"], diff --git a/tensorflow_federated/python/core/impl/context_stack/BUILD b/tensorflow_federated/python/core/impl/context_stack/BUILD index 5d2b07aa4..a77234af8 100644 --- a/tensorflow_federated/python/core/impl/context_stack/BUILD +++ b/tensorflow_federated/python/core/impl/context_stack/BUILD @@ -7,6 +7,7 @@ package_group( # Impl Dependencies "//tensorflow_federated/python/core/impl/executors/...", + "//tensorflow_federated/python/core/impl/tensorflow_context/...", "//tensorflow_federated/python/core/impl/utils/...", "//tensorflow_federated/python/core/impl/wrappers/...", diff --git a/tensorflow_federated/python/core/impl/tensorflow_context/BUILD b/tensorflow_federated/python/core/impl/tensorflow_context/BUILD index 952bbf130..7034f2866 100644 --- a/tensorflow_federated/python/core/impl/tensorflow_context/BUILD +++ b/tensorflow_federated/python/core/impl/tensorflow_context/BUILD @@ -25,3 +25,34 @@ py_library( srcs_version = "PY3", visibility = ["//tensorflow_federated/tools:__subpackages__"], ) + +py_library( + name = "tensorflow_computation_context", + srcs = ["tensorflow_computation_context.py"], + srcs_version = "PY3", + deps = [ + "//tensorflow_federated/python/common_libs:py_typecheck", + "//tensorflow_federated/python/core/api:computation_base", + "//tensorflow_federated/python/core/impl:computation_impl", + "//tensorflow_federated/python/core/impl:tensorflow_deserialization", + "//tensorflow_federated/python/core/impl/context_stack:context_base", + "//tensorflow_federated/python/core/impl/types:type_analysis", + "//tensorflow_federated/python/core/impl/types:type_conversions", + ], +) + +py_test( + name = "tensorflow_computation_context_test", + size = "small", + srcs = ["tensorflow_computation_context_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":tensorflow_computation_context", + "//tensorflow_federated/python/common_libs:test", + "//tensorflow_federated/python/core/api:computation_types", + "//tensorflow_federated/python/core/api:computations", + "//tensorflow_federated/python/core/api:intrinsics", + "//tensorflow_federated/python/core/impl/types:placement_literals", + ], +) diff --git a/tensorflow_federated/python/core/impl/tf_computation_context.py b/tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context.py similarity index 100% rename from tensorflow_federated/python/core/impl/tf_computation_context.py rename to tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context.py diff --git a/tensorflow_federated/python/core/impl/tf_computation_context_test.py b/tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context_test.py similarity index 74% rename from tensorflow_federated/python/core/impl/tf_computation_context_test.py rename to tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context_test.py index 2b317b5c6..ab427c777 100644 --- a/tensorflow_federated/python/core/impl/tf_computation_context_test.py +++ b/tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context_test.py @@ -18,14 +18,13 @@ from tensorflow_federated.python.common_libs import test from tensorflow_federated.python.core.api import computation_types from tensorflow_federated.python.core.api import computations from tensorflow_federated.python.core.api import intrinsics -from tensorflow_federated.python.core.impl import tf_computation_context -from tensorflow_federated.python.core.impl.executors import default_executor +from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation_context from tensorflow_federated.python.core.impl.types import placement_literals class TensorFlowComputationContextTest(test.TestCase): - def test_invoke_federated_computation_fails(self): + def test_invoke_raises_value_error_with_federated_computation(self): @computations.federated_computation( computation_types.FederatedType(tf.int32, placement_literals.SERVER, @@ -33,13 +32,14 @@ class TensorFlowComputationContextTest(test.TestCase): def foo(x): return intrinsics.federated_broadcast(x) - context = tf_computation_context.TensorFlowComputationContext( + context = tensorflow_computation_context.TensorFlowComputationContext( tf.compat.v1.get_default_graph()) + with self.assertRaisesRegex(ValueError, 'Expected a TensorFlow computation.'): context.invoke(foo, None) - def test_invoke_tf_computation(self): + def test_invoke_returns_result_with_tf_computation(self): make_10 = computations.tf_computation(lambda: tf.constant(10)) add_one = computations.tf_computation(lambda x: tf.add(x, 1), tf.int32) @@ -55,17 +55,22 @@ class TensorFlowComputationContextTest(test.TestCase): @computations.tf_computation def foo(): - # Test invoking one tf_computation inside - # another. zero = tf.Variable(0, name='zero') ten = tf.Variable(make_10()) return (add_one_with_v2(add_one_with_v1(add_one(make_10()))) + zero + ten - ten) - self.assertEqual(str(foo.type_signature), '( -> int32)') - self.assertEqual(foo(), 13) + graph = tf.compat.v1.Graph() + context = tensorflow_computation_context.TensorFlowComputationContext(graph) + + self.assertEqual(foo.type_signature.compact_representation(), '( -> int32)') + x = context.invoke(foo, None) + with tf.compat.v1.Session(graph=graph) as sess: + if context.init_ops: + sess.run(context.init_ops) + result = sess.run(x) + self.assertEqual(result, 13) if __name__ == '__main__': - default_executor.initialize_default_execution_context() test.main() diff --git a/tensorflow_federated/python/core/impl/tensorflow_deserialization.py b/tensorflow_federated/python/core/impl/tensorflow_deserialization.py index a1752553a..114aed2f6 100644 --- a/tensorflow_federated/python/core/impl/tensorflow_deserialization.py +++ b/tensorflow_federated/python/core/impl/tensorflow_deserialization.py @@ -14,8 +14,8 @@ """Utilities for deserializing TensorFlow computations. Note: This is separate from `tensorflow_serialization.py` to avoid a circular -dependency through `tf_computation_context.py`. The context code depends on -the deserialization code (to implement invocation), whereas the serialization +dependency through `tensorflow_computation_context.py`. The context code depends +on the deserialization code (to implement invocation), whereas the serialization code depends on the context code (to invoke the Python function in context). """ @@ -37,8 +37,8 @@ def deserialize_and_call_tf_computation(computation_proto, arg, graph): implementation may rely on different mechanisms. The caller should not be concerned with the specifics of the implementation. At this point, the method is expected to only be used within the body of another TF computation (within - an instance of `tf_computation_context.TensorFlowComputationContext` at the - top of the stack), and potentially also in certain types of interpreted + an instance of `tensorflow_computation_context.TensorFlowComputationContext` + at the top of the stack), and potentially also in certain types of interpreted execution contexts (TBD). Args: diff --git a/tensorflow_federated/python/core/impl/tensorflow_serialization.py b/tensorflow_federated/python/core/impl/tensorflow_serialization.py index cd3159005..206473d2e 100644 --- a/tensorflow_federated/python/core/impl/tensorflow_serialization.py +++ b/tensorflow_federated/python/core/impl/tensorflow_serialization.py @@ -27,8 +27,8 @@ from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import serialization_utils from tensorflow_federated.python.core.api import computation_types -from tensorflow_federated.python.core.impl import tf_computation_context from tensorflow_federated.python.core.impl.context_stack import context_stack_base +from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation_context from tensorflow_federated.python.core.impl.types import type_conversions from tensorflow_federated.python.core.impl.types import type_serialization from tensorflow_federated.python.core.impl.utils import function_utils @@ -268,7 +268,7 @@ def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack): signature.parameters)) parameter_value = None parameter_binding = None - context = tf_computation_context.TensorFlowComputationContext(graph) + context = tensorflow_computation_context.TensorFlowComputationContext(graph) with context_stack.install(context): with variable_utils.record_variable_creation_scope() as all_variables: if parameter_value is not None: diff --git a/tensorflow_federated/python/core/impl/types/BUILD b/tensorflow_federated/python/core/impl/types/BUILD index 0bc6405f8..5450820e7 100644 --- a/tensorflow_federated/python/core/impl/types/BUILD +++ b/tensorflow_federated/python/core/impl/types/BUILD @@ -9,6 +9,7 @@ package_group( "//tensorflow_federated/python/core/impl/compiler/...", "//tensorflow_federated/python/core/impl/context_stack/...", "//tensorflow_federated/python/core/impl/executors/...", + "//tensorflow_federated/python/core/impl/tensorflow_context/...", "//tensorflow_federated/python/core/impl/utils/...", "//tensorflow_federated/python/core/impl/wrappers/...", -- GitLab