caching_executor.py 13.9 KB
Newer Older
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An executor that caches and reuses values on repeated calls."""

16
import asyncio
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
17
import collections
18
import cachetools
19

Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
20
import numpy as np
21
import tensorflow as tf
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
22
23
24

from tensorflow_federated.proto.v0 import computation_pb2 as pb
from tensorflow_federated.python.common_libs import py_typecheck
25
from tensorflow_federated.python.common_libs import structure
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
26
27
28
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.impl import computation_impl
from tensorflow_federated.python.core.impl import type_utils
29
30
from tensorflow_federated.python.core.impl.executors import executor_base
from tensorflow_federated.python.core.impl.executors import executor_value_base
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
31
32


33
class HashableWrapper(collections.abc.Hashable):
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
  """A wrapper around non-hashable objects to be compared by identity."""

  def __init__(self, target):
    self._target = target

  def __hash__(self):
    return hash(id(self._target))

  def __eq__(self, other):
    return other is HashableWrapper and other._target is self._target  # pylint: disable=protected-access


def _get_hashable_key(value, type_spec):
  """Return a hashable key for value `value` of TFF type `type_spec`.

  Args:
    value: An argument to `create_value()`.
    type_spec: An optional type signature.

  Returns:
    A hashable key to use such that the same `value` always maps to the same
    key, and different ones map to different keys.

  Raises:
    TypeError: If there is no hashable key for this type of a value.
  """
60
  if type_spec.is_struct():
61
    if not isinstance(value, structure.Struct):
62
      try:
63
        value = structure.from_container(value)
64
65
      except Exception as e:
        raise TypeError(
66
67
68
            'Failed to convert value with type_spec {} to `Struct`'.format(
                repr(type_spec))) from e
    type_specs = structure.iter_elements(type_spec)
69
70
71
    r_elem = []
    for v, (field_name, field_type) in zip(value, type_specs):
      r_elem.append((field_name, _get_hashable_key(v, field_type)))
72
    return structure.Struct(r_elem)
73
  elif type_spec.is_federated():
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
74
75
76
77
78
    if type_spec.all_equal:
      return _get_hashable_key(value, type_spec.member)
    else:
      return tuple([_get_hashable_key(x, type_spec.member) for x in value])
  elif isinstance(value, pb.Computation):
79
    return value.SerializeToString(deterministic=True)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
80
  elif isinstance(value, np.ndarray):
81
82
    return ('<dtype={},shape={}>'.format(value.dtype,
                                         value.shape), value.tobytes())
83
  elif (isinstance(value, collections.abc.Hashable) and
84
85
        not isinstance(value, (tf.Tensor, tf.Variable))):
    # TODO(b/139200385): Currently Tensor and Variable returns True for
86
87
    # `isinstance(value, collections.abc.Hashable)` even when it's not hashable.
    # Hence this workaround.
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
88
89
90
91
92
    return value
  else:
    return HashableWrapper(value)


93
class CachedValueIdentifier(collections.abc.Hashable):
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  """An identifier for a cached value."""

  def __init__(self, identifier):
    py_typecheck.check_type(identifier, str)
    self._identifier = identifier

  def __hash__(self):
    return hash(self._identifier)

  def __eq__(self, other):
    # pylint: disable=protected-access
    return (isinstance(other, CachedValueIdentifier) and
            self._identifier == other._identifier)
    # pylint: enable=protected-access

  def __repr__(self):
110
    return 'CachedValueIdentifier({!r})'.format(self._identifier)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

  def __str__(self):
    return self._identifier


class CachedValue(executor_value_base.ExecutorValue):
  """A value held by the caching executor."""

  def __init__(self, identifier, hashable_key, type_spec, target_future):
    """Creates a cached value.

    Args:
      identifier: An instance of `CachedValueIdentifier`.
      hashable_key: A hashable source value key, if any, or `None` of not
        applicable in this context, for use during cleanup.
      type_spec: The type signature of the target, an instance of `tff.Type`.
      target_future: An asyncio future that returns an instance of
        `executor_value_base.ExecutorValue` that represents a value embedded in
        the target executor.

    Raises:
      TypeError: If the arguments are of the wrong types.
    """
    py_typecheck.check_type(identifier, CachedValueIdentifier)
135
    py_typecheck.check_type(hashable_key, collections.abc.Hashable)
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    py_typecheck.check_type(type_spec, computation_types.Type)
    if not asyncio.isfuture(target_future):
      raise TypeError('Expected an asyncio future, got {}'.format(
          py_typecheck.type_string(type(target_future))))
    self._identifier = identifier
    self._hashable_key = hashable_key
    self._type_spec = type_spec
    self._target_future = target_future
    self._computed_result = None

  @property
  def type_signature(self):
    return self._type_spec

  @property
  def identifier(self):
    return self._identifier

  @property
  def hashable_key(self):
    return self._hashable_key

  @property
  def target_future(self):
    return self._target_future

  async def compute(self):
    if self._computed_result is None:
      target_value = await self._target_future
      self._computed_result = await target_value.compute()
    return self._computed_result


_DEFAULT_CACHE_SIZE = 1000


Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
172
173
174
class CachingExecutor(executor_base.Executor):
  """The caching executor only performs caching."""

175
176
177
178
179
180
  # TODO(b/134543154): Factor out default cache settings to supply elsewhere,
  # possibly as a part of executor stack configuration.

  # TODO(b/134543154): It might be desirable to still keep aorund things that
  # are currently in use (referenced) regardless of what's in the cache. This
  # can be added later on.
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
181

182
  def __init__(self, target_executor, cache=None):
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
183
184
185
186
    """Creates a new instance of this executor.

    Args:
      target_executor: An instance of `executor_base.Executor`.
