Commit 73a6e4ff authored by Taylor Cramer's avatar Taylor Cramer Committed by tensorflow-copybara
Browse files

Use OrderedDict or tuple to represent structures with unknown Python containers

PiperOrigin-RevId: 411172320
parent f9bab81f
......@@ -268,11 +268,10 @@ class FederatedSampleTest(tf.test.TestCase):
x1 = -1.0
y1 = 5.0
test_type = collections.namedtuple('NestedScalars', ['x', 'y'])
value = call_federated_sample(
result = call_federated_sample(
[test_type(x0, y0),
test_type(x1, y1),
test_type(2.0, -10.0)])
result = value._asdict()
i0 = list(result['x']).index(x0)
i1 = list(result['y']).index(y1)
......@@ -334,11 +333,10 @@ class FederatedSampleTest(tf.test.TestCase):
x = 0.0
y = 5.0
test_type = collections.namedtuple('NestedScalars', ['x', 'y'])
value = call_federated_sample(
result = call_federated_sample(
[test_type(x, y),
test_type(3.4, 5.6),
test_type(1.0, 1.0)])
result = value._asdict()
self.assertIn(y, result['y'])
self.assertIn(x, result['x'])
......@@ -360,13 +358,13 @@ class FederatedSampleTest(tf.test.TestCase):
tuple_type = collections.namedtuple('NestedScalars', ['x', 'y'])
dict_type = collections.namedtuple('NestedScalars', ['a', 'b'])
value = call_federated_sample([
result = call_federated_sample([
nested_test_type(tuple_type(1.2, 2.2), dict_type(1.3, 8.8)),
nested_test_type(tuple_type(-9.1, 3.1), dict_type(1.2, -5.4))
])._asdict(recursive=True)
])
self.assertIn(1.2, value['tuple_1']['x'])
self.assertIn(8.8, value['tuple_2']['b'])
self.assertIn(1.2, result['tuple_1']['x'])
self.assertIn(8.8, result['tuple_2']['b'])
class SecureQuantizedSumStaticAssertsTest(tf.test.TestCase,
......
......@@ -70,11 +70,11 @@ def _build_reservoir_type(
# TODO(b/181365504): relax this to allow `StructType` once a `Struct` can be
# returned from `tf.function` decorated methods.
def is_tesnor_or_struct_with_py_type(t: computation_types.Type) -> bool:
def is_tensor_or_struct_with_py_type(t: computation_types.Type) -> bool:
return t.is_tensor() or t.is_struct_with_python()
if not type_analysis.contains_only(sample_value_type,
is_tesnor_or_struct_with_py_type):
is_tensor_or_struct_with_py_type):
raise TypeError('Cannot create a reservoir for type structure. Sample type '
'must only contain `TensorType` or `StructWithPythonType`, '
f'got a {sample_value_type!r}.')
......
......@@ -260,7 +260,7 @@ def iter_elements(struct: Struct) -> Iterator[Tuple[Optional[str], Any]]:
# pylint: enable=protected-access
def to_odict(struct: Struct, recursive=False):
def to_odict(struct: Struct, recursive=False) -> collections.OrderedDict:
"""Returns `struct` as an `OrderedDict`, if possible.
Args:
......@@ -285,7 +285,9 @@ def to_odict(struct: Struct, recursive=False):
return _to_odict(to_elements(struct))
def to_odict_or_tuple(struct: Struct, recursive=True):
def to_odict_or_tuple(
struct: Struct,
recursive=True) -> Union[collections.OrderedDict, Tuple[Any, ...]]:
"""Returns `struct` as an `OrderedDict` or `tuple`, if possible.
If all elements of `struct` have names, convert `struct` to an
......@@ -304,16 +306,13 @@ def to_odict_or_tuple(struct: Struct, recursive=True):
def _to_odict_or_tuple(elements):
field_is_named = tuple(name is not None for name, _ in elements)
has_names = any(field_is_named)
is_all_named = all(field_is_named)
if is_all_named:
if any(field_is_named):
if not all(field_is_named):
raise ValueError(
'Cannot convert a `Struct` with both named and unnamed '
'entries to an OrderedDict or tuple: {!r}'.format(struct))
return collections.OrderedDict(elements)
elif not has_names:
return tuple(value for _, value in elements)
else:
raise ValueError(
'Cannot convert an `Struct` with both named and unnamed '
'entries to an OrderedDict or tuple: {!r}'.format(struct))
return tuple(value for _, value in elements)
if recursive:
return to_container_recursive(struct, _to_odict_or_tuple)
......@@ -657,16 +656,17 @@ def update_struct(structure, **kwargs):
# In Python 3.8 and later `_asdict` no longer return OrdereDict, rather a
# regular `dict`, so we wrap here to get consistent types across Python
# version.s
d = collections.OrderedDict(structure._asdict())
dictionary = collections.OrderedDict(structure._asdict())
elif py_typecheck.is_attrs(structure):
d = attr.asdict(structure, dict_factory=collections.OrderedDict)
dictionary = attr.asdict(structure, dict_factory=collections.OrderedDict)
else:
for key in kwargs:
if key not in structure:
raise KeyError(
'structure does not contain a field named "{!s}"'.format(key))
d = structure
d.update(kwargs)
# Create a copy to prevent mutation of the original `structure`
dictionary = type(structure)(**structure)
dictionary.update(kwargs)
if isinstance(structure, collections.abc.Mapping):
return d
return type(structure)(**d)
return dictionary
return type(structure)(**dictionary)
......@@ -14,13 +14,16 @@
import collections
from absl.testing import parameterized
import attr
import tensorflow as tf
from tensorflow_federated.python.common_libs import structure
ODict = collections.OrderedDict
class StructTest(tf.test.TestCase):
class StructTest(tf.test.TestCase, parameterized.TestCase):
def test_new_named(self):
x = structure.Struct.named(a=1, b=4)
......@@ -68,7 +71,7 @@ class StructTest(tf.test.TestCase):
self.assertNotEqual(x, structure.Struct([('foo', 10)]))
self.assertEqual(structure.to_elements(x), v)
self.assertEqual(structure.to_odict(x), collections.OrderedDict())
self.assertEqual(structure.to_odict_or_tuple(x), collections.OrderedDict())
self.assertEqual(structure.to_odict_or_tuple(x), ())
self.assertEqual(repr(x), 'Struct([])')
self.assertEqual(str(x), '<>')
......@@ -468,28 +471,17 @@ class StructTest(tf.test.TestCase):
tf.SparseTensor(indices=[[1]], values=[2], dense_shape=[5]))
self.assertEqual(str(x), '<indices=[[1]],values=[2],dense_shape=[5]>')
def test_to_container_recursive(self):
def odict(**kwargs):
return collections.OrderedDict(sorted(list(kwargs.items())))
# Nested OrderedDicts.
s = odict(a=1, b=2, c=odict(d=3, e=odict(f=4, g=5)))
x = structure.from_container(s, recursive=True)
s2 = x._asdict(recursive=True)
self.assertEqual(s, s2)
# Single OrderedDict.
s = odict(a=1, b=2)
x = structure.from_container(s)
self.assertEqual(x._asdict(recursive=True), s)
# Single empty OrderedDict.
s = odict()
x = structure.from_container(s)
self.assertEqual(x._asdict(recursive=True), s)
# Invalid argument.
@parameterized.named_parameters(
('empty', ODict()),
('flat', ODict(a=1, b=2)),
('nested', ODict(a=1, b=2, c=ODict(d=3, e=ODict(f=4, g=5)))),
)
def test_from_container_asdict_roundtrip(self, dict_in):
structure_repr = structure.from_container(dict_in, recursive=True)
dict_out = structure_repr._asdict(recursive=True)
self.assertEqual(dict_in, dict_out)
def test_from_container_raises_on_non_container_argument(self):
with self.assertRaises(TypeError):
structure.from_container(3)
......@@ -554,6 +546,12 @@ class StructTest(tf.test.TestCase):
state3 = structure.update_struct(state2, a=8)
self.assertEqual(state3, {'a': 8, 'b': 2, 'c': 7})
def test_update_struct_on_dict_does_not_mutate_original(self):
state = collections.OrderedDict(a=1, b=2, c=3)
state2 = structure.update_struct(state, c=7)
del state2
self.assertEqual(state, collections.OrderedDict(a=1, b=2, c=3))
def test_update_struct_ordereddict(self):
state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)])
state2 = structure.update_struct(state, c=7)
......@@ -585,53 +583,30 @@ class StructTest(tf.test.TestCase):
with self.assertRaisesRegex(KeyError, 'does not contain a field'):
structure.update_struct({'z': 1}, a=8)
def test_to_ordered_dict_or_tuple(self):
def odict(**kwargs):
return collections.OrderedDict(sorted(list(kwargs.items())))
# Nested OrderedDicts.
s = odict(a=1, b=2, c=odict(d=3, e=odict(f=4, g=5)))
x = structure.from_container(s, recursive=True)
self.assertEqual(s, structure.to_odict_or_tuple(x))
# Single OrderedDict.
s = odict(a=1, b=2)
x = structure.from_container(s)
self.assertEqual(structure.to_odict_or_tuple(x), s)
# Single empty OrderedDict.
s = odict()
x = structure.from_container(s)
self.assertEqual(structure.to_odict_or_tuple(x), s)
# Nested tuples.
s = tuple([1, 2, tuple([3, tuple([4, 5])])])
x = structure.from_container(s, recursive=True)
self.assertEqual(s, structure.to_odict_or_tuple(x))
# Single tuple.
s = tuple([1, 2])
@parameterized.named_parameters(
('empty_tuple', ()),
('flat_tuple', (1, 2)),
('nested_tuple', (1, 2, (3, (4, 5)))),
('flat_dict', ODict(a=1, b=2)),
('nested_dict', ODict(a=1, b=2, c=ODict(d=3, e=ODict(f=4, g=5)))),
('mixed', ODict(a=1, b=2, c=(3, ODict(d=4, e=5)))),
)
def test_to_odict_or_tuple_from_container_roundtrip(self, original):
structure_repr = structure.from_container(original, recursive=True)
out = structure.to_odict_or_tuple(structure_repr)
self.assertEqual(original, out)
def test_to_odict_or_tuple_empty_dict_becomes_empty_tuple(self):
s = collections.OrderedDict()
x = structure.from_container(s)
self.assertEqual(structure.to_odict_or_tuple(x), s)
# Struct from a single empty tuple should be converted to an empty
# OrderedDict.
s = tuple()
x = structure.from_container(s)
self.assertEqual(structure.to_odict_or_tuple(x), collections.OrderedDict())
# Mixed OrderedDicts and tuples.
s = odict(a=1, b=2, c=tuple([3, odict(d=4, e=5)]))
x = structure.from_container(s, recursive=True)
self.assertEqual(s, structure.to_odict_or_tuple(x))
self.assertEqual(structure.to_odict_or_tuple(x), ())
# Mixed OrderedDicts and tuples with recursive=False.
s = odict(a=1, b=2, c=tuple([3, odict(d=4, e=5)]))
def test_to_odict_or_tuple_mixed_nonrecursive(self):
s = ODict(a=1, b=2, c=(3, ODict(d=4, e=5)))
x = structure.from_container(s, recursive=False)
self.assertEqual(s, structure.to_odict_or_tuple(x, recursive=False))
# Struct with named and unnamed elements should raise error.
def test_to_odict_or_tuple_raises_on_mixed_named_and_unnamed(self):
s = [(None, 10), ('foo', 20), ('bar', 30)]
x = structure.Struct(s)
with self.assertRaisesRegex(ValueError, 'named and unnamed'):
......
......@@ -85,7 +85,6 @@ py_test(
":form_utils",
":forms",
":test_utils",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/reference:reference_context",
......
......@@ -18,7 +18,6 @@ from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.backends.mapreduce import form_utils
......@@ -638,10 +637,10 @@ class GetMapReduceFormForIterativeProcessTest(MapReduceFormTestCase,
mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
state = new_it.initialize()
self.assertEqual(state.num_rounds, 0)
self.assertEqual(state['num_rounds'], 0)
state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
self.assertEqual(state.num_rounds, 1)
self.assertEqual(state['num_rounds'], 1)
self.assertAllClose(metrics,
collections.OrderedDict(ratio_over_threshold=0.5))
......@@ -743,18 +742,17 @@ class GetMapReduceFormForIterativeProcessTest(MapReduceFormTestCase,
def test_returns_map_reduce_form_with_secure_sum_bitwidth(self):
mrf = self.get_map_reduce_form_for_client_to_server_fn(
lambda data: intrinsics.federated_secure_sum_bitwidth(data, 7))
self.assertEqual(mrf.secure_sum_bitwidth(), structure.Struct.unnamed(7))
self.assertEqual(mrf.secure_sum_bitwidth(), (7,))
def test_returns_map_reduce_form_with_secure_sum_max_input(self):
mrf = self.get_map_reduce_form_for_client_to_server_fn(
lambda data: intrinsics.federated_secure_sum(data, 12))
self.assertEqual(mrf.secure_sum_max_input(), structure.Struct.unnamed(12))
self.assertEqual(mrf.secure_sum_max_input(), (12,))
def test_returns_map_reduce_form_with_secure_modular_sum_modulus(self):
mrf = self.get_map_reduce_form_for_client_to_server_fn(
lambda data: intrinsics.federated_secure_modular_sum(data, 22))
self.assertEqual(mrf.secure_modular_sum_modulus(),
structure.Struct.unnamed(22))
self.assertEqual(mrf.secure_modular_sum_modulus(), (22,))
class BroadcastFormTest(test_case.TestCase):
......
......@@ -1012,7 +1012,7 @@ class ReferenceContextTest(test_case.TestCase, parameterized.TestCase):
return zero_for([('A', tf.int32), ('B', tf.float32)])
self.assertEqual(str(foo.type_signature), '( -> <A=int32,B=float32>)')
self.assertEqual(str(foo()), '<A=0,B=0.0>')
self.assertEqual(foo(), collections.OrderedDict(A=0, B=0.0))
def test_generic_zero_with_federated_int_on_server(self):
......@@ -1046,10 +1046,10 @@ class ReferenceContextTest(test_case.TestCase, parameterized.TestCase):
'(<x=<A=int32,B=float32>,y=<A=int32,B=float32>> -> <A=int32,B=float32>)'
)
foo_result = foo([2, 0.1], [3, 0.2])
self.assertIsInstance(foo_result, structure.Struct)
self.assertSameElements(dir(foo_result), ['A', 'B'])
self.assertEqual(foo_result.A, 5)
self.assertAlmostEqual(foo_result.B, 0.3, places=2)
self.assertIsInstance(foo_result, collections.OrderedDict)
self.assertSameElements(foo_result.keys(), ['A', 'B'])
self.assertEqual(foo_result['A'], 5) # pylint: disable=invalid-sequence-index
self.assertAlmostEqual(foo_result['B'], 0.3, places=2) # pylint: disable=invalid-sequence-index
def test_sequence_map_with_list_of_integers(self):
......@@ -1123,17 +1123,16 @@ class ReferenceContextTest(test_case.TestCase, parameterized.TestCase):
str(foo.type_signature),
'({<A=float32,B=float32>}@CLIENTS -> <A=float32,B=float32>@SERVER)')
self.assertEqual(
str(
foo([{
'A': 1.0,
'B': 5.0
}, {
'A': 2.0,
'B': 6.0
}, {
'A': 3.0,
'B': 7.0
}])), '<A=2.0,B=6.0>')
foo([{
'A': 1.0,
'B': 5.0
}, {
'A': 2.0,
'B': 6.0
}, {
'A': 3.0,
'B': 7.0
}]), collections.OrderedDict(A=2.0, B=6.0))
def test_federated_zip_at_server(self):
......
......@@ -13,7 +13,7 @@
"""Utilities for type conversion, type checking, type inference, etc."""
import collections
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Type
import attr
import numpy as np
......@@ -348,6 +348,19 @@ def type_from_tensors(tensors):
return computation_types.to_type(type_spec)
def is_container_type_without_names(container_type: Type[Any]) -> bool:
"""Returns whether `container_type`'s elements are unnamed."""
return (issubclass(container_type, (list, tuple)) and
not py_typecheck.is_named_tuple(container_type))
def is_container_type_with_names(container_type: Type[Any]) -> bool:
"""Returns whether `container_type`'s elements are named."""
return (py_typecheck.is_named_tuple(container_type) or
py_typecheck.is_attrs(container_type) or
issubclass(container_type, dict))
def type_to_py_container(value, type_spec):
"""Recursively convert `structure.Struct`s to Python containers.
......@@ -366,7 +379,8 @@ def type_to_py_container(value, type_spec):
Raises:
ValueError: If the conversion is not possible due to a mix of named
and unnamed values.
and unnamed values, or if `value` contains names that are mismatched or
not present in the corresponding index of `type_spec`.
"""
if type_spec.is_federated():
if type_spec.all_equal:
......@@ -398,44 +412,50 @@ def type_to_py_container(value, type_spec):
return value
if not isinstance(value, structure.Struct):
# NOTE: When encountering non-anonymous tuples, we assume that
# NOTE: When encountering non-`structure.Struct`s, we assume that
# this means that we're attempting to re-convert a value that
# already has the proper containers, and we short-circuit to
# avoid re-converting. This is a possibly dangerous assumption.
return value
anon_tuple = value
def is_container_type_without_names(container_type):
return (issubclass(container_type, (list, tuple)) and
not py_typecheck.is_named_tuple(container_type))
def is_container_type_with_names(container_type):
return (py_typecheck.is_named_tuple(container_type) or
py_typecheck.is_attrs(container_type) or
issubclass(container_type, dict))
# TODO(b/133228705): Consider requiring StructWithPythonType.
container_type = structure_type_spec.python_container or structure.Struct
container_is_anon_tuple = structure_type_spec.python_container is None
container_type = structure_type_spec.python_container
# Ensure that names are only added, not mismatched or removed
names_from_value = structure.name_list_with_nones(value)
names_from_type_spec = structure.name_list_with_nones(structure_type_spec)
for value_name, type_name in zip(names_from_value, names_from_type_spec):
if value_name is not None:
if value_name != type_name:
raise ValueError(
f'Cannot convert value with field name `{value_name}` into a '
f'type with field name `{type_name}`.')
num_named_elements = len(dir(structure_type_spec))
num_unnamed_elements = len(structure_type_spec) - num_named_elements
if num_named_elements > 0 and num_unnamed_elements > 0:
raise ValueError(
f'Cannot represent value {value} with a Python container because it '
'contains a mix of named and unnamed elements.\n\nNote: this was '
'previously allowed when using the `tff.structure.Struct` container. '
'This support has been removed: please change to use structures with '
'either all-named or all-unnamed fields.')
if container_type is None:
if num_named_elements:
container_type = collections.OrderedDict
else:
container_type = tuple
# Avoid projecting the `structure.StructType`d TFF value into a Python
# container that is not supported.
if not container_is_anon_tuple:
num_named_elements = len(dir(anon_tuple))
num_unnamed_elements = len(anon_tuple) - num_named_elements
if num_named_elements > 0 and num_unnamed_elements > 0:
raise ValueError('Cannot represent value {} with container type {}, '
'because value contains a mix of named and unnamed '
'elements.'.format(anon_tuple, container_type))
if (num_named_elements > 0 and
is_container_type_without_names(container_type)):
raise ValueError(
'Cannot represent value {} with named elements '
'using container type {} which does not support names. In TFF\'s '
'typesystem, this corresponds to an implicit downcast'.format(
anon_tuple, container_type))
if (num_named_elements > 0 and
is_container_type_without_names(container_type)):
raise ValueError(
'Cannot represent value {} with named elements '
'using container type {} which does not support names. In TFF\'s '
'typesystem, this corresponds to an implicit downcast'.format(
value, container_type))
if (is_container_type_with_names(container_type) and
len(dir(structure_type_spec)) != len(anon_tuple)):
len(dir(structure_type_spec)) != len(value)):
# If the type specifies the names, we have all the information we need.
# Otherwise we must raise here.
raise ValueError('When packaging as a Python value which requires names, '
......@@ -443,17 +463,17 @@ def type_to_py_container(value, type_spec):
'{} names in type spec {} of length {}, with requested'
'python type {}.'.format(
len(dir(structure_type_spec)), structure_type_spec,
len(anon_tuple), container_type))
len(value), container_type))
elements = []
for index, (elem_name, elem_type) in enumerate(
structure.iter_elements(structure_type_spec)):
value = type_to_py_container(anon_tuple[index], elem_type)
element = type_to_py_container(value[index], elem_type)
if elem_name is None and not container_is_anon_tuple:
elements.append(value)
if elem_name is None:
elements.append(element)
else:
elements.append((elem_name, value))
elements.append((elem_name, element))
if (py_typecheck.is_named_tuple(container_type) or
py_typecheck.is_attrs(container_type) or
......
......@@ -516,7 +516,7 @@ class TypeFromTensorsTest(test_case.TestCase):
class TypeToPyContainerTest(test_case.TestCase):
def test_not_anon_tuple_passthrough(self):
def test_tuple_passthrough(self):
value = (1, 2.0)
result = type_conversions.type_to_py_container(
(1, 2.0),
......@@ -524,12 +524,27 @@ class TypeToPyContainerTest(test_case.TestCase):
container_type=list))
self.assertEqual(result, value)
def test_anon_tuple_return(self):
anon_tuple = structure.Struct([(None, 1), (None, 2.0)])
def test_represents_unnamed_fields_as_tuple(self):
input_value = structure.Struct([(None, 1), (None, 2.0)])
input_type = computation_types.StructType([tf.int32, tf.float32])
self.assertEqual(
type_conversions.type_to_py_container(
anon_tuple, computation_types.StructType([tf.int32, tf.float32])),
anon_tuple)
type_conversions.type_to_py_container(input_value, input_type),
(1, 2.0))
def test_represents_named_fields_as_odict(self):
input_value = structure.Struct([('a', 1), ('b', 2.0)])
input_type = computation_types.StructType([('a', tf.int32),
('b', tf.float32)])
self.assertEqual(
type_conversions.type_to_py_container(input_value, input_type),
collections.OrderedDict(a=1, b=2.0))
def test_raises_on_mixed_named_unnamed(self):
input_value = structure.Struct([('a', 1), (None, 2.0)])
input_type = computation_types.StructType([('a', tf.int32),
(None, tf.float32)])
with self.assertRaises(ValueError):
type_conversions.type_to_py_container(input_value, input_type)
def test_anon_tuple_without_names_to_container_without_names(self):
anon_tuple = structure.Struct([(None, 1), (None, 2.0)])
......@@ -575,11 +590,12 @@ class TypeToPyContainerTest(test_case.TestCase):