提交 befd5104 编辑于 作者: Zachary Garrett's avatar Zachary Garrett 提交者: tensorflow-copybara
浏览文件

Add a test for `tf.function` calls that return multiple datasets under the same parent op name.

This triggers a bug in the C++ runtime that fails to uniquely name the DatasetToGraphV2 JIT nodes.

PiperOrigin-RevId: 391626904
上级 139a123d
......@@ -142,6 +142,39 @@ class TensorFlowExecutorBindingsTest(parameterized.TestCase,
TensorType(sequence_type.element.dtype))
self.assertEqual(result, sum(range(5)))
def test_create_tuple_of_value_sequence(self):
self.skipTest('b/197147669')
datasets = (tf.data.Dataset.range(5), tf.data.Dataset.range(5))
executor = executor_bindings.create_tensorflow_executor()
struct_of_sequence_type = StructType([
(None, SequenceType(datasets[0].element_spec)),
(None, SequenceType(datasets[0].element_spec))
])
arg_value_pb, _ = value_serialization.serialize_value(
datasets, struct_of_sequence_type)
arg = executor.create_value(arg_value_pb)
@computations.tf_computation(struct_of_sequence_type)
def preprocess(datasets):
def double_value(x):
return 2 * x
@tf.function
def add_preprocessing(ds1, ds2):
return ds1.map(double_value), ds2.map(double_value)
return add_preprocessing(*datasets)
comp_pb = serialization_bindings.Value(
computation=preprocess.get_proto(preprocess))
comp = executor.create_value(comp_pb)
result = executor.create_call(comp.ref, arg.ref)
output_pb = executor.materialize(result.ref)
result, result_type_spec = value_serialization.deserialize_value(
output_pb, type_hint=struct_of_sequence_type)
self.assert_types_identical(result_type_spec, struct_of_sequence_type)
def test_create_struct(self):
executor = executor_bindings.create_tensorflow_executor()
expected_type_spec = TensorType(shape=[3], dtype=tf.int64)
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册