Commit 6a24b11d authored by Keith Rush's avatar Keith Rush Committed by tensorflow-copybara
Browse files

Adds check that computation arguments in TF executor are in fact TensorFlow.

PiperOrigin-RevId: 413813975
parent 9a828c2e
......@@ -493,6 +493,11 @@ def to_representation_for_type(
computation_impl.ConcreteComputation.get_proto(value),
tf_function_cache, type_spec, device)
elif isinstance(value, pb.Computation):
computation_oneof = value.WhichOneof('computation')
if computation_oneof != 'tensorflow':
raise ValueError('Eager TF Executor can only execute computations of '
'TensorFlow flavor; encountered a computation of type '
f'{computation_oneof}')
return _to_computation_internal_rep(
value=value,
tf_function_cache=tf_function_cache,
......
......@@ -432,6 +432,17 @@ class EagerTFExecutorTest(test_case.TestCase, parameterized.TestCase):
self.assertEqual(str(val.type_signature), 'int32')
self.assertEqual(val.internal_representation, 10)
def test_executor_create_value_raises_on_lambda(self):
ex = eager_tf_executor.EagerTFExecutor()
@computations.federated_computation(tf.int32)
def comp(x):
return x
with self.assertRaisesRegex(ValueError, 'computation of type lambda'):
asyncio.get_event_loop().run_until_complete(
ex.create_value(comp.to_building_block().proto, comp.type_signature))
def test_executor_create_value_struct_mismatched_type(self):
ex = eager_tf_executor.EagerTFExecutor()
with self.assertRaises(TypeError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment