Skip to content
Snippets Groups Projects
Commit 594460f8 authored by Keith Rush's avatar Keith Rush Committed by tensorflow-copybara
Browse files

Opens cardinality inference to non-list Python types for federated arguments.

Some standard use cases (e.g. zip) give tuples as results.

PiperOrigin-RevId: 268546571
parent 37927c96
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import six
from tensorflow_federated.python.common_libs import anonymous_tuple
......@@ -53,8 +55,9 @@ def merge_cardinalities(existing, to_add):
def infer_cardinalities(value, type_spec):
"""Infers cardinalities from Python `value`.
Codifies the TFF convention that federated types which are not declared to be
all-equal must be represented before ingestion at the Python level as a list.
Allows for any Python object to represent a federated value; enforcing
particular representations is not the job of this inference function, but
rather ingestion functions lower in the stack.
Args:
value: Python object from which to infer TFF placement cardinalities.
......@@ -76,7 +79,7 @@ def infer_cardinalities(value, type_spec):
if isinstance(type_spec, computation_types.FederatedType):
if type_spec.all_equal:
return {}
py_typecheck.check_type(value, list)
py_typecheck.check_type(value, collections.Sized)
return {type_spec.placement: len(value)}
elif isinstance(type_spec, computation_types.NamedTupleType):
anonymous_tuple_value = anonymous_tuple.from_container(
......
......@@ -38,12 +38,32 @@ class InferCardinalitiesTest(absltest.TestCase):
cardinalities = runtime_utils.infer_cardinalities(1, int_type)
self.assertEmpty(cardinalities)
def test_raises_federated_type_non_list(self):
def test_raises_federated_type_integer(self):
federated_type = computation_types.FederatedType(
tf.int32, placement_literals.CLIENTS, all_equal=False)
with self.assertRaises(TypeError):
runtime_utils.infer_cardinalities(1, federated_type)
def test_raises_federated_type_generator(self):
def generator_fn():
yield 1
generator = generator_fn()
federated_type = computation_types.FederatedType(
tf.int32, placement_literals.CLIENTS, all_equal=False)
with self.assertRaises(TypeError):
runtime_utils.infer_cardinalities(generator, federated_type)
def test_passes_federated_type_tuple(self):
tup = tuple(range(5))
federated_type = computation_types.FederatedType(
tf.int32, placement_literals.CLIENTS, all_equal=False)
runtime_utils.infer_cardinalities(tup, federated_type)
five_client_cardinalities = runtime_utils.infer_cardinalities(
tup, federated_type)
self.assertEqual(five_client_cardinalities[placement_literals.CLIENTS], 5)
def test_adds_list_length_as_cardinality_at_clients(self):
federated_type = computation_types.FederatedType(
tf.int32, placement_literals.CLIENTS, all_equal=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment