Commit 6bfc93a3 authored by Taylor Cramer's avatar Taylor Cramer Committed by tensorflow-copybara
Browse files

Delay instantiation of GRPC connections until after main()

Pre-`main` instantiation of GRPC connections will no longer be supported in the future. This CL modifies the test annotation which accepted pre-initialized context harnesses to instead accept no-arg lambdas. This allows initialization of the test contexts to be delayed until the test itself is run.

PiperOrigin-RevId: 412120040
parent 0fa0ed2c
......@@ -329,16 +329,17 @@ class TensorFlowComputationTest(tf.test.TestCase, parameterized.TestCase):
# pyformat: disable
@test_contexts.with_contexts(
('native_local', tff.backends.native.create_local_python_execution_context()),
# pylint: disable=unnecessary-lambda
('native_local', lambda: tff.backends.native.create_local_python_execution_context()),
('native_remote',
remote_runtime_test_utils.create_localhost_remote_context(test_contexts.WORKER_PORTS),
remote_runtime_test_utils.create_inprocess_worker_contexts(test_contexts.WORKER_PORTS)),
lambda: remote_runtime_test_utils.create_localhost_remote_context(test_contexts.WORKER_PORTS),
lambda: remote_runtime_test_utils.create_inprocess_worker_contexts(test_contexts.WORKER_PORTS)),
('native_remote_intermediate_aggregator',
remote_runtime_test_utils.create_localhost_remote_context(test_contexts.AGGREGATOR_PORTS),
remote_runtime_test_utils.create_inprocess_aggregator_contexts(test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)),
('native_sizing', tff.backends.native.create_sizing_execution_context()),
lambda: remote_runtime_test_utils.create_localhost_remote_context(test_contexts.AGGREGATOR_PORTS),
lambda: remote_runtime_test_utils.create_inprocess_aggregator_contexts(test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)),
('native_sizing', lambda: tff.backends.native.create_sizing_execution_context()),
('native_thread_debug',
tff.backends.native.create_thread_debugging_execution_context()),
lambda: tff.backends.native.create_thread_debugging_execution_context()),
)
# pyformat: enable
def test_takes_infinite_dataset(self):
......@@ -355,16 +356,17 @@ class TensorFlowComputationTest(tf.test.TestCase, parameterized.TestCase):
# pyformat: disable
@test_contexts.with_contexts(
('native_local', tff.backends.native.create_local_python_execution_context()),
# pylint: disable=unnecessary-lambda
('native_local', lambda: tff.backends.native.create_local_python_execution_context()),
('native_remote',
remote_runtime_test_utils.create_localhost_remote_context(test_contexts.WORKER_PORTS),
remote_runtime_test_utils.create_inprocess_worker_contexts(test_contexts.WORKER_PORTS)),
lambda: remote_runtime_test_utils.create_localhost_remote_context(test_contexts.WORKER_PORTS),
lambda: remote_runtime_test_utils.create_inprocess_worker_contexts(test_contexts.WORKER_PORTS)),
('native_remote_intermediate_aggregator',
remote_runtime_test_utils.create_localhost_remote_context(test_contexts.AGGREGATOR_PORTS),
remote_runtime_test_utils.create_inprocess_aggregator_contexts(test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)),
('native_sizing', tff.backends.native.create_sizing_execution_context()),
lambda: remote_runtime_test_utils.create_localhost_remote_context(test_contexts.AGGREGATOR_PORTS),
lambda: remote_runtime_test_utils.create_inprocess_aggregator_contexts(test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)),
('native_sizing', lambda: tff.backends.native.create_sizing_execution_context()),
('native_thread_debug',
tff.backends.native.create_thread_debugging_execution_context()),
lambda: tff.backends.native.create_thread_debugging_execution_context()),
)
# pyformat: enable
def test_returns_infinite_dataset(self):
......@@ -451,7 +453,8 @@ class NonDeterministicTest(parameterized.TestCase):
class SizingExecutionContextTest(parameterized.TestCase):
@test_contexts.with_context(
tff.backends.native.create_sizing_execution_context())
# pylint: disable=unnecessary-lambda
lambda: tff.backends.native.create_sizing_execution_context())
def test_get_size_info(self):
num_clients = 10
to_float = lambda x: tf.cast(x, tf.float32)
......
......@@ -160,21 +160,27 @@ class WorkerFailureTest(parameterized.TestCase):
self.assertEqual(sum_arg(1), 10)
@parameterized.named_parameters((
'native_remote',
remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
remote_runtime_test_utils.create_inprocess_worker_contexts(_WORKER_PORTS),
), (
'native_remote_intermediate_aggregator',
remote_runtime_test_utils.create_localhost_remote_context(
_AGGREGATOR_PORTS),
remote_runtime_test_utils.create_inprocess_aggregator_contexts(
_WORKER_PORTS, _AGGREGATOR_PORTS),
))
@parameterized.named_parameters(
# pylint: disable=g-long-lambda
# pylint: disable=unnecessary-lambda
(
'native_remote',
lambda: remote_runtime_test_utils.create_localhost_remote_context(
_WORKER_PORTS),
lambda: remote_runtime_test_utils.create_inprocess_worker_contexts(
_WORKER_PORTS),
),
(
'native_remote_intermediate_aggregator',
lambda: remote_runtime_test_utils.create_localhost_remote_context(
_AGGREGATOR_PORTS),
lambda: remote_runtime_test_utils.create_inprocess_aggregator_contexts(
_WORKER_PORTS, _AGGREGATOR_PORTS),
))
class RemoteRuntimeConfigurationChangeTest(parameterized.TestCase):
def test_computations_run_with_changing_clients(self, context,
server_contexts):
def test_computations_run_with_changing_clients(self, context_fn,
server_contexts_fn):
@tff.tf_computation(tf.int32)
@tf.function
......@@ -186,10 +192,10 @@ class RemoteRuntimeConfigurationChangeTest(parameterized.TestCase):
return tff.federated_map(add_one, federated_arg)
context_stack = tff.framework.get_context_stack()
with context_stack.install(context):
with context_stack.install(context_fn()):
with contextlib.ExitStack() as stack:
for server_context in server_contexts:
for server_context in server_contexts_fn():
stack.enter_context(server_context)
result_two_clients = map_add_one([0, 1])
self.assertEqual(result_two_clients, [1, 2])
......
......@@ -32,25 +32,28 @@ def _create_local_mergeable_comp_context():
def _get_all_contexts():
"""Returns a list containing a (name, context_fn) tuple for each context."""
# pyformat: disable
return [
('native_local', tff.backends.native.create_local_python_execution_context()),
('native_mergeable', _create_local_mergeable_comp_context()),
# pylint: disable=unnecessary-lambda
# native_local_cpp removed by copybara
('native_local', lambda: tff.backends.native.create_local_python_execution_context()),
('native_mergeable', lambda: _create_local_mergeable_comp_context()),
('native_remote',
remote_runtime_test_utils.create_localhost_remote_context(WORKER_PORTS),
remote_runtime_test_utils.create_inprocess_worker_contexts(WORKER_PORTS)),
lambda: remote_runtime_test_utils.create_localhost_remote_context(WORKER_PORTS),
lambda: remote_runtime_test_utils.create_inprocess_worker_contexts(WORKER_PORTS)),
('native_remote_intermediate_aggregator',
remote_runtime_test_utils.create_localhost_remote_context(AGGREGATOR_PORTS),
remote_runtime_test_utils.create_inprocess_aggregator_contexts(WORKER_PORTS, AGGREGATOR_PORTS)),
('native_sizing', tff.backends.native.create_sizing_execution_context()),
lambda: remote_runtime_test_utils.create_localhost_remote_context(AGGREGATOR_PORTS),
lambda: remote_runtime_test_utils.create_inprocess_aggregator_contexts(WORKER_PORTS, AGGREGATOR_PORTS)),
('native_sizing', lambda: tff.backends.native.create_sizing_execution_context()),
('native_thread_debug',
tff.backends.native.create_thread_debugging_execution_context()),
('test', tff.backends.test.create_test_execution_context()),
lambda: tff.backends.native.create_thread_debugging_execution_context()),
('test', lambda: tff.backends.test.create_test_execution_context()),
]
# pyformat: enable
def with_context(context):
def with_context(context_fn):
"""A decorator for running tests in the given `context`."""
def decorator_context(fn):
......@@ -58,7 +61,7 @@ def with_context(context):
@functools.wraps(fn)
def wrapper_context(self):
context_stack = tff.framework.get_context_stack()
with context_stack.install(context):
with context_stack.install(context_fn()):
fn(self)
return wrapper_context
......@@ -66,7 +69,7 @@ def with_context(context):
return decorator_context
def with_environment(server_contexts):
def with_environment(server_contexts_fn):
"""A decorator for running tests in an environment."""
def decorator_environment(fn):
......@@ -74,7 +77,7 @@ def with_environment(server_contexts):
@functools.wraps(fn)
def wrapper_environment(self):
with contextlib.ExitStack() as stack:
for server_context in server_contexts:
for server_context in server_contexts_fn():
stack.enter_context(server_context)
fn(self)
......@@ -91,11 +94,11 @@ def with_contexts(*args):
named_contexts = _get_all_contexts()
@parameterized.named_parameters(*named_contexts)
def wrapper_contexts(self, context, server_contexts=None):
with_context_decorator = with_context(context)
def wrapper_contexts(self, context_fn, server_contexts_fn=None):
with_context_decorator = with_context(context_fn)
decorated_fn = with_context_decorator(fn)
if server_contexts is not None:
with_environment_decorator = with_environment(server_contexts)
if server_contexts_fn is not None:
with_environment_decorator = with_environment(server_contexts_fn)
decorated_fn = with_environment_decorator(decorated_fn)
decorated_fn(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment