提交 5ae26b55 编辑于 作者: Keith Rush's avatar Keith Rush 提交者: tensorflow-copybara
浏览文件

Adds reference resolving executors back to remote stack for unplaced executor...

Adds reference resolving executors back to remote stack for unplaced executor composing can delegate to, and ensures composing is always topped by an RRE.

PiperOrigin-RevId: 347690967
上级 1dcb490d
......@@ -546,7 +546,8 @@ class ComposingExecutorFactory(executor_factory.ExecutorFactory):
RuntimeError: If hierarchy construction fails.
"""
if len(executors) == 1:
return self._create_composing_stack(target_executors=executors)
return reference_resolving_executor.ReferenceResolvingExecutor(
self._create_composing_stack(target_executors=executors))
while len(executors) > 1:
new_executors = []
offset = 0
......@@ -868,7 +869,7 @@ def remote_executor_factory(
flat_stack_fn = _configure_remote_workers
unplaced_ex_factory = UnplacedExecutorFactory(
use_caching=False, can_resolve_references=False)
use_caching=False, can_resolve_references=True)
composing_executor_factory = ComposingExecutorFactory(
max_fanout=max_fanout,
unplaced_ex_factory=unplaced_ex_factory,
......
......@@ -234,6 +234,51 @@ class FederatedComputationTest(parameterized.TestCase):
self.assertLen(result1, 2)
self.assertLen(result2, 3)
@with_contexts
def test_runs_unplaced_lambda(self):
@tff.federated_computation(tf.int32, tf.int32)
def bar(x, y):
del y # Unused
return x
result = bar(1, 2)
self.assertEqual(result, 1)
@with_contexts
def test_runs_server_placed_lambda(self):
@tff.federated_computation(tf.int32, tf.int32)
def foo(x, y):
del y # Unused
return x
@tff.federated_computation(
tff.FederatedType(
collections.OrderedDict(x=tf.int32, y=tf.int32), tff.SERVER))
def bar(server_tuple):
return tff.federated_map(foo, server_tuple)
result = bar(collections.OrderedDict(x=1, y=2))
self.assertEqual(result, 1)
@with_contexts
def test_runs_clients_placed_lambda(self):
@tff.federated_computation(tf.int32, tf.int32)
def foo(x, y):
del y # Unused
return x
@tff.federated_computation(
tff.FederatedType(
collections.OrderedDict(x=tf.int32, y=tf.int32), tff.CLIENTS))
def bar(clients_tuple):
return tff.federated_map(foo, clients_tuple)
result = bar([collections.OrderedDict(x=1, y=2)])
self.assertEqual(result, [1])
class TensorFlowComputationTest(parameterized.TestCase):
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册