Michael Reneer's avatar
Michael Reneer committed
187
188
      cache: The cache to use (must be an instance of `cachetools.Cache`). If
        unspecified, by default we construct a 1000-element LRU cache.
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
189
190
    """
    py_typecheck.check_type(target_executor, executor_base.Executor)
191
192
193
194
    if cache is not None:
      py_typecheck.check_type(cache, cachetools.Cache)
    else:
      cache = cachetools.LRUCache(_DEFAULT_CACHE_SIZE)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
195
    self._target_executor = target_executor
196
    self._cache = cache
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
197
198
    self._num_values_created = 0

199
  def close(self):
200
    self._cache.clear()
201
202
    self._target_executor.close()

Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
203
  def __del__(self):
204
205
    for k in list(self._cache):
      del self._cache[k]
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
206
207
208
209
210
211
212

  async def create_value(self, value, type_spec=None):
    type_spec = computation_types.to_type(type_spec)
    if isinstance(value, computation_impl.ComputationImpl):
      return await self.create_value(
          computation_impl.ComputationImpl.get_proto(value),
          type_utils.reconcile_value_with_type_spec(value, type_spec))
213
    py_typecheck.check_type(type_spec, computation_types.Type)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
214
215
    hashable_key = _get_hashable_key(value, type_spec)
    try:
216
      identifier = self._cache.get(hashable_key)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
217
218
    except TypeError as err:
      raise RuntimeError(
Zachary Garrett's avatar
Zachary Garrett committed
219
          'Failed to perform a hash table lookup with a value of Python '
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
220
221
          'type {} and TFF type {}, and payload {}: {}'.format(
              py_typecheck.type_string(type(value)), type_spec, value, err))
222
    if isinstance(identifier, CachedValueIdentifier):
223
      cached_value = self._cache.get(identifier)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
224
225
226
227
      # If may be that the same payload appeared with a mismatching type spec,
      # which may be a legitimate use case if (as it happens) the payload alone
      # does not uniquely determine the type, so we simply opt not to reuse the
      # cache value and fallback on the regular behavior.
228
229
      if (cached_value is not None and type_spec is not None and
          not cached_value.type_signature.is_equivalent_to(type_spec)):
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
230
        identifier = None
231
232
    else:
      identifier = None
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
233
234
    if identifier is None:
      self._num_values_created = self._num_values_created + 1
235
236
237
238
      identifier = CachedValueIdentifier(str(self._num_values_created))
      self._cache[hashable_key] = identifier
      target_future = asyncio.ensure_future(
          self._target_executor.create_value(value, type_spec))
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
239
240
      cached_value = None
    if cached_value is None:
241
242
243
      cached_value = CachedValue(identifier, hashable_key, type_spec,
                                 target_future)
      self._cache[identifier] = cached_value
244
245
    try:
      await cached_value.target_future
Scott Wegner's avatar
Scott Wegner committed
246
    except Exception:
247
248
249
250
251
252
      # Invalidate the entire cache in the inner executor had an exception.
      # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
      # only the current cache item needs to be invalidated; however this
      # currently only occurs when an inner RemoteExecutor has the backend go
      # down.
      self._cache = {}
Scott Wegner's avatar
Scott Wegner committed
253
      raise
254
    # No type check is necessary here; we have either checked
255
    # `is_equivalent_to` or just constructed `target_value`
256
    # explicitly with `type_spec`.
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
257
258
259
260
    return cached_value

  async def create_call(self, comp, arg=None):
    py_typecheck.check_type(comp, CachedValue)
261
262
    py_typecheck.check_type(comp.type_signature, computation_types.FunctionType)
    to_gather = [comp.target_future]
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
263
264
    if arg is not None:
      py_typecheck.check_type(arg, CachedValue)
265
      comp.type_signature.parameter.check_assignable_from(arg.type_signature)
266
267
      to_gather.append(arg.target_future)
      identifier_str = '{}({})'.format(comp.identifier, arg.identifier)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
268
    else:
269
270
271
272
273
274
275
276
277
278
279
      identifier_str = '{}()'.format(comp.identifier)
    gathered = await asyncio.gather(*to_gather)
    type_spec = comp.type_signature.result
    identifier = CachedValueIdentifier(identifier_str)
    try:
      cached_value = self._cache[identifier]
    except KeyError:
      target_future = asyncio.ensure_future(
          self._target_executor.create_call(*gathered))
      cached_value = CachedValue(identifier, None, type_spec, target_future)
      self._cache[identifier] = cached_value
280
281
    try:
      target_value = await cached_value.target_future
Scott Wegner's avatar
Scott Wegner committed
282
    except Exception:
283
284
285
286
287
      # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
      # only the current cache item needs to be invalidated; however this
      # currently only occurs when an inner RemoteExecutor has the backend go
      # down.
      self._cache = {}
Scott Wegner's avatar
Scott Wegner committed
288
      raise
289
    type_spec.check_assignable_from(target_value.type_signature)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
290
291
    return cached_value

292
  async def create_struct(self, elements):
293
294
    if not isinstance(elements, structure.Struct):
      elements = structure.from_container(elements)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
295
    element_strings = []
296
    element_kv_pairs = structure.to_elements(elements)
297
298
299
    to_gather = []
    type_elements = []
    for k, v in element_kv_pairs:
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
300
      py_typecheck.check_type(v, CachedValue)
301
      to_gather.append(v.target_future)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
302
303
      if k is not None:
        py_typecheck.check_type(k, str)
Michael Reneer's avatar
Michael Reneer committed
304
        element_strings.append('{}={}'.format(k, v.identifier))
305
        type_elements.append((k, v.type_signature))
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
306
      else:
307
308
        element_strings.append(str(v.identifier))
        type_elements.append(v.type_signature)
309
    type_spec = computation_types.StructType(type_elements)
310
311
312
313
314
315
    gathered = await asyncio.gather(*to_gather)
    identifier = CachedValueIdentifier('<{}>'.format(','.join(element_strings)))
    try:
      cached_value = self._cache[identifier]
    except KeyError:
      target_future = asyncio.ensure_future(
316
          self._target_executor.create_struct(
317
              structure.Struct(
318
                  (k, v) for (k, _), v in zip(element_kv_pairs, gathered))))
319
320
      cached_value = CachedValue(identifier, None, type_spec, target_future)
      self._cache[identifier] = cached_value
321
322
    try:
      target_value = await cached_value.target_future
Scott Wegner's avatar
Scott Wegner committed
323
    except Exception:
324
325
326
327
328
      # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
      # only the current cache item needs to be invalidated; however this
      # currently only occurs when an inner RemoteExecutor has the backend go
      # down.
      self._cache = {}
Scott Wegner's avatar
Scott Wegner committed
329
      raise
330
    type_spec.check_assignable_from(target_value.type_signature)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
331
332
333
334
    return cached_value

  async def create_selection(self, source, index=None, name=None):
    py_typecheck.check_type(source, CachedValue)
335
    py_typecheck.check_type(source.type_signature, computation_types.StructType)
336
    source_val = await source.target_future
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
337
338
    if index is not None:
      py_typecheck.check_none(name)
339
340
      identifier_str = '{}[{}]'.format(source.identifier, index)
      type_spec = source.type_signature[index]
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
341
342
    else:
      py_typecheck.check_not_none(name)
343
344
345
346
347
348
349
350
351
352
353
      identifier_str = '{}.{}'.format(source.identifier, name)
      type_spec = getattr(source.type_signature, name)
    identifier = CachedValueIdentifier(identifier_str)
    try:
      cached_value = self._cache[identifier]
    except KeyError:
      target_future = asyncio.ensure_future(
          self._target_executor.create_selection(
              source_val, index=index, name=name))
      cached_value = CachedValue(identifier, None, type_spec, target_future)
      self._cache[identifier] = cached_value
354
355
    try:
      target_value = await cached_value.target_future
Scott Wegner's avatar
Scott Wegner committed
356
    except Exception:
357
358
359
360
361
      # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
      # only the current cache item needs to be invalidated; however this
      # currently only occurs when an inner RemoteExecutor has the backend go
      # down.
      self._cache = {}
Scott Wegner's avatar
Scott Wegner committed
362
      raise
363
    type_spec.check_assignable_from(target_value.type_signature)
Krzysztof Ostrowski's avatar
Krzysztof Ostrowski committed
364
    return cached_value