Commit f72d215c authored by Krzysztof Ostrowski's avatar Krzysztof Ostrowski Committed by A. Unique TensorFlower
Browse files

Adds a set_default_executor() capability, but does not expose it yet due to...

Adds a set_default_executor() capability, but does not expose it yet due to PY3 compatibility issues. Also updates protobuf commit id in the workspace.

PiperOrigin-RevId: 255193729
parent 4aaf7e0e
......@@ -739,6 +739,32 @@ py_test(
],
)
py_library(
name = "set_default_executor",
srcs = ["set_default_executor.py"],
srcs_version = "PY3",
deps = [
":context_stack_impl",
":execution_context",
":executor_base",
"//tensorflow_federated/python/common_libs:py_typecheck",
],
)
py_test(
name = "set_default_executor_test",
size = "small",
srcs = ["set_default_executor_test.py"],
python_version = "PY3",
deps = [
":context_stack_impl",
":eager_executor",
":set_default_executor",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:computations",
],
)
py_library(
name = "tensorflow_deserialization",
srcs = ["tensorflow_deserialization.py"],
......
......@@ -28,15 +28,32 @@ from tensorflow_federated.python.core.impl import context_stack_base
from tensorflow_federated.python.core.impl import reference_executor
def _make_default_context(stack):
return reference_executor.ReferenceExecutor(
compiler_pipeline.CompilerPipeline(stack))
class ContextStackImpl(context_stack_base.ContextStack, threading.local):
"""An implementation of a common thread-local context stack to run against."""
def __init__(self):
super(ContextStackImpl, self).__init__()
self._stack = [
reference_executor.ReferenceExecutor(
compiler_pipeline.CompilerPipeline(self))
]
self._stack = [_make_default_context(self)]
def set_default_context(self, ctx=None):
"""Places `ctx` at the bottom of the stack.
Args:
ctx: Either an instance of `context_base.Context`, or `None`, with the
latter resulting in the default reference executor getting installed at
the bottom of the stack (as is the default).
"""
if ctx is not None:
py_typecheck.check_type(ctx, context_base.Context)
else:
ctx = _make_default_context(self)
assert self._stack
self._stack[0] = ctx
@property
def current(self):
......
......@@ -62,6 +62,18 @@ class ContextStackTest(absltest.TestCase):
self.assertIsInstance(ctx_stack.current,
reference_executor.ReferenceExecutor)
def test_set_default_context(self):
ctx_stack = context_stack_impl.context_stack
self.assertIsInstance(ctx_stack.current,
reference_executor.ReferenceExecutor)
foo = TestContext('foo')
ctx_stack.set_default_context(foo)
self.assertIs(ctx_stack.current, foo)
ctx_stack.set_default_context()
self.assertIsInstance(ctx_stack.current,
reference_executor.ReferenceExecutor)
if __name__ == '__main__':
absltest.main()
# Lint as: python3
# Copyright 2018, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A utility to change the default executor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.impl import context_stack_impl
from tensorflow_federated.python.core.impl import execution_context
from tensorflow_federated.python.core.impl import executor_base
def set_default_executor(executor=None):
"""Places an `executor`-backed execution context at the top of the stack.
Args:
executor: Either an instance of `executor_base.Executor`, or `None` which
causes the default reference executor to be installed (as is the default).
"""
if executor is not None:
py_typecheck.check_type(executor, executor_base.Executor)
context = execution_context.ExecutionContext(executor)
else:
context = None
context_stack_impl.context_stack.set_default_context(context)
# Lint as: python3
# Copyright 2018, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the set_default_executor.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.impl import context_stack_impl
from tensorflow_federated.python.core.impl import eager_executor
from tensorflow_federated.python.core.impl import set_default_executor
class TestSetDefaultExecutor(absltest.TestCase):
def test_basic_functionality(self):
@computations.tf_computation(computation_types.SequenceType(tf.int32))
def comp(ds):
return ds.take(5).reduce(np.int32(0), lambda x, y: x + y)
set_default_executor.set_default_executor(eager_executor.EagerExecutor())
ds = tf.data.Dataset.range(1).map(lambda x: tf.constant(5)).repeat()
v = comp(ds)
self.assertEqual(v, 25)
set_default_executor.set_default_executor()
self.assertIn('ReferenceExecutor',
str(type(context_stack_impl.context_stack.current).__name__))
if __name__ == '__main__':
tf.compat.v1.enable_v2_behavior()
absltest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment