value_impl.py 19.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2018, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementations of the abstract interface Value in api/value_base."""

import abc
17
import collections
18
from typing import Any, Union
19

20
import attr
21
import tensorflow as tf
22

23
from tensorflow_federated.proto.v0 import computation_pb2 as pb
24
from tensorflow_federated.python.common_libs import py_typecheck
25
from tensorflow_federated.python.common_libs import structure
26
from tensorflow_federated.python.core.api import computation_base
27
from tensorflow_federated.python.core.api import computation_types
28
from tensorflow_federated.python.core.api import value_base
29
from tensorflow_federated.python.core.impl.compiler import building_block_factory
30
from tensorflow_federated.python.core.impl.compiler import building_blocks
31
from tensorflow_federated.python.core.impl.compiler import intrinsic_defs
32
from tensorflow_federated.python.core.impl.compiler import tensorflow_computation_factory
33
from tensorflow_federated.python.core.impl.context_stack import context_base
34
from tensorflow_federated.python.core.impl.context_stack import context_stack_base
35
from tensorflow_federated.python.core.impl.context_stack import symbol_binding_context
36
from tensorflow_federated.python.core.impl.types import placement_literals
37
from tensorflow_federated.python.core.impl.types import type_conversions
38
from tensorflow_federated.python.core.impl.utils import function_utils
39
from tensorflow_federated.python.core.impl.utils import tensorflow_utils
40
41


42
43
44
45
46
47
def _unfederated(type_signature):
  if type_signature.is_federated():
    return type_signature.member
  return type_signature


48
# Note: not a `ValueImpl` method because of the `__setattr__` override
49
def _is_federated_named_tuple(vimpl: 'ValueImpl') -> bool:
50
  comp_ty = vimpl.type_signature
51
  return comp_ty.is_federated() and comp_ty.member.is_struct()
52
53


54
# Note: not a `ValueImpl` method because of the `__setattr__` override
55
def _is_named_tuple(vimpl: 'ValueImpl') -> bool:
56
  return vimpl.type_signature.is_struct()  # pylint: disable=protected-access
57
58


59
def _check_struct_or_federated_struct(
60
    vimpl: 'ValueImpl',
61
    attribute: str,
62
):
63
  if not (_is_named_tuple(vimpl) or _is_federated_named_tuple(vimpl)):
64
65
66
    raise AttributeError(
        f'`tff.Value` of non-structural type {vimpl.type_signature} has no '
        f'attribute {attribute}')
67
68


69
70
71
72
73
74
75
76
77
78
79
def _check_symbol_binding_context(context: context_base.Context):
  if not isinstance(context, symbol_binding_context.SymbolBindingContext):
    raise context_base.ContextError('TFF values should only be materialized '
                                    'inside a context which can bind '
                                    'references, generally a '
                                    '`FederatedComputationContext`. Attempted '
                                    'to materialize a TFF value in a context '
                                    '{c} of type {t}.'.format(
                                        c=context, t=type(context)))


80
class ValueImpl(value_base.Value, metaclass=abc.ABCMeta):
81
82
  """A generic base class for values that appear in TFF computations.

83
84
  If the value in this class is of `StructType` or `FederatedType`
  containing a `StructType`, the inner fields can be accessed by name
85
86
87
88
89
90
91
  (e.g. `my_value_impl.x = ...` or `y = my_value_impl.y`).

  Note that setting nested fields (e.g. `my_value_impl.x.y = ...`) will not
  work properly because it translates to
  `my_value_impl.__getattr__('x').__setattr__('y')`, but the object returned
  by `__getattr__` cannot proxy writes back to the original `ValueImpl`.
  """
92

Taylor Cramer's avatar
Taylor Cramer committed
93
94
95
96
97
  def __init__(
      self,
      comp: building_blocks.ComputationBuildingBlock,
      context_stack: context_stack_base.ContextStack,
  ):
98
99
100
    """Constructs a value of the given type.

    Args:
Michael Reneer's avatar
Michael Reneer committed
101
102
      comp: An instance of building_blocks.ComputationBuildingBlock that
        contains the logic that computes this value.
103
      context_stack: The context stack to use.
104
    """
105
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
106
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
107
    _check_symbol_binding_context(context_stack.current)
108
109
110
    # We override `__setattr__` for `ValueImpl` and so must assign fields using
    # the `__setattr__` impl on the superclass (rather than simply using
    # e.g. `self._comp = comp`.
111
112
    super().__setattr__('_comp', comp)
    super().__setattr__('_context_stack', context_stack)
113
114
115
116
117
118
119
120
121
122

  @property
  def type_signature(self):
    return self._comp.type_signature

  @classmethod
  def get_comp(cls, value):
    py_typecheck.check_type(value, cls)
    return value._comp  # pylint: disable=protected-access

123
124
125
126
127
  @classmethod
  def get_context_stack(cls, value):
    py_typecheck.check_type(value, cls)
    return value._context_stack  # pylint: disable=protected-access

128
129
130
131
132
133
134
  def __repr__(self):
    return repr(self._comp)

  def __str__(self):
    return str(self._comp)

  def __dir__(self):
135
    attributes = ['type_signature']
136
137
138
    type_signature = _unfederated(self.type_signature)
    if type_signature.is_struct():
      attributes.extend(dir(type_signature))
139
    return attributes
140
141

  def __getattr__(self, name):
142
    py_typecheck.check_type(name, str)
143
    _check_struct_or_federated_struct(self, name)
144
    if _is_federated_named_tuple(self):
145
146
147
148
149
150
      if name not in structure.name_list(self.type_signature.member):
        raise AttributeError(
            'There is no such attribute \'{}\' in this federated tuple. Valid '
            'attributes: ({})'.format(
                name, ', '.join(dir(self.type_signature.member))))

151
      return ValueImpl(
152
          building_block_factory.create_federated_getattr_call(
153
              self._comp, name), self._context_stack)
154
    if name not in dir(self.type_signature):
155
      raise AttributeError(
156
          'There is no such attribute \'{}\' in this tuple. Valid attributes: ({})'
157
          .format(name, ', '.join(dir(self.type_signature))))
158
    if self._comp.is_struct():
159
      return ValueImpl(getattr(self._comp, name), self._context_stack)
160
    return ValueImpl(
161
        building_blocks.Selection(self._comp, name=name), self._context_stack)
162

163
  def __setattr__(self, name, value):
164
    py_typecheck.check_type(name, str)
165
    _check_struct_or_federated_struct(self, name)
166
    value_comp = ValueImpl.get_comp(to_value(value, None, self._context_stack))
167
    if _is_federated_named_tuple(self):
168
      new_comp = building_block_factory.create_federated_setattr_call(
169
          self._comp, name, value_comp)
170
      super().__setattr__('_comp', new_comp)
171
      return
172
    named_tuple_setattr_lambda = building_block_factory.create_named_tuple_setattr_lambda(
173
        self.type_signature, name, value_comp)
174
    new_comp = building_blocks.Call(named_tuple_setattr_lambda, self._comp)
175
176
177
    fc_context = self._context_stack.current
    ref = fc_context.bind_computation_to_reference(new_comp)
    super().__setattr__('_comp', ref)
178

179
180
181
182
183
184
  def __bool__(self):
    raise TypeError(
        'Federated computation values do not support boolean operations. '
        'If you were attempting to perform logic on tensors, consider moving '
        'this logic into a tff.tf_computation.')

185
  def __len__(self):
186
    type_signature = _unfederated(self.type_signature)
187
    if not type_signature.is_struct():
188
      raise TypeError(
189
          'Operator len() is only supported for (possibly federated) structure '
190
191
          'types, but the object on which it has been invoked is of type {}.'
          .format(self.type_signature))
192
    return len(type_signature)
193

194
195
196
197
  def __getitem__(self, key: Union[int, str, slice]):
    py_typecheck.check_type(key, (int, str, slice))
    if isinstance(key, str):
      return getattr(self, key)
198
    if _is_federated_named_tuple(self):
199
      return ValueImpl(
200
201
          building_block_factory.create_federated_getitem_call(self._comp, key),
          self._context_stack)
202
    if not _is_named_tuple(self):
203
      raise TypeError(
204
          'Operator getitem() is only supported for structure types, but the '
205
          'object on which it has been invoked is of type {}.'.format(
206
207
              self.type_signature))
    elem_length = len(self.type_signature)
208
209
210
211
    if isinstance(key, int):
      if key < 0 or key >= elem_length:
        raise IndexError(
            'The index of the selected element {} is out of range.'.format(key))
212
      if self._comp.is_struct():
213
        return ValueImpl(self._comp[key], self._context_stack)
214
215
      else:
        return ValueImpl(
216
            building_blocks.Selection(self._comp, index=key),
217
            self._context_stack)
218
219
220
221
222
    elif isinstance(key, slice):
      index_range = range(*key.indices(elem_length))
      if not index_range:
        raise IndexError('Attempted to slice 0 elements, which is not '
                         'currently supported.')
223
      return to_value([self[k] for k in index_range], None, self._context_stack)
224
225

  def __iter__(self):
226
    type_signature = _unfederated(self.type_signature)
227
    if not type_signature.is_struct():
228
      raise TypeError(
229
230
231
          'Operator iter() is only supported for (possibly federated) structure '
          'types, but the object on which it has been invoked is of type {}.'
          .format(self.type_signature))
232
233
    for index in range(len(type_signature)):
      yield self[index]
234
235

  def __call__(self, *args, **kwargs):
236
    if not self.type_signature.is_function():
237
      raise SyntaxError(
Michael Reneer's avatar
Michael Reneer committed
238
239
          'Function-like invocation is only supported for values of functional '
          'types, but the value being invoked is of type {} that does not '
240
          'support invocation.'.format(self.type_signature))
241
242
243
    if args or kwargs:
      args = [to_value(x, None, self._context_stack) for x in args]
      kwargs = {
244
          k: to_value(v, None, self._context_stack) for k, v in kwargs.items()
245
      }
246
      arg = function_utils.pack_args(self.type_signature.parameter, args,
247
                                     kwargs, self._context_stack.current)
248
      arg = ValueImpl.get_comp(to_value(arg, None, self._context_stack))
249
    else:
250
      arg = None
251
252
253
254
    fc_context = self._context_stack.current
    call = building_blocks.Call(self._comp, arg)
    ref = fc_context.bind_computation_to_reference(call)
    return ValueImpl(ref, self._context_stack)
255

256
  def __add__(self, other):
257
    other = to_value(other, None, self._context_stack)
258
    if not self.type_signature.is_equivalent_to(other.type_signature):
Michael Reneer's avatar
Michael Reneer committed
259
260
      raise TypeError('Cannot add {} and {}.'.format(self.type_signature,
                                                     other.type_signature))
261
262
263
264
265
266
267
268
269
270
    call = building_blocks.Call(
        building_blocks.Intrinsic(
            intrinsic_defs.GENERIC_PLUS.uri,
            computation_types.FunctionType(
                [self.type_signature, self.type_signature],
                self.type_signature)),
        ValueImpl.get_comp(to_value([self, other], None, self._context_stack)))
    fc_context = self._context_stack.current
    ref = fc_context.bind_computation_to_reference(call)
    return ValueImpl(ref, self._context_stack)
271

272

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def _wrap_computation_as_value(
    proto: pb.Computation,
    context_stack: context_stack_base.ContextStack) -> value_base.Value:
  """Wraps the given computation as a `tff.Value`.

  Args:
    proto: A pb.Computation.
    context_stack: The context stack to use.

  Returns:
    A `value_base.Value`.
  """
  py_typecheck.check_type(proto, pb.Computation)
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
  compiled = building_blocks.CompiledComputation(proto)
  call = building_blocks.Call(compiled)
  federated_computation_context = context_stack.current
  ref = federated_computation_context.bind_computation_to_reference(call)
  return ValueImpl(ref, context_stack)


294
def _wrap_constant_as_value(const, context_stack):
295
  """Wraps the given Python constant as a `tff.Value`.
296
297

  Args:
298
299
    const: Python constant to be converted to TFF value. Anything convertible to
      Tensor via `tf.constant` can be passed in.
300
    context_stack: The context stack to use.
301
302

  Returns:
303
    An instance of `value_base.Value`.
304
  """
305
306
307
  tf_comp, _ = tensorflow_computation_factory.create_computation_for_py_fn(
      fn=lambda: tf.constant(const), parameter_type=None)
  return _wrap_computation_as_value(tf_comp, context_stack)
308
309


310
def _wrap_sequence_as_value(elements, element_type, context_stack):
311
312
313
314
  """Wraps `elements` as a TFF sequence with elements of type `element_type`.

  Args:
    elements: Python object to the wrapped as a TFF sequence value.
315
316
    element_type: An instance of `Type` that determines the type of elements of
      the sequence.
317
    context_stack: The context stack to use.
318
319

  Returns:
320
    An instance of `tff.Value`.
321
322
323
324
325
326

  Raises:
    TypeError: If `elements` and `element_type` are of incompatible types.
  """
  # TODO(b/113116813): Add support for other representations of sequences.
  py_typecheck.check_type(elements, list)
327
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
328
329
330
  for element in elements:
    inferred_type = type_conversions.infer_type(element)
    if not element_type.is_assignable_from(inferred_type):
331
332
      raise TypeError(
          'Expected all sequence elements to be {}, found {}.'.format(
333
              element_type, inferred_type))
334
335

  def _create_dataset_from_elements():
336
    return tensorflow_utils.make_data_set_from_elements(
337
        tf.compat.v1.get_default_graph(), elements, element_type)
338

339
340
341
  proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
      fn=_create_dataset_from_elements, parameter_type=None)
  return _wrap_computation_as_value(proto, context_stack)
342
343


344
def _dictlike_items_to_value(items, context_stack, container_type) -> ValueImpl:
345
  value = building_blocks.Struct(
346
347
348
349
350
      [(k, ValueImpl.get_comp(to_value(v, None, context_stack)))
       for k, v in items], container_type)
  return ValueImpl(value, context_stack)


Taylor Cramer's avatar
Taylor Cramer committed
351
352
353
354
def to_value(
    arg: Any,
    type_spec,
    context_stack: context_stack_base.ContextStack,
355
    parameter_type_hint=None,
Taylor Cramer's avatar
Taylor Cramer committed
356
) -> ValueImpl:
357
358
359
360
361
  """Converts the argument into an instance of `tff.Value`.

  The types of non-`tff.Value` arguments that are currently convertible to
  `tff.Value` include the following:

362
  * Lists, tuples, `structure.Struct`s, named tuples, and dictionaries, all
363
364
365
    of which are converted into instances of `tff.Tuple`.
  * Placement literals, converted into instances of `tff.Placement`.
  * Computations.
366
367
368
  * Python constants of type `str`, `int`, `float`, `bool`
  * Numpy objects inherting from `np.ndarray` or `np.generic` (the parent
    of numpy scalar types)
369
370

  Args:
371
372
    arg: Either an instance of `tff.Value`, or an argument convertible to
      `tff.Value`. The argument must not be `None`.
Taylor Cramer's avatar
Taylor Cramer committed
373
374
375
    type_spec: An optional `computation_types.Type` or value convertible to it
      by `computation_types.to_type` which specifies the desired type signature
      of the resulting value. This allows for disambiguating the target type
376
377
378
379
      (e.g., when two TFF types can be mapped to the same Python
      representations), or `None` if none available, in which case TFF tries to
      determine the type of the TFF value automatically.
    context_stack: The context stack to use.
380
381
382
383
    parameter_type_hint: An optional `computation_types.Type` or value
      convertible to it by `computation_types.to_type` which specifies an
      argument type to use in the case that `arg` is a
      `function_utils.PolymorphicFunction`.
384
385

  Returns:
386
    An instance of `tff.Value` corresponding to the given `arg`, and of TFF type
387
    matching the `type_spec` if specified (not `None`).
388
389

  Raises:
390
    TypeError: if `arg` is of an unsupported type, or of a type that does not
391
392
393
      match `type_spec`. Raises explicit error message if TensorFlow constructs
      are encountered, as TensorFlow code should be sealed away from TFF
      federated context.
394
  """
395
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
396
  _check_symbol_binding_context(context_stack.current)
397
398
  if type_spec is not None:
    type_spec = computation_types.to_type(type_spec)
399
  if isinstance(arg, ValueImpl):
400
    result = arg
401
  elif isinstance(arg, building_blocks.ComputationBuildingBlock):
402
    result = ValueImpl(arg, context_stack)
403
  elif isinstance(arg, placement_literals.PlacementLiteral):
404
    result = ValueImpl(building_blocks.Placement(arg), context_stack)
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
  elif isinstance(
      arg, (computation_base.Computation, function_utils.PolymorphicFunction)):
    if isinstance(arg, function_utils.PolymorphicFunction):
      if parameter_type_hint is None:
        raise TypeError(
            'Polymorphic computations cannot be converted to TFF values '
            'without a type hint. Consider explicitly specifying the '
            'argument types of a computation before passing it to a '
            'function that requires a TFF value (such as a TFF intrinsic '
            'like `federated_map`). If you are a TFF developer and think '
            'this should be supported, consider providing `parameter_type_hint` '
            'as an argument to the encompassing `to_value` conversion.')
      parameter_type_hint = computation_types.to_type(parameter_type_hint)
      arg = arg.fn_for_argument_type(parameter_type_hint)
    py_typecheck.check_type(arg, computation_base.Computation)
420
    result = ValueImpl(arg.to_compiled_building_block(), context_stack)
421
  elif type_spec is not None and type_spec.is_sequence():
422
    result = _wrap_sequence_as_value(arg, type_spec.element, context_stack)
423
  elif isinstance(arg, structure.Struct):
424
    result = ValueImpl(
425
        building_blocks.Struct([
Michael Reneer's avatar
Michael Reneer committed
426
            (k, ValueImpl.get_comp(to_value(v, None, context_stack)))
427
            for k, v in structure.iter_elements(arg)
Michael Reneer's avatar
Michael Reneer committed
428
        ]), context_stack)
429
  elif py_typecheck.is_named_tuple(arg):
Zachary Garrett's avatar
Zachary Garrett committed
430
    items = arg._asdict().items()
431
    result = _dictlike_items_to_value(items, context_stack, type(arg))
432
  elif py_typecheck.is_attrs(arg):
433
434
435
    items = attr.asdict(
        arg, dict_factory=collections.OrderedDict, recurse=False).items()
    result = _dictlike_items_to_value(items, context_stack, type(arg))
436
  elif isinstance(arg, dict):
437
    if isinstance(arg, collections.OrderedDict):
438
      items = arg.items()
439
    else:
440
      items = sorted(arg.items())
441
    result = _dictlike_items_to_value(items, context_stack, type(arg))
442
  elif isinstance(arg, (tuple, list)):
443
    result = ValueImpl(
444
        building_blocks.Struct(
445
446
            [ValueImpl.get_comp(to_value(x, None, context_stack)) for x in arg],
            type(arg)), context_stack)
447
  elif isinstance(arg, tensorflow_utils.TENSOR_REPRESENTATION_TYPES):
448
449
450
451
452
453
454
    result = _wrap_constant_as_value(arg, context_stack)
  elif isinstance(arg, (tf.Tensor, tf.Variable)):
    raise TypeError(
        'TensorFlow construct {} has been encountered in a federated '
        'context. TFF does not support mixing TF and federated orchestration '
        'code. Please wrap any TensorFlow constructs with '
        '`tff.tf_computation`.'.format(arg))
455
456
457
458
  else:
    raise TypeError(
        'Unable to interpret an argument of type {} as a TFF value.'.format(
            py_typecheck.type_string(type(arg))))
459
  py_typecheck.check_type(result, ValueImpl)
460
  if (type_spec is not None and
461
      not type_spec.is_assignable_from(result.type_signature)):
462
    raise TypeError(
Michael Reneer's avatar
Michael Reneer committed
463
464
        'The supplied argument maps to TFF type {}, which is incompatible with '
        'the requested type {}.'.format(result.type_signature, type_spec))
465
  return result