Commit 7681cec7 authored by Keith Rush's avatar Keith Rush Committed by tensorflow-copybara
Browse files

Automatically zips if necessary on ingestion into federated context.

PiperOrigin-RevId: 413745813
parent 0d736a45
...@@ -74,6 +74,7 @@ py_library( ...@@ -74,6 +74,7 @@ py_library(
"//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/impl/types:type_conversions", "//tensorflow_federated/python/core/impl/types:type_conversions",
"//tensorflow_federated/python/core/impl/types:type_serialization", "//tensorflow_federated/python/core/impl/types:type_serialization",
"//tensorflow_federated/python/core/impl/types:type_transformations",
"//tensorflow_federated/python/core/impl/utils:tensorflow_utils", "//tensorflow_federated/python/core/impl/utils:tensorflow_utils",
], ],
) )
......
...@@ -33,6 +33,7 @@ from tensorflow_federated.python.core.impl.types import placements ...@@ -33,6 +33,7 @@ from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.impl.types import type_analysis
from tensorflow_federated.python.core.impl.types import type_conversions from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.impl.types import type_serialization from tensorflow_federated.python.core.impl.types import type_serialization
from tensorflow_federated.python.core.impl.types import type_transformations
from tensorflow_federated.python.core.impl.utils import tensorflow_utils from tensorflow_federated.python.core.impl.utils import tensorflow_utils
...@@ -1997,3 +1998,115 @@ def apply_binary_operator_with_upcast( ...@@ -1997,3 +1998,115 @@ def apply_binary_operator_with_upcast(
called = building_blocks.Call(tf_representing_op, arg) called = building_blocks.Call(tf_representing_op, arg)
return called return called
def zip_to_match_type(
*, comp_to_zip: building_blocks.ComputationBuildingBlock,
target_type: computation_types.Type
) -> Optional[building_blocks.ComputationBuildingBlock]:
"""Zips computation argument to match target type.
This function will apply the appropriate federated zips to match `comp_to_zip`
to the requested type `target_type`, subject to a few caveats. We will
traverse `computation_types.StructTypes` to match types, so for example we
would zip `<<T@P, R@P>>` to match `<<T, R>@P>`, but we will not traverse
`computation_types.FunctionTypes`. Therefore we would not apply a zip to the
parameter of `(<<T@P, R@P>> -> Q)` to match (<<T, R>@P> -> Q).
If zipping in this manner cannot match the type of `comp_to_zip` to
`target_type`, `None` will be returned.
Args:
comp_to_zip: Instance of `building_blocks.ComputationBuildingBlock` to
traverse and attempt to zip to match `target_type`.
target_type: The type to target when traversing and zipping `comp_to_zip`.
Returns:
Either a potentially transformed version of `comp_to_zip` or `None`,
depending on whether inserting a zip according to the semantics above
can transformed `comp_to_zip` to the requested type.
"""
py_typecheck.check_type(comp_to_zip, building_blocks.ComputationBuildingBlock)
py_typecheck.check_type(target_type, computation_types.Type)
def _can_be_zipped_into(source_type: computation_types.Type,
target_type: computation_types.Type) -> bool:
"""Indicates possibility of the transformation `zip_to_match_type`."""
def _struct_can_be_zipped_to_federated(
struct_type: computation_types.StructType,
federated_type: computation_types.FederatedType) -> bool:
placements_encountered = set()
def _remove_placement(
subtype: computation_types.Type
) -> Tuple[computation_types.Type, bool]:
if subtype.is_federated():
placements_encountered.add(subtype.placement)
return subtype.member, True
return subtype, False
unplaced_struct, _ = type_transformations.transform_type_postorder(
struct_type, _remove_placement)
if not (all(
x is federated_type.placement for x in placements_encountered)):
return False
if (federated_type.placement is placements.CLIENTS and
federated_type.all_equal):
# There is no all-equal clients zip; return false.
return False
return federated_type.member.is_assignable_from(unplaced_struct)
def _struct_elem_zippable(source_name, source_element, target_name,
target_element):
return _can_be_zipped_into(
source_element, target_element) and source_name in (target_name, None)
if source_type.is_struct():
if target_type.is_federated():
return _struct_can_be_zipped_to_federated(source_type, target_type)
elif target_type.is_struct():
elements_zippable = []
for (s_name, s_el), (t_name, t_el) in zip(
structure.iter_elements(source_type),
structure.iter_elements(target_type)):
elements_zippable.append(
_struct_elem_zippable(s_name, s_el, t_name, t_el))
return all(elements_zippable)
else:
return target_type.is_assignable_from(source_type)
def _zip_to_match(
*, source: building_blocks.ComputationBuildingBlock,
target_type: computation_types.Type
) -> building_blocks.ComputationBuildingBlock:
if target_type.is_federated() and source.type_signature.is_struct():
return create_federated_zip(source)
elif target_type.is_struct() and source.type_signature.is_struct():
zipped_elements = []
# Bind a reference to the source to prevent duplication in the AST.
ref_name = next(unique_name_generator(source))
ref_to_source = building_blocks.Reference(ref_name, source.type_signature)
for idx, ((_, t_el), (s_name, _)) in enumerate(
zip(
structure.iter_elements(target_type),
structure.iter_elements(source.type_signature))):
s_selection = building_blocks.Selection(ref_to_source, index=idx)
zipped_elements.append(
(s_name, _zip_to_match(source=s_selection, target_type=t_el)))
# Insert binding above the constructed structure.
return building_blocks.Block([(ref_name, source)],
building_blocks.Struct(zipped_elements))
else:
# No zipping to be done here.
return source
if target_type.is_assignable_from(comp_to_zip.type_signature):
# No zipping needs to be done; return directly.
return comp_to_zip
elif _can_be_zipped_into(comp_to_zip.type_signature, target_type):
return _zip_to_match(source=comp_to_zip, target_type=target_type)
else:
# Zipping cannot be performed here.
return None
...@@ -1932,5 +1932,117 @@ class SelectOutputFromLambdaTest(test_case.TestCase): ...@@ -1932,5 +1932,117 @@ class SelectOutputFromLambdaTest(test_case.TestCase):
self.assertEqual(str(tuple_selected), '(x -> <x.a.inner,x.b>)') self.assertEqual(str(tuple_selected), '(x -> <x.a.inner,x.b>)')
class ZipUpToTest(test_case.TestCase):
def test_zips_struct_of_federated_values(self):
comp = building_blocks.Struct([
building_blocks.Reference(
'x', computation_types.FederatedType(tf.int32, placements.CLIENTS)),
building_blocks.Reference(
'y', computation_types.FederatedType(tf.int32, placements.CLIENTS))
])
zippable_type = computation_types.FederatedType(
computation_types.StructType([(None, tf.int32), (None, tf.int32)]),
placements.CLIENTS)
zipped = building_block_factory.zip_to_match_type(
comp_to_zip=comp, target_type=zippable_type)
self.assert_types_equivalent(zipped.type_signature, zippable_type)
def test_does_not_zip_different_placement_target(self):
comp = building_blocks.Struct([
building_blocks.Reference(
'x', computation_types.FederatedType(tf.int32, placements.CLIENTS)),
building_blocks.Reference(
'y', computation_types.FederatedType(tf.int32, placements.CLIENTS))
])
non_zippable_type = computation_types.FederatedType(
computation_types.StructType([(None, tf.int32), (None, tf.int32)]),
placements.SERVER)
zipped = building_block_factory.zip_to_match_type(
comp_to_zip=comp, target_type=non_zippable_type)
self.assertIsNone(zipped)
def test_zips_struct_of_federated_values_under_struct(self):
comp = building_blocks.Struct([
building_blocks.Struct([
building_blocks.Reference(
'x',
computation_types.FederatedType(tf.int32, placements.CLIENTS)),
building_blocks.Reference(
'y',
computation_types.FederatedType(tf.int32, placements.CLIENTS))
])
])
zippable_type = computation_types.StructType([
(None,
computation_types.FederatedType(
computation_types.StructType([(None, tf.int32), (None, tf.int32)]),
placements.CLIENTS))
])
zipped = building_block_factory.zip_to_match_type(
comp_to_zip=comp, target_type=zippable_type)
self.assert_types_equivalent(zipped.type_signature, zippable_type)
def test_assignability_with_names(self):
# This would correspond to an implicit downcast in TFF's typesystem; the
# result would not be assignable to the requested type.
comp = building_blocks.Struct([
building_blocks.Struct([
('a',
building_blocks.Reference(
'x',
computation_types.FederatedType(tf.int32,
placements.CLIENTS))),
('b',
building_blocks.Reference(
'y',
computation_types.FederatedType(tf.int32, placements.CLIENTS)))
])
])
unnamed_zippable_type = computation_types.StructType([
(None,
computation_types.FederatedType(
computation_types.StructType([(None, tf.int32), (None, tf.int32)]),
placements.CLIENTS))
])
named_zippable_type = computation_types.StructType([
(None,
computation_types.FederatedType(
computation_types.StructType([('a', tf.int32), ('b', tf.int32)]),
placements.CLIENTS))
])
not_zipped = building_block_factory.zip_to_match_type(
comp_to_zip=comp, target_type=unnamed_zippable_type)
zipped = building_block_factory.zip_to_match_type(
comp_to_zip=comp, target_type=named_zippable_type)
self.assertFalse(
unnamed_zippable_type.is_assignable_from(named_zippable_type))
self.assertIsNone(not_zipped)
self.assert_types_equivalent(zipped.type_signature, named_zippable_type)
def test_does_not_zip_under_function(self):
result_comp = building_blocks.Struct([
building_blocks.Reference(
'x', computation_types.FederatedType(tf.int32, placements.CLIENTS)),
building_blocks.Reference(
'y', computation_types.FederatedType(tf.int32, placements.CLIENTS))
])
lam = building_blocks.Lambda(None, None, result_comp)
zippable_function_type = computation_types.FunctionType(
None,
computation_types.FederatedType(
computation_types.StructType([(None, tf.int32), (None, tf.int32)]),
placements.CLIENTS))
zipped = building_block_factory.zip_to_match_type(
comp_to_zip=lam, target_type=zippable_function_type)
self.assertIsNone(zipped)
if __name__ == '__main__': if __name__ == '__main__':
test_case.main() test_case.main()
...@@ -74,8 +74,11 @@ py_test( ...@@ -74,8 +74,11 @@ py_test(
":federated_computation_context", ":federated_computation_context",
":value_impl", ":value_impl",
"//tensorflow_federated/python/core/api:computations", "//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/impl/compiler:building_blocks", "//tensorflow_federated/python/core/impl/compiler:building_blocks",
"//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
], ],
) )
......
...@@ -98,12 +98,11 @@ class FederatedComputationContext(symbol_binding_context.SymbolBindingContext): ...@@ -98,12 +98,11 @@ class FederatedComputationContext(symbol_binding_context.SymbolBindingContext):
return self._symbol_bindings return self._symbol_bindings
def ingest(self, val, type_spec): def ingest(self, val, type_spec):
val = value_impl.to_value(val, type_spec, self._context_stack) val = value_impl.to_value(val, type_spec, zip_if_needed=True)
type_analysis.check_type(val, type_spec)
return val return val
def invoke(self, comp, arg): def invoke(self, comp, arg):
fn = value_impl.to_value(comp, None, self._context_stack) fn = value_impl.to_value(comp, None)
tys = fn.type_signature tys = fn.type_signature
py_typecheck.check_type(tys, computation_types.FunctionType) py_typecheck.check_type(tys, computation_types.FunctionType)
if arg is not None: if arg is not None:
......
...@@ -12,17 +12,19 @@ ...@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from absl.testing import absltest
import tensorflow as tf import tensorflow as tf
from tensorflow_federated.python.core.api import computations from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.impl.compiler import building_blocks from tensorflow_federated.python.core.impl.compiler import building_blocks
from tensorflow_federated.python.core.impl.context_stack import context_stack_impl from tensorflow_federated.python.core.impl.context_stack import context_stack_impl
from tensorflow_federated.python.core.impl.federated_context import federated_computation_context from tensorflow_federated.python.core.impl.federated_context import federated_computation_context
from tensorflow_federated.python.core.impl.federated_context import value_impl from tensorflow_federated.python.core.impl.federated_context import value_impl
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
class FederatedComputationContextTest(absltest.TestCase): class FederatedComputationContextTest(test_case.TestCase):
def test_invoke_returns_value_with_correct_type(self): def test_invoke_returns_value_with_correct_type(self):
context = federated_computation_context.FederatedComputationContext( context = federated_computation_context.FederatedComputationContext(
...@@ -33,6 +35,46 @@ class FederatedComputationContextTest(absltest.TestCase): ...@@ -33,6 +35,46 @@ class FederatedComputationContextTest(absltest.TestCase):
self.assertIsInstance(result, value_impl.Value) self.assertIsInstance(result, value_impl.Value)
self.assertEqual(str(result.type_signature), 'int32') self.assertEqual(str(result.type_signature), 'int32')
def test_ingest_zips_value_when_necessary_to_match_federated_type(self):
context = federated_computation_context.FederatedComputationContext(
context_stack_impl.context_stack)
# This thing will be <{int}@C, {int}@C>
comp = building_blocks.Struct([
building_blocks.Reference(
'x', computation_types.FederatedType(tf.int32, placements.CLIENTS)),
building_blocks.Reference(
'y', computation_types.FederatedType(tf.int32, placements.CLIENTS))
])
# The type of comp can be zipped to the below.
zippable_type = computation_types.FederatedType(
computation_types.StructType([(None, tf.int32), (None, tf.int32)]),
placements.CLIENTS)
ingested = context.ingest(comp, type_spec=zippable_type)
self.assert_types_equivalent(ingested.type_signature, zippable_type)
def test_ingest_zips_federated_under_struct(self):
context = federated_computation_context.FederatedComputationContext(
context_stack_impl.context_stack)
comp = building_blocks.Struct([
building_blocks.Struct([
building_blocks.Reference(
'x',
computation_types.FederatedType(tf.int32, placements.CLIENTS)),
building_blocks.Reference(
'y',
computation_types.FederatedType(tf.int32, placements.CLIENTS))
])
])
# The type of comp can be zipped to the below.
zippable_type = computation_types.StructType([
(None,
computation_types.FederatedType(
computation_types.StructType([(None, tf.int32), (None, tf.int32)]),
placements.CLIENTS))
])
ingested = context.ingest(comp, type_spec=zippable_type)
self.assert_types_equivalent(ingested.type_signature, zippable_type)
def test_construction_populates_name(self): def test_construction_populates_name(self):
context = federated_computation_context.FederatedComputationContext( context = federated_computation_context.FederatedComputationContext(
context_stack_impl.context_stack) context_stack_impl.context_stack)
...@@ -94,4 +136,4 @@ class FederatedComputationContextTest(absltest.TestCase): ...@@ -94,4 +136,4 @@ class FederatedComputationContextTest(absltest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() test_case.main()
...@@ -298,6 +298,8 @@ def to_value( ...@@ -298,6 +298,8 @@ def to_value(
arg: Any, arg: Any,
type_spec, type_spec,
parameter_type_hint=None, parameter_type_hint=None,
*,
zip_if_needed: bool = False,
) -> Value: ) -> Value:
"""Converts the argument into an instance of the abstract class `tff.Value`. """Converts the argument into an instance of the abstract class `tff.Value`.
...@@ -339,6 +341,9 @@ def to_value( ...@@ -339,6 +341,9 @@ def to_value(
parameter_type_hint: An optional `tff.Type` or value convertible to it by parameter_type_hint: An optional `tff.Type` or value convertible to it by
`tff.to_type()` which specifies an argument type to use in the case that `tff.to_type()` which specifies an argument type to use in the case that
`arg` is a `function_utils.PolymorphicComputation`. `arg` is a `function_utils.PolymorphicComputation`.
zip_if_needed: If `True`, attempt to coerce the result of `to_value` to
match `type_spec` by applying `intrinsics.federated_zip` to appropriate
elements.
Returns: Returns:
An instance of `tff.Value` as described above. An instance of `tff.Value` as described above.
...@@ -409,6 +414,12 @@ def to_value( ...@@ -409,6 +414,12 @@ def to_value(
py_typecheck.check_type(result, Value) py_typecheck.check_type(result, Value)
if (type_spec is not None and if (type_spec is not None and
not type_spec.is_assignable_from(result.type_signature)): not type_spec.is_assignable_from(result.type_signature)):
if zip_if_needed:
# Returns `None` if such a zip can't be performed.
zipped_comp = building_block_factory.zip_to_match_type(
comp_to_zip=result.comp, target_type=type_spec)
if zipped_comp is not None:
return Value(zipped_comp)
raise TypeError( raise TypeError(
'The supplied argument maps to TFF type {}, which is incompatible with ' 'The supplied argument maps to TFF type {}, which is incompatible with '
'the requested type {}.'.format(result.type_signature, type_spec)) 'the requested type {}.'.format(result.type_signature, type_spec))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment