提交 5e1ad040 编辑于 作者: Keith Rush's avatar Keith Rush 提交者: tensorflow-copybara
浏览文件

Adds failing test for bad interaction between caching and remote runtime configuration.

PiperOrigin-RevId: 347400725
上级 5721e8f9
......@@ -36,7 +36,7 @@ py_test(
py_test(
name = "perf_regression_test",
size = "medium",
timeout = "moderate",
srcs = ["perf_regression_test.py"],
python_version = "PY3",
srcs_version = "PY3",
......@@ -45,7 +45,7 @@ py_test(
py_test(
name = "remote_runtime_integration_test",
size = "small",
timeout = "moderate",
srcs = ["remote_runtime_integration_test.py"],
python_version = "PY3",
srcs_version = "PY3",
......
......@@ -151,5 +151,57 @@ class StreamingWithIntermediateAggTest(absltest.TestCase):
self.assertEqual(result, 3)
@parameterized.named_parameters((
'native_remote_request_reply',
remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS),
), (
'native_remote_streaming',
remote_runtime_test_utils.create_localhost_remote_context(
_WORKER_PORTS, rpc_mode='STREAMING'),
remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS),
), (
'native_remote_intermediate_aggregator',
remote_runtime_test_utils.create_localhost_remote_context(
_AGGREGATOR_PORTS),
remote_runtime_test_utils.create_localhost_aggregator_contexts(
_WORKER_PORTS, _AGGREGATOR_PORTS),
))
class RemoteRuntimeConfigurationChangeTest(absltest.TestCase):
def test_computations_run_with_changing_clients(self, context,
server_contexts):
self.skipTest('b/175155128')
@tff.tf_computation(tf.int32)
@tf.function
def add_one(x):
return x + 1
@tff.federated_computation(tff.type_at_clients(tf.int32))
def map_add_one(federated_arg):
return tff.federated_map(add_one, federated_arg)
context_stack = tff.framework.get_context_stack()
with context_stack.install(context):
with contextlib.ExitStack() as stack:
for server_context in server_contexts:
stack.enter_context(server_context)
result_two_clients = map_add_one([0, 1])
self.assertEqual(result_two_clients, [1, 2])
# Moving to three clients should be fine
result_three_clients = map_add_one([0, 1, 2])
# Running a 0-client function should also be OK
self.assertEqual(add_one(0), 1)
self.assertEqual(result_three_clients, [1, 2, 3])
# Changing back to 2 clients should still succeed.
second_result_two_clients = map_add_one([0, 1])
self.assertEqual(second_result_two_clients, [1, 2])
# Similarly, 3 clients again should be fine.
second_result_three_clients = map_add_one([0, 1, 2])
self.assertEqual(second_result_three_clients, [1, 2, 3])
if __name__ == '__main__':
absltest.main()
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册