Commit a9e68ee5 authored by Keith Rush's avatar Keith Rush Committed by tensorflow-copybara
Browse files

Adds failing test for pulling out unbound reference across lambda when using sequence executor.

PiperOrigin-RevId: 414515103
parent 3a2e85c7
......@@ -487,6 +487,49 @@ class NonDeterministicTest(parameterized.TestCase):
self.assertNotEqual(first_random, second_random)
class SequenceExecutorIntegrationTest(parameterized.TestCase):
@test_contexts.with_context(
test_contexts.create_sequence_op_supporting_context)
def test_inlined_value_in_sequence_reduce(self):
@tff.tf_computation(tf.float32, tf.float32)
def add_floats(x, y):
return x + y
@tff.federated_computation(tff.SequenceType(tf.float32))
def sum_floats(sequence):
return tff.sequence_reduce(sequence, 0., add_floats)
ds = tf.data.Dataset.range(10).map(lambda x: tf.cast(x, tf.float32))
value = sum_floats(ds)
self.assertEqual(value, 10. * 9 / 2)
@test_contexts.with_context(
test_contexts.create_sequence_op_supporting_context)
def test_inlined_value_in_mapped_sequence_reduce(self):
@tff.tf_computation(tf.float32, tf.float32)
def add_floats(x, y):
return x + y
@tff.federated_computation(tff.SequenceType(tf.float32))
def sum_floats(sequence):
return tff.sequence_reduce(sequence, 0., add_floats)
@tff.federated_computation(
tff.FederatedType(sum_floats.type_signature.parameter, tff.SERVER))
def map_reduction(placed_sequence):
return tff.federated_map(sum_floats, placed_sequence)
ds = tf.data.Dataset.range(10).map(lambda x: tf.cast(x, tf.float32))
with self.assertRaises(RuntimeError):
# Raises due to b/208736990
map_reduction(ds)
class SizingExecutionContextTest(parameterized.TestCase):
@test_contexts.with_context(
......
......@@ -31,6 +31,14 @@ def _create_local_mergeable_comp_context():
return tff.backends.native.create_mergeable_comp_execution_context([factory])
def create_sequence_op_supporting_context():
executor_factory = tff.framework.local_executor_factory(
support_sequence_ops=True)
return tff.framework.ExecutionContext(
executor_fn=executor_factory,
compiler_fn=tff.backends.native.compiler.transform_to_native_form) # pytype: disable=wrong-arg-types
def _get_all_contexts():
"""Returns a list containing a (name, context_fn) tuple for each context."""
# pyformat: disable
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment