提交 99036cc9 编辑于 作者: Keith Rush's avatar Keith Rush 提交者: tensorflow-copybara
浏览文件

Adds failing test for progress on partially available workers.

PiperOrigin-RevId: 345482667
上级 147a6c60
......@@ -83,6 +83,44 @@ class WorkerFailureTest(parameterized.TestCase):
result = map_add_one([0, 1])
self.assertEqual(result, [1, 2])
@parameterized.named_parameters(
(
'native_remote_request_reply',
remote_runtime_test_utils.create_localhost_remote_context(
_WORKER_PORTS),
remote_runtime_test_utils.create_localhost_worker_contexts(
[_WORKER_PORTS[0]]),
),
(
'native_remote_streaming',
remote_runtime_test_utils.create_localhost_remote_context(
_WORKER_PORTS, rpc_mode='STREAMING'),
remote_runtime_test_utils.create_localhost_worker_contexts(
[_WORKER_PORTS[0]]),
),
)
def test_computations_run_with_partially_available_workers(
self, tff_context, server_contexts):
self.skipTest('b/174679820')
@tff.tf_computation(tf.int32)
def add_one(x):
return x + 1
@tff.federated_computation(tff.type_at_clients(tf.int32))
def map_add_one(federated_arg):
return tff.federated_map(add_one, federated_arg)
context_stack = tff.framework.get_context_stack()
with context_stack.install(tff_context):
with contextlib.ExitStack() as stack:
for server_context in server_contexts:
stack.enter_context(server_context)
result = map_add_one([0, 1])
self.assertEqual(result, [1, 2])
# TODO(b/172025644): Promote streaming plus intermediate aggregation to a
# proper backend test when the final cleanup issues are diagnosed and fixed.
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册