Skip to content
Snippets Groups Projects
Commit fb91b2de authored by Michael Reneer's avatar Michael Reneer Committed by tensorflow-copybara
Browse files

Move the `tensorflow_computation_context` modules into the `impl/tensorflow_context` package.

PiperOrigin-RevId: 323072438
parent 1f9daf6c
No related branches found
No related tags found
No related merge requests found
...@@ -368,12 +368,12 @@ py_library( ...@@ -368,12 +368,12 @@ py_library(
srcs = ["tensorflow_serialization.py"], srcs = ["tensorflow_serialization.py"],
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":tf_computation_context",
"//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:serialization_utils", "//tensorflow_federated/python/common_libs:serialization_utils",
"//tensorflow_federated/python/core/api:computation_types", "//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/impl/context_stack:context_stack_base", "//tensorflow_federated/python/core/impl/context_stack:context_stack_base",
"//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation_context",
"//tensorflow_federated/python/core/impl/types:type_conversions", "//tensorflow_federated/python/core/impl/types:type_conversions",
"//tensorflow_federated/python/core/impl/types:type_serialization", "//tensorflow_federated/python/core/impl/types:type_serialization",
"//tensorflow_federated/python/core/impl/utils:function_utils", "//tensorflow_federated/python/core/impl/utils:function_utils",
...@@ -406,38 +406,6 @@ py_library( ...@@ -406,38 +406,6 @@ py_library(
deps = ["//tensorflow_federated/python/core/api:computations"], deps = ["//tensorflow_federated/python/core/api:computations"],
) )
py_library(
name = "tf_computation_context",
srcs = ["tf_computation_context.py"],
srcs_version = "PY3",
deps = [
":computation_impl",
":tensorflow_deserialization",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/impl/types:type_conversions",
],
)
py_test(
name = "tf_computation_context_test",
size = "small",
srcs = ["tf_computation_context_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":tf_computation_context",
"//tensorflow_federated/python/common_libs:test",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:intrinsics",
"//tensorflow_federated/python/core/impl/executors:default_executor",
"//tensorflow_federated/python/core/impl/types:placement_literals",
],
)
py_library( py_library(
name = "tree_to_cc_transformations", name = "tree_to_cc_transformations",
srcs = ["tree_to_cc_transformations.py"], srcs = ["tree_to_cc_transformations.py"],
......
...@@ -7,6 +7,7 @@ package_group( ...@@ -7,6 +7,7 @@ package_group(
# Impl Dependencies # Impl Dependencies
"//tensorflow_federated/python/core/impl/executors/...", "//tensorflow_federated/python/core/impl/executors/...",
"//tensorflow_federated/python/core/impl/tensorflow_context/...",
"//tensorflow_federated/python/core/impl/utils/...", "//tensorflow_federated/python/core/impl/utils/...",
"//tensorflow_federated/python/core/impl/wrappers/...", "//tensorflow_federated/python/core/impl/wrappers/...",
......
...@@ -25,3 +25,34 @@ py_library( ...@@ -25,3 +25,34 @@ py_library(
srcs_version = "PY3", srcs_version = "PY3",
visibility = ["//tensorflow_federated/tools:__subpackages__"], visibility = ["//tensorflow_federated/tools:__subpackages__"],
) )
py_library(
name = "tensorflow_computation_context",
srcs = ["tensorflow_computation_context.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/impl:computation_impl",
"//tensorflow_federated/python/core/impl:tensorflow_deserialization",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/impl/types:type_conversions",
],
)
py_test(
name = "tensorflow_computation_context_test",
size = "small",
srcs = ["tensorflow_computation_context_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":tensorflow_computation_context",
"//tensorflow_federated/python/common_libs:test",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:intrinsics",
"//tensorflow_federated/python/core/impl/types:placement_literals",
],
)
...@@ -18,14 +18,13 @@ from tensorflow_federated.python.common_libs import test ...@@ -18,14 +18,13 @@ from tensorflow_federated.python.common_libs import test
from tensorflow_federated.python.core.api import computation_types from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import computations from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import intrinsics from tensorflow_federated.python.core.api import intrinsics
from tensorflow_federated.python.core.impl import tf_computation_context from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation_context
from tensorflow_federated.python.core.impl.executors import default_executor
from tensorflow_federated.python.core.impl.types import placement_literals from tensorflow_federated.python.core.impl.types import placement_literals
class TensorFlowComputationContextTest(test.TestCase): class TensorFlowComputationContextTest(test.TestCase):
def test_invoke_federated_computation_fails(self): def test_invoke_raises_value_error_with_federated_computation(self):
@computations.federated_computation( @computations.federated_computation(
computation_types.FederatedType(tf.int32, placement_literals.SERVER, computation_types.FederatedType(tf.int32, placement_literals.SERVER,
...@@ -33,13 +32,14 @@ class TensorFlowComputationContextTest(test.TestCase): ...@@ -33,13 +32,14 @@ class TensorFlowComputationContextTest(test.TestCase):
def foo(x): def foo(x):
return intrinsics.federated_broadcast(x) return intrinsics.federated_broadcast(x)
context = tf_computation_context.TensorFlowComputationContext( context = tensorflow_computation_context.TensorFlowComputationContext(
tf.compat.v1.get_default_graph()) tf.compat.v1.get_default_graph())
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
'Expected a TensorFlow computation.'): 'Expected a TensorFlow computation.'):
context.invoke(foo, None) context.invoke(foo, None)
def test_invoke_tf_computation(self): def test_invoke_returns_result_with_tf_computation(self):
make_10 = computations.tf_computation(lambda: tf.constant(10)) make_10 = computations.tf_computation(lambda: tf.constant(10))
add_one = computations.tf_computation(lambda x: tf.add(x, 1), tf.int32) add_one = computations.tf_computation(lambda x: tf.add(x, 1), tf.int32)
...@@ -55,17 +55,22 @@ class TensorFlowComputationContextTest(test.TestCase): ...@@ -55,17 +55,22 @@ class TensorFlowComputationContextTest(test.TestCase):
@computations.tf_computation @computations.tf_computation
def foo(): def foo():
# Test invoking one tf_computation inside
# another.
zero = tf.Variable(0, name='zero') zero = tf.Variable(0, name='zero')
ten = tf.Variable(make_10()) ten = tf.Variable(make_10())
return (add_one_with_v2(add_one_with_v1(add_one(make_10()))) + zero + return (add_one_with_v2(add_one_with_v1(add_one(make_10()))) + zero +
ten - ten) ten - ten)
self.assertEqual(str(foo.type_signature), '( -> int32)') graph = tf.compat.v1.Graph()
self.assertEqual(foo(), 13) context = tensorflow_computation_context.TensorFlowComputationContext(graph)
self.assertEqual(foo.type_signature.compact_representation(), '( -> int32)')
x = context.invoke(foo, None)
with tf.compat.v1.Session(graph=graph) as sess:
if context.init_ops:
sess.run(context.init_ops)
result = sess.run(x)
self.assertEqual(result, 13)
if __name__ == '__main__': if __name__ == '__main__':
default_executor.initialize_default_execution_context()
test.main() test.main()
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
"""Utilities for deserializing TensorFlow computations. """Utilities for deserializing TensorFlow computations.
Note: This is separate from `tensorflow_serialization.py` to avoid a circular Note: This is separate from `tensorflow_serialization.py` to avoid a circular
dependency through `tf_computation_context.py`. The context code depends on dependency through `tensorflow_computation_context.py`. The context code depends
the deserialization code (to implement invocation), whereas the serialization on the deserialization code (to implement invocation), whereas the serialization
code depends on the context code (to invoke the Python function in context). code depends on the context code (to invoke the Python function in context).
""" """
...@@ -37,8 +37,8 @@ def deserialize_and_call_tf_computation(computation_proto, arg, graph): ...@@ -37,8 +37,8 @@ def deserialize_and_call_tf_computation(computation_proto, arg, graph):
implementation may rely on different mechanisms. The caller should not be implementation may rely on different mechanisms. The caller should not be
concerned with the specifics of the implementation. At this point, the method concerned with the specifics of the implementation. At this point, the method
is expected to only be used within the body of another TF computation (within is expected to only be used within the body of another TF computation (within
an instance of `tf_computation_context.TensorFlowComputationContext` at the an instance of `tensorflow_computation_context.TensorFlowComputationContext`
top of the stack), and potentially also in certain types of interpreted at the top of the stack), and potentially also in certain types of interpreted
execution contexts (TBD). execution contexts (TBD).
Args: Args:
......
...@@ -27,8 +27,8 @@ from tensorflow_federated.proto.v0 import computation_pb2 as pb ...@@ -27,8 +27,8 @@ from tensorflow_federated.proto.v0 import computation_pb2 as pb
from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import serialization_utils from tensorflow_federated.python.common_libs import serialization_utils
from tensorflow_federated.python.core.api import computation_types from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.impl import tf_computation_context
from tensorflow_federated.python.core.impl.context_stack import context_stack_base from tensorflow_federated.python.core.impl.context_stack import context_stack_base
from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation_context
from tensorflow_federated.python.core.impl.types import type_conversions from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.impl.types import type_serialization from tensorflow_federated.python.core.impl.types import type_serialization
from tensorflow_federated.python.core.impl.utils import function_utils from tensorflow_federated.python.core.impl.utils import function_utils
...@@ -268,7 +268,7 @@ def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack): ...@@ -268,7 +268,7 @@ def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
signature.parameters)) signature.parameters))
parameter_value = None parameter_value = None
parameter_binding = None parameter_binding = None
context = tf_computation_context.TensorFlowComputationContext(graph) context = tensorflow_computation_context.TensorFlowComputationContext(graph)
with context_stack.install(context): with context_stack.install(context):
with variable_utils.record_variable_creation_scope() as all_variables: with variable_utils.record_variable_creation_scope() as all_variables:
if parameter_value is not None: if parameter_value is not None:
......
...@@ -9,6 +9,7 @@ package_group( ...@@ -9,6 +9,7 @@ package_group(
"//tensorflow_federated/python/core/impl/compiler/...", "//tensorflow_federated/python/core/impl/compiler/...",
"//tensorflow_federated/python/core/impl/context_stack/...", "//tensorflow_federated/python/core/impl/context_stack/...",
"//tensorflow_federated/python/core/impl/executors/...", "//tensorflow_federated/python/core/impl/executors/...",
"//tensorflow_federated/python/core/impl/tensorflow_context/...",
"//tensorflow_federated/python/core/impl/utils/...", "//tensorflow_federated/python/core/impl/utils/...",
"//tensorflow_federated/python/core/impl/wrappers/...", "//tensorflow_federated/python/core/impl/wrappers/...",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment