Skip to content
Snippets Groups Projects
Commit 42c367a6 authored by Keith Rush's avatar Keith Rush Committed by tensorflow-copybara
Browse files

Binds references to remaining call construction sites in `ValueImpl`, cleans...

Binds references to remaining call construction sites in `ValueImpl`, cleans up tests depending on old behavior.

PiperOrigin-RevId: 322394328
parent 12ab02e8
No related branches found
No related tags found
No related merge requests found
......@@ -154,5 +154,7 @@ py_test(
deps = [
":value_base",
":values",
"//tensorflow_federated/python/core/impl:federated_computation_context",
"//tensorflow_federated/python/core/impl/context_stack:context_stack_impl",
],
)
......@@ -16,10 +16,18 @@ from absl.testing import absltest
from tensorflow_federated.python.core.api import value_base
from tensorflow_federated.python.core.api import values
from tensorflow_federated.python.core.impl import federated_computation_context
from tensorflow_federated.python.core.impl.context_stack import context_stack_impl
class ValuesTest(absltest.TestCase):
def run(self, result=None):
fc_context = federated_computation_context.FederatedComputationContext(
context_stack_impl.context_stack)
with context_stack_impl.context_stack.install(fc_context):
super(ValuesTest, self).run(result=result)
# Note: No need to test all supported types, as those are already tested in
# the test of the underlying implementation (`value_impl_test.py`).
def test_to_value_with_int_constant(self):
......
......@@ -47,7 +47,7 @@ class ZeroOrOneArgFnToBuildingBlockTest(parameterized.TestCase):
lambda x: (x[1], x[0]),
computation_types.NamedTupleType([tf.int32, tf.int32]),
'(FEDERATED_foo -> <FEDERATED_foo[1],FEDERATED_foo[0]>)'),
('constant', lambda: 'stuff', None, '( -> comp#'))
('constant', lambda: 'stuff', None, '( -> (let fc_FEDERATED_symbol_0=comp#'))
# pyformat: enable
def test_zero_or_one_arg_fn_to_building_block(self, fn, parameter_type,
fn_str):
......
......@@ -146,9 +146,10 @@ class ValueImpl(value_base.Value, metaclass=abc.ABCMeta):
return
named_tuple_setattr_lambda = building_block_factory.create_named_tuple_setattr_lambda(
self._comp.type_signature, name, value_comp)
# TODO(b/159281959): Follow up and bind a reference here.
new_comp = building_blocks.Call(named_tuple_setattr_lambda, self._comp)
super().__setattr__('_comp', new_comp)
fc_context = self._context_stack.current
ref = fc_context.bind_computation_to_reference(new_comp)
super().__setattr__('_comp', ref)
def __bool__(self):
raise TypeError(
......@@ -236,17 +237,16 @@ class ValueImpl(value_base.Value, metaclass=abc.ABCMeta):
if not self.type_signature.is_equivalent_to(other.type_signature):
raise TypeError('Cannot add {} and {}.'.format(self.type_signature,
other.type_signature))
# TODO(b/159281959): Follow up and bind a reference here.
return ValueImpl(
building_blocks.Call(
building_blocks.Intrinsic(
intrinsic_defs.GENERIC_PLUS.uri,
computation_types.FunctionType(
[self.type_signature, self.type_signature],
self.type_signature)),
ValueImpl.get_comp(
to_value([self, other], None, self._context_stack))),
self._context_stack)
call = building_blocks.Call(
building_blocks.Intrinsic(
intrinsic_defs.GENERIC_PLUS.uri,
computation_types.FunctionType(
[self.type_signature, self.type_signature],
self.type_signature)),
ValueImpl.get_comp(to_value([self, other], None, self._context_stack)))
fc_context = self._context_stack.current
ref = fc_context.bind_computation_to_reference(call)
return ValueImpl(ref, self._context_stack)
def _wrap_constant_as_value(const, context_stack):
......@@ -264,9 +264,10 @@ def _wrap_constant_as_value(const, context_stack):
tf_comp, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
lambda: tf.constant(const), None, context_stack)
compiled_comp = building_blocks.CompiledComputation(tf_comp)
# TODO(b/159281959): Follow up and bind a reference here.
called_comp = building_blocks.Call(compiled_comp)
return ValueImpl(called_comp, context_stack)
fc_context = context_stack.current
ref = fc_context.bind_computation_to_reference(called_comp)
return ValueImpl(ref, context_stack)
def _wrap_sequence_as_value(elements, element_type, context_stack):
......@@ -305,10 +306,10 @@ def _wrap_sequence_as_value(elements, element_type, context_stack):
# Wraps the dataset as a value backed by a no-argument TensorFlow computation.
tf_comp, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
_create_dataset_from_elements, None, context_stack)
# TODO(b/159281959): Follow up and bind a reference here.
return ValueImpl(
building_blocks.Call(building_blocks.CompiledComputation(tf_comp)),
context_stack)
call = building_blocks.Call(building_blocks.CompiledComputation(tf_comp))
fc_context = context_stack.current
ref = fc_context.bind_computation_to_reference(call)
return ValueImpl(ref, context_stack)
def _dictlike_items_to_value(items, context_stack, container_type) -> ValueImpl:
......@@ -362,7 +363,6 @@ def to_value(
are encountered, as TensorFlow code should be sealed away from TFF
federated context.
"""
# TODO(b/159281959): Follow up and bind references here where appropriate.
py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
if type_spec is not None:
type_spec = computation_types.to_type(type_spec)
......
......@@ -161,7 +161,12 @@ class ValueImplTest(parameterized.TestCase):
z = x + y
self.assertIsInstance(z, value_base.Value)
self.assertEqual(str(z.type_signature), 'int32')
self.assertEqual(str(z), 'generic_plus(<x,y>)')
self.assertEqual(str(z), 'fc_FEDERATED_symbol_0')
bindings = value_impl.ValueImpl.get_context_stack(z).current.symbol_bindings
self.assertLen(bindings, 1)
name, comp = bindings[0]
self.assertEqual(name, 'fc_FEDERATED_symbol_0')
self.assertEqual(comp.compact_representation(), 'generic_plus(<x,y>)')
def test_to_value_for_tuple(self):
x = value_impl.ValueImpl(
......@@ -387,7 +392,7 @@ class ValueImplTest(parameterized.TestCase):
_ = v[2:4:-1]
@parameterized.named_parameters(('list', list), ('tuple', tuple))
def test_slicing_tuple_values(self, sequence_type):
def test_slicing_tuple_values_from_front(self, sequence_type):
def _to_value(cbb):
return value_impl.to_value(cbb, None, context_stack_impl.context_stack)
......@@ -402,15 +407,89 @@ class ValueImplTest(parameterized.TestCase):
sliced = v[:2]
self.assertEqual((str(sliced.type_signature)), '<int32,int32>')
self.assertEqual(str(sliced), '<comp#1(),comp#2()>')
self.assertEqual(
str(sliced), '<fc_FEDERATED_symbol_0,fc_FEDERATED_symbol_1>')
expected_symbol_bindings = [
('fc_FEDERATED_symbol_0', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_1', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_2', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_3', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_4', [r'comp#[a-zA-Z0-9]*()']),
]
bindings = value_impl.ValueImpl.get_context_stack(
sliced).current.symbol_bindings
for (bound_name, comp), (expected_name,
expected_regex) in zip(bindings,
expected_symbol_bindings):
self.assertEqual(bound_name, expected_name)
self.assertRegexMatch(comp.compact_representation(), expected_regex)
@parameterized.named_parameters(('list', list), ('tuple', tuple))
def test_slicing_tuple_values_from_back(self, sequence_type):
def _to_value(cbb):
return value_impl.to_value(cbb, None, context_stack_impl.context_stack)
t = sequence_type(range(0, 50, 10))
v = _to_value(t)
self.assertEqual((str(v.type_signature)), '<int32,int32,int32,int32,int32>')
self.assertEqual(str(v[:]), str(v))
sliced = v[-3:]
self.assertEqual((str(sliced.type_signature)), '<int32,int32,int32>')
self.assertEqual(str(sliced), '<comp#3(),comp#4(),comp#5()>')
self.assertEqual(
str(sliced),
'<fc_FEDERATED_symbol_2,fc_FEDERATED_symbol_3,fc_FEDERATED_symbol_4>')
expected_symbol_bindings = [
('fc_FEDERATED_symbol_0', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_1', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_2', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_3', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_4', [r'comp#[a-zA-Z0-9]*()']),
]
bindings = value_impl.ValueImpl.get_context_stack(
sliced).current.symbol_bindings
for (bound_name, comp), (expected_name,
expected_regex) in zip(bindings,
expected_symbol_bindings):
self.assertEqual(bound_name, expected_name)
self.assertRegexMatch(comp.compact_representation(), expected_regex)
@parameterized.named_parameters(('list', list), ('tuple', tuple))
def test_slicing_tuple_values_skipping_steps(self, sequence_type):
def _to_value(val):
return value_impl.to_value(val, None, context_stack_impl.context_stack)
t = sequence_type(range(0, 50, 10))
v = _to_value(t)
sliced = v[::2]
self.assertEqual((str(sliced.type_signature)), '<int32,int32,int32>')
self.assertEqual(str(sliced), '<comp#1(),comp#3(),comp#5()>')
self.assertEqual(
str(sliced),
'<fc_FEDERATED_symbol_0,fc_FEDERATED_symbol_2,fc_FEDERATED_symbol_4>')
expected_symbol_bindings = [
('fc_FEDERATED_symbol_0', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_1', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_2', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_3', [r'comp#[a-zA-Z0-9]*()']),
('fc_FEDERATED_symbol_4', [r'comp#[a-zA-Z0-9]*()']),
]
bindings = value_impl.ValueImpl.get_context_stack(
sliced).current.symbol_bindings
for (bound_name, comp), (expected_name,
expected_regex) in zip(bindings,
expected_symbol_bindings):
self.assertEqual(bound_name, expected_name)
self.assertRegexMatch(comp.compact_representation(), expected_regex)
def test_getitem_resolution_federated_value_clients(self):
federated_value = value_impl.to_value(
......@@ -472,7 +551,6 @@ class ValueImplTest(parameterized.TestCase):
str(federated_value.type_signature), '<a=int32,b=bool>@SERVER')
federated_attribute = federated_value['a']
self.assertEqual(str(federated_attribute.type_signature), 'int32@SERVER')
print(repr(federated_value))
with self.assertRaises(ValueError):
_ = federated_value['badkey']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment