提交 2a0ce580 编辑于 作者: Keith Rush's avatar Keith Rush 提交者: tensorflow-copybara
浏览文件

Adjusts caching behavior of remote runtime to reconfigure on cardinalities change.

Previously, the remote runtime inherited the executor-stack caching behavior of the local runtime. This was incorrect, as in the local runtime there is a 1:1 mapping between cardinalities and in-memory objects, but this does not hold in the remote runtime--the same service can host different numbers of clients across time.

PiperOrigin-RevId: 347836956
上级 5ae26b55
......@@ -421,7 +421,7 @@ class FederatingExecutorFactory(executor_factory.ExecutorFactory):
def create_minimal_length_flat_stack_fn(
max_clients_per_stack: int,
federated_stack_factory: FederatingExecutorFactory
federated_stack_factory: executor_factory.ExecutorFactory
) -> Callable[[executor_factory.CardinalitiesType],
List[executor_base.Executor]]:
"""Creates a function returning a list of executors to run `cardinalities`.
......@@ -436,8 +436,8 @@ def create_minimal_length_flat_stack_fn(
Args:
max_clients_per_stack: Integer determining the maximum number of clients a
single executor in the list returned by the function may execute.
federated_stack_factory: The `FederatingExecutorFactory` for use in actually
constructing these executors.
federated_stack_factory: The `executor_factory.ExecutorFactory` for use in
actually constructing these executors.
Returns:
A callable taking a parameter of type `executor_factory.CardinalitiesType`,
......@@ -769,6 +769,64 @@ def sizing_executor_factory(
return SizingExecutorFactory(_factory_fn)
class ReconstructOnChangeExecutorFactory(executor_factory.ExecutorFactory):
"""ExecutorFactory exposing hook to construct executors on environment change.
When the initialization parameter `change_query` returns `True`,
ReconstructOnChangeExecutorFactory` constructs a new executor, bypassing
any previously constructed executors.
"""
def __init__(self,
underlying_stack: executor_factory.ExecutorFactory,
ensure_closed: Optional[Sequence[executor_base.Executor]] = None,
change_query: Callable[[executor_factory.CardinalitiesType],
bool] = lambda _: True):
self._change_query = change_query
self._underlying_stack = underlying_stack
self._executors = {}
if ensure_closed is None:
ensure_closed = ()
self._ensure_closed = ensure_closed
def create_executor(
self, cardinalities: executor_factory.CardinalitiesType
) -> executor_base.Executor:
"""Returns a new or existing executor, depending on `change_query`.
`create_executor` constructs a new executor whenever `change_query` returns
`True` when called with argument `cardinalities`. If `change_query` returns
`False`, `create_executor` is free to inspect its internal executor cache
and return a previously constructed executor if one is available.
Args:
cardinalities: A mapping from placement literals to ints.
Returns:
An `executor_base.Executor` obeying the semantics above.
"""
py_typecheck.check_type(cardinalities, dict)
key = _get_hashable_key(cardinalities)
if self._change_query(cardinalities):
reconstructed = self._underlying_stack.create_executor(cardinalities)
self._executors[key] = reconstructed
return reconstructed
elif self._executors.get(key):
return self._executors[key]
else:
constructed = self._underlying_stack.create_executor(cardinalities)
self._executors[key] = constructed
return constructed
def clean_up_executors(self):
for _, ex in self._executors.items():
ex.close()
self._executors = {}
for ex in self._ensure_closed:
ex.close()
self._underlying_stack.clean_up_executors()
def remote_executor_factory(
channels: List[grpc.Channel],
rpc_mode: str = 'REQUEST_REPLY',
......@@ -876,6 +934,19 @@ def remote_executor_factory(
flat_stack_fn=flat_stack_fn,
)
return ResourceManagingExecutorFactory(
executor_stack_fn=composing_executor_factory.create_executor,
ensure_closed=remote_executors)
class _ChangeQuery:
"""Stateful callable tracking cardinalities of remote runtime."""
def __init__(self):
self._cardinalities = None
def __call__(self,
cardinalities: executor_factory.CardinalitiesType) -> bool:
cardinalities_changed = self._cardinalities != cardinalities
self._cardinalities = cardinalities
return cardinalities_changed
return ReconstructOnChangeExecutorFactory(
underlying_stack=composing_executor_factory,
ensure_closed=remote_executors,
change_query=_ChangeQuery())
......@@ -171,7 +171,6 @@ class RemoteRuntimeConfigurationChangeTest(absltest.TestCase):
def test_computations_run_with_changing_clients(self, context,
server_contexts):
self.skipTest('b/175155128')
@tff.tf_computation(tf.int32)
@tf.function
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册