From fb91b2de80cbcc75cc7f8c6c0de293b67098698f Mon Sep 17 00:00:00 2001
From: Michael Reneer <michaelreneer@google.com>
Date: Fri, 24 Jul 2020 14:27:58 -0700
Subject: [PATCH] Move the `tensorflow_computation_context` modules into the
 `impl/tensorflow_context` package.

PiperOrigin-RevId: 323072438
---
 tensorflow_federated/python/core/impl/BUILD   | 34 +------------------
 .../python/core/impl/context_stack/BUILD      |  1 +
 .../python/core/impl/tensorflow_context/BUILD | 31 +++++++++++++++++
 .../tensorflow_computation_context.py}        |  0
 .../tensorflow_computation_context_test.py}   | 25 ++++++++------
 .../core/impl/tensorflow_deserialization.py   |  8 ++---
 .../core/impl/tensorflow_serialization.py     |  4 +--
 .../python/core/impl/types/BUILD              |  1 +
 8 files changed, 55 insertions(+), 49 deletions(-)
 rename tensorflow_federated/python/core/impl/{tf_computation_context.py => tensorflow_context/tensorflow_computation_context.py} (100%)
 rename tensorflow_federated/python/core/impl/{tf_computation_context_test.py => tensorflow_context/tensorflow_computation_context_test.py} (74%)

diff --git a/tensorflow_federated/python/core/impl/BUILD b/tensorflow_federated/python/core/impl/BUILD
index f07bb10f8..2224fff3b 100644
--- a/tensorflow_federated/python/core/impl/BUILD
+++ b/tensorflow_federated/python/core/impl/BUILD
@@ -368,12 +368,12 @@ py_library(
     srcs = ["tensorflow_serialization.py"],
     srcs_version = "PY3",
     deps = [
-        ":tf_computation_context",
         "//tensorflow_federated/proto/v0:computation_py_pb2",
         "//tensorflow_federated/python/common_libs:py_typecheck",
         "//tensorflow_federated/python/common_libs:serialization_utils",
         "//tensorflow_federated/python/core/api:computation_types",
         "//tensorflow_federated/python/core/impl/context_stack:context_stack_base",
+        "//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation_context",
         "//tensorflow_federated/python/core/impl/types:type_conversions",
         "//tensorflow_federated/python/core/impl/types:type_serialization",
         "//tensorflow_federated/python/core/impl/utils:function_utils",
@@ -406,38 +406,6 @@ py_library(
     deps = ["//tensorflow_federated/python/core/api:computations"],
 )
 
-py_library(
-    name = "tf_computation_context",
-    srcs = ["tf_computation_context.py"],
-    srcs_version = "PY3",
-    deps = [
-        ":computation_impl",
-        ":tensorflow_deserialization",
-        "//tensorflow_federated/python/common_libs:py_typecheck",
-        "//tensorflow_federated/python/core/api:computation_base",
-        "//tensorflow_federated/python/core/impl/context_stack:context_base",
-        "//tensorflow_federated/python/core/impl/types:type_analysis",
-        "//tensorflow_federated/python/core/impl/types:type_conversions",
-    ],
-)
-
-py_test(
-    name = "tf_computation_context_test",
-    size = "small",
-    srcs = ["tf_computation_context_test.py"],
-    python_version = "PY3",
-    srcs_version = "PY3",
-    deps = [
-        ":tf_computation_context",
-        "//tensorflow_federated/python/common_libs:test",
-        "//tensorflow_federated/python/core/api:computation_types",
-        "//tensorflow_federated/python/core/api:computations",
-        "//tensorflow_federated/python/core/api:intrinsics",
-        "//tensorflow_federated/python/core/impl/executors:default_executor",
-        "//tensorflow_federated/python/core/impl/types:placement_literals",
-    ],
-)
-
 py_library(
     name = "tree_to_cc_transformations",
     srcs = ["tree_to_cc_transformations.py"],
diff --git a/tensorflow_federated/python/core/impl/context_stack/BUILD b/tensorflow_federated/python/core/impl/context_stack/BUILD
index 5d2b07aa4..a77234af8 100644
--- a/tensorflow_federated/python/core/impl/context_stack/BUILD
+++ b/tensorflow_federated/python/core/impl/context_stack/BUILD
@@ -7,6 +7,7 @@ package_group(
 
         # Impl Dependencies
         "//tensorflow_federated/python/core/impl/executors/...",
+        "//tensorflow_federated/python/core/impl/tensorflow_context/...",
         "//tensorflow_federated/python/core/impl/utils/...",
         "//tensorflow_federated/python/core/impl/wrappers/...",
 
diff --git a/tensorflow_federated/python/core/impl/tensorflow_context/BUILD b/tensorflow_federated/python/core/impl/tensorflow_context/BUILD
index 952bbf130..7034f2866 100644
--- a/tensorflow_federated/python/core/impl/tensorflow_context/BUILD
+++ b/tensorflow_federated/python/core/impl/tensorflow_context/BUILD
@@ -25,3 +25,34 @@ py_library(
     srcs_version = "PY3",
     visibility = ["//tensorflow_federated/tools:__subpackages__"],
 )
+
+py_library(
+    name = "tensorflow_computation_context",
+    srcs = ["tensorflow_computation_context.py"],
+    srcs_version = "PY3",
+    deps = [
+        "//tensorflow_federated/python/common_libs:py_typecheck",
+        "//tensorflow_federated/python/core/api:computation_base",
+        "//tensorflow_federated/python/core/impl:computation_impl",
+        "//tensorflow_federated/python/core/impl:tensorflow_deserialization",
+        "//tensorflow_federated/python/core/impl/context_stack:context_base",
+        "//tensorflow_federated/python/core/impl/types:type_analysis",
+        "//tensorflow_federated/python/core/impl/types:type_conversions",
+    ],
+)
+
+py_test(
+    name = "tensorflow_computation_context_test",
+    size = "small",
+    srcs = ["tensorflow_computation_context_test.py"],
+    python_version = "PY3",
+    srcs_version = "PY3",
+    deps = [
+        ":tensorflow_computation_context",
+        "//tensorflow_federated/python/common_libs:test",
+        "//tensorflow_federated/python/core/api:computation_types",
+        "//tensorflow_federated/python/core/api:computations",
+        "//tensorflow_federated/python/core/api:intrinsics",
+        "//tensorflow_federated/python/core/impl/types:placement_literals",
+    ],
+)
diff --git a/tensorflow_federated/python/core/impl/tf_computation_context.py b/tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context.py
similarity index 100%
rename from tensorflow_federated/python/core/impl/tf_computation_context.py
rename to tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context.py
diff --git a/tensorflow_federated/python/core/impl/tf_computation_context_test.py b/tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context_test.py
similarity index 74%
rename from tensorflow_federated/python/core/impl/tf_computation_context_test.py
rename to tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context_test.py
index 2b317b5c6..ab427c777 100644
--- a/tensorflow_federated/python/core/impl/tf_computation_context_test.py
+++ b/tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context_test.py
@@ -18,14 +18,13 @@ from tensorflow_federated.python.common_libs import test
 from tensorflow_federated.python.core.api import computation_types
 from tensorflow_federated.python.core.api import computations
 from tensorflow_federated.python.core.api import intrinsics
-from tensorflow_federated.python.core.impl import tf_computation_context
-from tensorflow_federated.python.core.impl.executors import default_executor
+from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation_context
 from tensorflow_federated.python.core.impl.types import placement_literals
 
 
 class TensorFlowComputationContextTest(test.TestCase):
 
-  def test_invoke_federated_computation_fails(self):
+  def test_invoke_raises_value_error_with_federated_computation(self):
 
     @computations.federated_computation(
         computation_types.FederatedType(tf.int32, placement_literals.SERVER,
@@ -33,13 +32,14 @@ class TensorFlowComputationContextTest(test.TestCase):
     def foo(x):
       return intrinsics.federated_broadcast(x)
 
-    context = tf_computation_context.TensorFlowComputationContext(
+    context = tensorflow_computation_context.TensorFlowComputationContext(
         tf.compat.v1.get_default_graph())
+
     with self.assertRaisesRegex(ValueError,
                                 'Expected a TensorFlow computation.'):
       context.invoke(foo, None)
 
-  def test_invoke_tf_computation(self):
+  def test_invoke_returns_result_with_tf_computation(self):
     make_10 = computations.tf_computation(lambda: tf.constant(10))
     add_one = computations.tf_computation(lambda x: tf.add(x, 1), tf.int32)
 
@@ -55,17 +55,22 @@ class TensorFlowComputationContextTest(test.TestCase):
 
     @computations.tf_computation
     def foo():
-      # Test invoking one tf_computation inside
-      # another.
       zero = tf.Variable(0, name='zero')
       ten = tf.Variable(make_10())
       return (add_one_with_v2(add_one_with_v1(add_one(make_10()))) + zero +
               ten - ten)
 
-    self.assertEqual(str(foo.type_signature), '( -> int32)')
-    self.assertEqual(foo(), 13)
+    graph = tf.compat.v1.Graph()
+    context = tensorflow_computation_context.TensorFlowComputationContext(graph)
+
+    self.assertEqual(foo.type_signature.compact_representation(), '( -> int32)')
+    x = context.invoke(foo, None)
 
+    with tf.compat.v1.Session(graph=graph) as sess:
+      if context.init_ops:
+        sess.run(context.init_ops)
+      result = sess.run(x)
+    self.assertEqual(result, 13)
 
 if __name__ == '__main__':
-  default_executor.initialize_default_execution_context()
   test.main()
diff --git a/tensorflow_federated/python/core/impl/tensorflow_deserialization.py b/tensorflow_federated/python/core/impl/tensorflow_deserialization.py
index a1752553a..114aed2f6 100644
--- a/tensorflow_federated/python/core/impl/tensorflow_deserialization.py
+++ b/tensorflow_federated/python/core/impl/tensorflow_deserialization.py
@@ -14,8 +14,8 @@
 """Utilities for deserializing TensorFlow computations.
 
 Note: This is separate from `tensorflow_serialization.py` to avoid a circular
-dependency through `tf_computation_context.py`. The context code depends on
-the deserialization code (to implement invocation), whereas the serialization
+dependency through `tensorflow_computation_context.py`. The context code depends
+on the deserialization code (to implement invocation), whereas the serialization
 code depends on the context code (to invoke the Python function in context).
 """
 
@@ -37,8 +37,8 @@ def deserialize_and_call_tf_computation(computation_proto, arg, graph):
   implementation may rely on different mechanisms. The caller should not be
   concerned with the specifics of the implementation. At this point, the method
   is expected to only be used within the body of another TF computation (within
-  an instance of `tf_computation_context.TensorFlowComputationContext` at the
-  top of the stack), and potentially also in certain types of interpreted
+  an instance of `tensorflow_computation_context.TensorFlowComputationContext`
+  at the top of the stack), and potentially also in certain types of interpreted
   execution contexts (TBD).
 
   Args:
diff --git a/tensorflow_federated/python/core/impl/tensorflow_serialization.py b/tensorflow_federated/python/core/impl/tensorflow_serialization.py
index cd3159005..206473d2e 100644
--- a/tensorflow_federated/python/core/impl/tensorflow_serialization.py
+++ b/tensorflow_federated/python/core/impl/tensorflow_serialization.py
@@ -27,8 +27,8 @@ from tensorflow_federated.proto.v0 import computation_pb2 as pb
 from tensorflow_federated.python.common_libs import py_typecheck
 from tensorflow_federated.python.common_libs import serialization_utils
 from tensorflow_federated.python.core.api import computation_types
-from tensorflow_federated.python.core.impl import tf_computation_context
 from tensorflow_federated.python.core.impl.context_stack import context_stack_base
+from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation_context
 from tensorflow_federated.python.core.impl.types import type_conversions
 from tensorflow_federated.python.core.impl.types import type_serialization
 from tensorflow_federated.python.core.impl.utils import function_utils
@@ -268,7 +268,7 @@ def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
                 signature.parameters))
       parameter_value = None
       parameter_binding = None
-    context = tf_computation_context.TensorFlowComputationContext(graph)
+    context = tensorflow_computation_context.TensorFlowComputationContext(graph)
     with context_stack.install(context):
       with variable_utils.record_variable_creation_scope() as all_variables:
         if parameter_value is not None:
diff --git a/tensorflow_federated/python/core/impl/types/BUILD b/tensorflow_federated/python/core/impl/types/BUILD
index 0bc6405f8..5450820e7 100644
--- a/tensorflow_federated/python/core/impl/types/BUILD
+++ b/tensorflow_federated/python/core/impl/types/BUILD
@@ -9,6 +9,7 @@ package_group(
         "//tensorflow_federated/python/core/impl/compiler/...",
         "//tensorflow_federated/python/core/impl/context_stack/...",
         "//tensorflow_federated/python/core/impl/executors/...",
+        "//tensorflow_federated/python/core/impl/tensorflow_context/...",
         "//tensorflow_federated/python/core/impl/utils/...",
         "//tensorflow_federated/python/core/impl/wrappers/...",
 
-- 
GitLab