Commit 9db55cdf authored by Taylor Cramer's avatar Taylor Cramer Committed by tensorflow-copybara
Browse files

Remove the CachingExecutor

The CachingExecutor formerly played a critical role in the TFF execution stack, as it prevented multiple execution of the myriad duplicate computations in TFF's ASTs. Nowadays, TFF no longer generates these types of duplicate computations, and it is no longer to cache at this level to ensure performance.  Additionally, this type of caching plays poorly with nondeterministic computations.

Prior to this CL, the CachingExecutor was already unused outside of tests. This should be a no-op for TFF customers.

PiperOrigin-RevId: 411112532
parent fd8b4ca4
......@@ -28,7 +28,6 @@ py_library(
"//tensorflow_federated/python/core/impl/context_stack:get_context_stack",
"//tensorflow_federated/python/core/impl/context_stack:set_default_context",
"//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context",
"//tensorflow_federated/python/core/impl/executors:caching_executor",
"//tensorflow_federated/python/core/impl/executors:cardinality_carrying_base",
"//tensorflow_federated/python/core/impl/executors:data_backend_base",
"//tensorflow_federated/python/core/impl/executors:data_executor",
......
......@@ -24,7 +24,6 @@ from tensorflow_federated.python.core.impl.context_stack.context_stack_base impo
from tensorflow_federated.python.core.impl.context_stack.get_context_stack import get_context_stack
from tensorflow_federated.python.core.impl.context_stack.set_default_context import set_default_context
from tensorflow_federated.python.core.impl.execution_contexts.sync_execution_context import ExecutionContext
from tensorflow_federated.python.core.impl.executors.caching_executor import CachingExecutor
from tensorflow_federated.python.core.impl.executors.cardinality_carrying_base import CardinalityCarrying
from tensorflow_federated.python.core.impl.executors.data_backend_base import DataBackend
from tensorflow_federated.python.core.impl.executors.data_executor import DataExecutor
......
......@@ -22,44 +22,6 @@ py_library(
visibility = ["//tensorflow_federated/tools/python_package:python_package_tool"],
)
py_library(
name = "caching_executor",
srcs = ["caching_executor.py"],
srcs_version = "PY3",
deps = [
":executor_base",
":executor_utils",
":executor_value_base",
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/computation:computation_impl",
"//tensorflow_federated/python/core/impl/types:computation_types",
],
)
py_test(
name = "caching_executor_test",
size = "small",
timeout = "moderate",
srcs = ["caching_executor_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":caching_executor",
":eager_tf_executor",
":executor_base",
":executor_stacks",
":executor_test_utils",
":reference_resolving_executor",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/computation:computation_impl",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
],
)
py_library(
name = "cardinalities_utils",
srcs = ["cardinalities_utils.py"],
......@@ -231,7 +193,7 @@ py_library(
srcs_version = "PY3",
tags = ["nokokoro"],
deps = [
"//tensorflow_federated/python/core/impl/executors:data_conversions",
":data_conversions",
"//tensorflow_federated/python/core/impl/types:placements",
],
)
......@@ -349,7 +311,6 @@ py_library(
srcs = ["executor_stacks.py"],
srcs_version = "PY3",
deps = [
":caching_executor",
":eager_tf_executor",
":executor_base",
":executor_factory",
......@@ -806,7 +767,6 @@ py_test(
shard_count = 10,
srcs_version = "PY3",
deps = [
":caching_executor",
":eager_tf_executor",
":reference_resolving_executor",
":sizing_executor",
......@@ -836,7 +796,6 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":caching_executor",
":eager_tf_executor",
":executor_base",
":thread_delegating_executor",
......
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An executor that caches and reuses values on repeated calls."""
import asyncio
import collections
import cachetools
import numpy as np
import tensorflow as tf
from tensorflow_federated.proto.v0 import computation_pb2 as pb
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.computation import computation_impl
from tensorflow_federated.python.core.impl.executors import executor_base
from tensorflow_federated.python.core.impl.executors import executor_utils
from tensorflow_federated.python.core.impl.executors import executor_value_base
from tensorflow_federated.python.core.impl.types import computation_types
class HashableWrapper(collections.abc.Hashable):
"""A wrapper around non-hashable objects to be compared by identity."""
def __init__(self, target):
self._target = target
def __hash__(self):
return hash(id(self._target))
def __eq__(self, other):
return other is HashableWrapper and other._target is self._target # pylint: disable=protected-access
def _get_hashable_key(value, type_spec):
"""Return a hashable key for value `value` of TFF type `type_spec`.
Args:
value: An argument to `create_value()`.
type_spec: An optional type signature.
Returns:
A hashable key to use such that the same `value` always maps to the same
key, and different ones map to different keys.
Raises:
TypeError: If there is no hashable key for this type of a value.
"""
if type_spec.is_struct():
if not isinstance(value, structure.Struct):
try:
value = structure.from_container(value)
except Exception as e:
raise TypeError(
'Failed to convert value with type_spec {} to `Struct`'.format(
repr(type_spec))) from e
type_specs = structure.iter_elements(type_spec)
r_elem = []
for v, (field_name, field_type) in zip(value, type_specs):
r_elem.append((field_name, _get_hashable_key(v, field_type)))
return structure.Struct(r_elem)
elif type_spec.is_federated():
if type_spec.all_equal:
return _get_hashable_key(value, type_spec.member)
else:
return tuple([_get_hashable_key(x, type_spec.member) for x in value])
elif isinstance(value, pb.Computation):
return value.SerializeToString(deterministic=True)
elif isinstance(value, np.ndarray):
return ('<dtype={},shape={}>'.format(value.dtype,
value.shape), value.tobytes())
elif (isinstance(value, collections.abc.Hashable) and
not isinstance(value, (tf.Tensor, tf.Variable))):
# TODO(b/139200385): Currently Tensor and Variable returns True for
# `isinstance(value, collections.abc.Hashable)` even when it's not hashable.
# Hence this workaround.
return value
else:
return HashableWrapper(value)
class CachedValueIdentifier(collections.abc.Hashable):
"""An identifier for a cached value."""
def __init__(self, identifier):
py_typecheck.check_type(identifier, str)
self._identifier = identifier
def __hash__(self):
return hash(self._identifier)
def __eq__(self, other):
# pylint: disable=protected-access
return (isinstance(other, CachedValueIdentifier) and
self._identifier == other._identifier)
# pylint: enable=protected-access
def __repr__(self):
return 'CachedValueIdentifier({!r})'.format(self._identifier)
def __str__(self):
return self._identifier
class CachedValue(executor_value_base.ExecutorValue):
"""A value held by the caching executor."""
def __init__(self, identifier, hashable_key, type_spec, target_future):
"""Creates a cached value.
Args:
identifier: An instance of `CachedValueIdentifier`.
hashable_key: A hashable source value key, if any, or `None` of not
applicable in this context, for use during cleanup.
type_spec: The type signature of the target, an instance of `tff.Type`.
target_future: An asyncio future that returns an instance of
`executor_value_base.ExecutorValue` that represents a value embedded in
the target executor.
Raises:
TypeError: If the arguments are of the wrong types.
"""
py_typecheck.check_type(identifier, CachedValueIdentifier)
py_typecheck.check_type(hashable_key, collections.abc.Hashable)
py_typecheck.check_type(type_spec, computation_types.Type)
if not asyncio.isfuture(target_future):
raise TypeError('Expected an asyncio future, got {}'.format(
py_typecheck.type_string(type(target_future))))
self._identifier = identifier
self._hashable_key = hashable_key
self._type_spec = type_spec
self._target_future = target_future
self._computed_result = None
@property
def type_signature(self):
return self._type_spec
@property
def identifier(self):
return self._identifier
@property
def hashable_key(self):
return self._hashable_key
@property
def target_future(self):
return self._target_future
async def compute(self):
if self._computed_result is None:
target_value = await self._target_future
self._computed_result = await target_value.compute()
return self._computed_result
_DEFAULT_CACHE_SIZE = 1000
class CachingExecutor(executor_base.Executor):
"""The caching executor only performs caching."""
# TODO(b/134543154): Factor out default cache settings to supply elsewhere,
# possibly as a part of executor stack configuration.
# TODO(b/134543154): It might be desirable to still keep aorund things that
# are currently in use (referenced) regardless of what's in the cache. This
# can be added later on.
def __init__(self, target_executor, cache=None):
"""Creates a new instance of this executor.
Args:
target_executor: An instance of `executor_base.Executor`.
cache: The cache to use (must be an instance of `cachetools.Cache`). If
unspecified, by default we construct a 1000-element LRU cache.
"""
py_typecheck.check_type(target_executor, executor_base.Executor)
if cache is not None:
py_typecheck.check_type(cache, cachetools.Cache)
else:
cache = cachetools.LRUCache(_DEFAULT_CACHE_SIZE)
self._target_executor = target_executor
self._cache = cache
self._num_values_created = 0
def close(self):
self._cache.clear()
self._target_executor.close()
def __del__(self):
for k in list(self._cache):
del self._cache[k]
async def create_value(self, value, type_spec=None):
type_spec = computation_types.to_type(type_spec)
if isinstance(value, computation_impl.ConcreteComputation):
return await self.create_value(
computation_impl.ConcreteComputation.get_proto(value),
executor_utils.reconcile_value_with_type_spec(value, type_spec))
py_typecheck.check_type(type_spec, computation_types.Type)
hashable_key = _get_hashable_key(value, type_spec)
try:
identifier = self._cache.get(hashable_key)
except TypeError as err:
raise RuntimeError(
'Failed to perform a hash table lookup with a value of Python '
'type {} and TFF type {}, and payload {}: {}'.format(
py_typecheck.type_string(type(value)), type_spec, value, err))
if isinstance(identifier, CachedValueIdentifier):
cached_value = self._cache.get(identifier)
# If may be that the same payload appeared with a mismatching type spec,
# which may be a legitimate use case if (as it happens) the payload alone
# does not uniquely determine the type, so we simply opt not to reuse the
# cache value and fallback on the regular behavior.
if (cached_value is not None and type_spec is not None and
not cached_value.type_signature.is_equivalent_to(type_spec)):
identifier = None
else:
identifier = None
if identifier is None:
self._num_values_created = self._num_values_created + 1
identifier = CachedValueIdentifier(str(self._num_values_created))
self._cache[hashable_key] = identifier
target_future = asyncio.ensure_future(
self._target_executor.create_value(value, type_spec))
cached_value = None
if cached_value is None:
cached_value = CachedValue(identifier, hashable_key, type_spec,
target_future)
self._cache[identifier] = cached_value
try:
await cached_value.target_future
except Exception:
# Invalidate the entire cache in the inner executor had an exception.
# TODO(b/145514490): This is a bit heavy handed, there maybe caches where
# only the current cache item needs to be invalidated; however this
# currently only occurs when an inner RemoteExecutor has the backend go
# down.
self._cache = {}
raise
# No type check is necessary here; we have either checked
# `is_equivalent_to` or just constructed `target_value`
# explicitly with `type_spec`.
return cached_value
async def create_call(self, comp, arg=None):
py_typecheck.check_type(comp, CachedValue)
py_typecheck.check_type(comp.type_signature, computation_types.FunctionType)
to_gather = [comp.target_future]
if arg is not None:
py_typecheck.check_type(arg, CachedValue)
comp.type_signature.parameter.check_assignable_from(arg.type_signature)
to_gather.append(arg.target_future)
identifier_str = '{}({})'.format(comp.identifier, arg.identifier)
else:
identifier_str = '{}()'.format(comp.identifier)
gathered = await asyncio.gather(*to_gather)
type_spec = comp.type_signature.result
identifier = CachedValueIdentifier(identifier_str)
try:
cached_value = self._cache[identifier]
except KeyError:
target_future = asyncio.ensure_future(
self._target_executor.create_call(*gathered))
cached_value = CachedValue(identifier, None, type_spec, target_future)
self._cache[identifier] = cached_value
try:
target_value = await cached_value.target_future
except Exception:
# TODO(b/145514490): This is a bit heavy handed, there maybe caches where
# only the current cache item needs to be invalidated; however this
# currently only occurs when an inner RemoteExecutor has the backend go
# down.
self._cache = {}
raise
type_spec.check_assignable_from(target_value.type_signature)
return cached_value
async def create_struct(self, elements):
if not isinstance(elements, structure.Struct):
elements = structure.from_container(elements)
element_strings = []
element_kv_pairs = structure.to_elements(elements)
to_gather = []
type_elements = []
for k, v in element_kv_pairs:
py_typecheck.check_type(v, CachedValue)
to_gather.append(v.target_future)
if k is not None:
py_typecheck.check_type(k, str)
element_strings.append('{}={}'.format(k, v.identifier))
type_elements.append((k, v.type_signature))
else:
element_strings.append(str(v.identifier))
type_elements.append(v.type_signature)
type_spec = computation_types.StructType(type_elements)
gathered = await asyncio.gather(*to_gather)
identifier = CachedValueIdentifier('<{}>'.format(','.join(element_strings)))
try:
cached_value = self._cache[identifier]
except KeyError:
target_future = asyncio.ensure_future(
self._target_executor.create_struct(
structure.Struct(
(k, v) for (k, _), v in zip(element_kv_pairs, gathered))))
cached_value = CachedValue(identifier, None, type_spec, target_future)
self._cache[identifier] = cached_value
try:
target_value = await cached_value.target_future
except Exception:
# TODO(b/145514490): This is a bit heavy handed, there maybe caches where
# only the current cache item needs to be invalidated; however this
# currently only occurs when an inner RemoteExecutor has the backend go
# down.
self._cache = {}
raise
type_spec.check_assignable_from(target_value.type_signature)
return cached_value
async def create_selection(self, source, index):
py_typecheck.check_type(source, CachedValue)
py_typecheck.check_type(source.type_signature, computation_types.StructType)
source_val = await source.target_future
identifier_str = f'{source.identifier}[{index}]'
type_spec = source.type_signature[index]
identifier = CachedValueIdentifier(identifier_str)
try:
cached_value = self._cache[identifier]
except KeyError:
target_future = asyncio.ensure_future(
self._target_executor.create_selection(source_val, index))
cached_value = CachedValue(identifier, None, type_spec, target_future)
self._cache[identifier] = cached_value
try:
target_value = await cached_value.target_future
except Exception:
# TODO(b/145514490): This is a bit heavy handed, there maybe caches where
# only the current cache item needs to be invalidated; however this
# currently only occurs when an inner RemoteExecutor has the backend go
# down.
self._cache = {}
raise
type_spec.check_assignable_from(target_value.type_signature)
return cached_value
......@@ -27,7 +27,6 @@ import tensorflow as tf
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.impl.compiler import local_computation_factory_base
from tensorflow_federated.python.core.impl.compiler import tensorflow_computation_factory
from tensorflow_federated.python.core.impl.executors import caching_executor
from tensorflow_federated.python.core.impl.executors import eager_tf_executor
from tensorflow_federated.python.core.impl.executors import executor_base
from tensorflow_federated.python.core.impl.executors import executor_factory
......@@ -253,12 +252,9 @@ class SizingExecutorFactory(ResourceManagingExecutorFactory):
# pylint:disable=missing-function-docstring
def _wrap_executor_in_threading_stack(ex: executor_base.Executor,
use_caching: Optional[bool] = False,
support_sequence_ops: bool = False,
can_resolve_references=True):
threaded_ex = thread_delegating_executor.ThreadDelegatingExecutor(ex)
if use_caching:
threaded_ex = caching_executor.CachingExecutor(threaded_ex)
if support_sequence_ops:
if not can_resolve_references:
raise ValueError(
......@@ -281,13 +277,11 @@ class UnplacedExecutorFactory(executor_factory.ExecutorFactory):
def __init__(self,
*,
use_caching: bool,
support_sequence_ops: bool = False,
can_resolve_references: bool = True,
server_device: Optional[tf.config.LogicalDevice] = None,
client_devices: Optional[Sequence[tf.config.LogicalDevice]] = (),
leaf_executor_fn=eager_tf_executor.EagerTFExecutor):
self._use_caching = use_caching
self._support_sequence_ops = support_sequence_ops
self._can_resolve_references = can_resolve_references
self._server_device = server_device
......@@ -322,7 +316,6 @@ class UnplacedExecutorFactory(executor_factory.ExecutorFactory):
leaf_ex = self._leaf_executor_fn(device=device)
return _wrap_executor_in_threading_stack(
leaf_ex,
use_caching=self._use_caching,
support_sequence_ops=self._support_sequence_ops,
can_resolve_references=self._can_resolve_references)
......@@ -669,7 +662,6 @@ def local_executor_factory(
if max_fanout < 2:
raise ValueError('Max fanout must be greater than 1.')
unplaced_ex_factory = UnplacedExecutorFactory(
use_caching=False,
support_sequence_ops=support_sequence_ops,
can_resolve_references=reference_resolving_clients,
server_device=server_tf_device,
......@@ -754,7 +746,6 @@ def thread_debugging_executor_factory(
"""
py_typecheck.check_type(clients_per_thread, int)
unplaced_ex_factory = UnplacedExecutorFactory(
use_caching=False,
can_resolve_references=False,
leaf_executor_fn=leaf_executor_fn)
federating_executor_factory = FederatingExecutorFactory(
......@@ -804,7 +795,7 @@ def sizing_executor_factory(
if max_fanout < 2:
raise ValueError('Max fanout must be greater than 1.')
unplaced_ex_factory = UnplacedExecutorFactory(
use_caching=False, leaf_executor_fn=leaf_executor_fn)
leaf_executor_fn=leaf_executor_fn)
federating_executor_factory = FederatingExecutorFactory(
clients_per_thread=clients_per_thread,
unplaced_ex_factory=unplaced_ex_factory,
......@@ -988,7 +979,7 @@ def remote_executor_factory(
num_clients = cardinalities.get(placements.CLIENTS, default_num_clients)
return _configure_remote_workers(num_clients, remote_executors)
unplaced_ex_factory = UnplacedExecutorFactory(use_caching=False)
unplaced_ex_factory = UnplacedExecutorFactory()
composing_executor_factory = ComposingExecutorFactory(
max_fanout=max_fanout,
unplaced_ex_factory=unplaced_ex_factory,
......
......@@ -89,7 +89,6 @@ class MultiGPUTest(tf.test.TestCase, parameterized.TestCase):
server_tf_device = None if not tf_devices else tf_devices[0]
gpu_devices = tf.config.list_logical_devices('GPU')
unplaced_factory = executor_stacks.UnplacedExecutorFactory(
use_caching=True,
server_device=server_tf_device,
client_devices=gpu_devices)
unplaced_executor = unplaced_factory.create_executor()
......
......@@ -373,22 +373,16 @@ class ExecutorStacksTest(parameterized.TestCase):
class UnplacedExecutorFactoryTest(parameterized.TestCase):
def test_constructs_executor_factory(self):
unplaced_factory = executor_stacks.UnplacedExecutorFactory(use_caching=True)
unplaced_factory = executor_stacks.UnplacedExecutorFactory()
self.assertIsInstance(unplaced_factory, executor_factory.ExecutorFactory)
def test_constructs_executor_factory_without_caching(self):
unplaced_factory_no_caching = executor_stacks.UnplacedExecutorFactory(
use_caching=False)
self.assertIsInstance(unplaced_factory_no_caching,
executor_factory.ExecutorFactory)
def test_create_executor_returns_executor(self):
unplaced_factory = executor_stacks.UnplacedExecutorFactory(use_caching=True)
unplaced_factory = executor_stacks.UnplacedExecutorFactory()
unplaced_executor = unplaced_factory.create_executor(cardinalities={})
self.assertIsInstance(unplaced_executor, executor_base.Executor)
def test_create_executor_raises_with_nonempty_cardinalitites(self):
unplaced_factory = executor_stacks.UnplacedExecutorFactory(use_caching=True)
unplaced_factory = executor_stacks.UnplacedExecutorFactory()
with self.assertRaises(ValueError):
unplaced_factory.create_executor(cardinalities={placements.SERVER: 1})
......@@ -399,7 +393,7 @@ class UnplacedExecutorFactoryTest(parameterized.TestCase):
tf_devices = tf.config.list_logical_devices(tf_device)
server_tf_device = None if not tf_devices else tf_devices[0]
unplaced_factory = executor_stacks.UnplacedExecutorFactory(
use_caching=False, server_device=server_tf_device)
server_device=server_tf_device)
unplaced_executor = unplaced_factory.create_executor()
self.assertIsInstance(unplaced_executor, executor_base.Executor)
......@@ -409,7 +403,7 @@ class UnplacedExecutorFactoryTest(parameterized.TestCase):
def test_create_executor_with_client_devices(self, tf_device):
tf_devices = tf.config.list_logical_devices(tf_device)