Skip to content
Snippets Groups Projects
Commit 14e830a8 authored by Keith Rush's avatar Keith Rush Committed by tensorflow-copybara
Browse files

Ensures remote executor configuration only creates a new event loop when necessary.

PiperOrigin-RevId: 339961597
parent 4ed46d9b
No related branches found
No related tags found
No related merge requests found
......@@ -815,13 +815,29 @@ def remote_executor_factory(
thread_pool_executor=thread_pool_executor,
dispose_batch_size=dispose_batch_size))
def _get_event_loop():
should_close_loop = False
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
should_close_loop = True
except RuntimeError:
loop = asyncio.new_event_loop()
should_close_loop = True
return loop, should_close_loop
def _configure_remote_executor(ex, cardinalities, loop):
"""Configures `ex` to run the appropriate number of clients."""
loop.run_until_complete(ex.set_cardinalities(cardinalities))
if loop.is_running():
asyncio.run_coroutine_threadsafe(
ex.set_cardinalities(cardinalities), loop)
else:
loop.run_until_complete(ex.set_cardinalities(cardinalities))
return
def _configure_remote_workers(cardinalities):
loop = asyncio.new_event_loop()
loop, must_close_loop = _get_event_loop()
try:
if not cardinalities.get(placement_literals.CLIENTS):
for ex in remote_executors:
......@@ -839,7 +855,9 @@ def remote_executor_factory(
ex, {placement_literals.CLIENTS: num_clients_to_host}, loop)
live_workers.append(ex)
finally:
loop.close()
if must_close_loop:
loop.stop()
loop.close()
return [_wrap_executor_in_threading_stack(e) for e in live_workers]
flat_stack_fn = _configure_remote_workers
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import math
from unittest import mock
......@@ -696,6 +697,21 @@ class RemoteExecutorFactoryTest(absltest.TestCase):
remote_ex_factory.create_executor({placement_literals.CLIENTS: 10})
mock_obj.assert_called_once()
def test_configuration_succeeds_while_event_loop_is_running(self):
loop = asyncio.get_event_loop()
channels = [
grpc.insecure_channel('localhost:1'),
grpc.insecure_channel('localhost:2')
]
async def coro_func():
remote_ex_factory = executor_stacks.remote_executor_factory(channels)
remote_ex_factory.create_executor({placement_literals.CLIENTS: 1})
loop.run_until_complete(coro_func())
loop.stop()
loop.close()
if __name__ == '__main__':
absltest.main()
......@@ -40,6 +40,7 @@ py_test(
py_test(
name = "remote_runtime_integration_test",
size = "small",
srcs = ["remote_runtime_integration_test.py"],
python_version = "PY3",
srcs_version = "PY3",
......
......@@ -29,15 +29,22 @@ _AGGREGATOR_PORTS = [portpicker.pick_unused_port() for _ in range(2)]
# TODO(b/168744510): This module is intended to be short-lived, and the
# coverage here should be moved down to unit tests when we have a better mocking
# infrastructure deeper in the runtime.
class RemoteRuntimeIntegrationTest(parameterized.TestCase):
class WorkerFailureTest(parameterized.TestCase):
@parameterized.named_parameters(
('native_remote',
('native_remote_request_reply',
remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
remote_runtime_test_utils.create_localhost_worker_contexts(
_WORKER_PORTS),
remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)
),
('native_remote_streaming',
remote_runtime_test_utils.create_localhost_remote_context(
_WORKER_PORTS, rpc_mode='STREAMING'),
remote_runtime_test_utils.create_localhost_worker_contexts(
_WORKER_PORTS),
remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)
),
('native_remote_intermediate_aggregator',
remote_runtime_test_utils.create_localhost_remote_context(
_AGGREGATOR_PORTS),
......@@ -77,5 +84,34 @@ class RemoteRuntimeIntegrationTest(parameterized.TestCase):
self.assertEqual(result, [1, 2])
# TODO(b/172025644): Promote streaming plus intermediate aggregation to a
# proper backend test when the final cleanup issues are diagnosed and fixed.
class StreamingWithIntermediateAggTest(absltest.TestCase):
def test_runs_computation_streaming_with_intermediate_agg(self):
@tff.tf_computation(tf.int32)
def add_one(x):
return x + 1
@tff.federated_computation(tff.type_at_clients(tf.int32))
def map_add_one_and_sum(federated_arg):
return tff.federated_sum(tff.federated_map(add_one, federated_arg))
execution_context = remote_runtime_test_utils.create_localhost_remote_context(
_AGGREGATOR_PORTS, rpc_mode='STREAMING')
worker_contexts = remote_runtime_test_utils.create_localhost_aggregator_contexts(
_WORKER_PORTS, _AGGREGATOR_PORTS, rpc_mode='STREAMING')
context_stack = tff.framework.get_context_stack()
with context_stack.install(execution_context):
with contextlib.ExitStack() as stack:
for server_context in worker_contexts:
stack.enter_context(server_context)
result = map_add_one_and_sum([0, 1])
self.assertEqual(result, 3)
if __name__ == '__main__':
absltest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment