提交 b272adc4 编辑于 作者: Zachary Garrett's avatar Zachary Garrett 提交者: tensorflow-copybara
浏览文件

Uniquely name JIT'd ops added for DatasetToGraphV2 serialization.

This resolves a bug where the C++ will fail if a `tff.Computation` is returning
multiple variant tensors from the same node (e.g. a `PartitionedCall` op).

PiperOrigin-RevId: 391638014
上级 befd5104
......@@ -217,11 +217,11 @@ struct NamesForBindingRewrite {
// wrapping sequence bindings in datset serialization ops.
NamesForBindingRewrite GetVariantTensorNodeNameAndReplacement(
absl::string_view variant_tensor_name,
absl::string_view replace_node_suffix) {
absl::string_view replace_node_suffix, absl::string_view node_prefix) {
NamesForBindingRewrite names;
names.variant_node_name = GetNodeName(variant_tensor_name);
names.graph_def_node_name =
absl::StrCat(names.variant_node_name, "/", replace_node_suffix);
names.graph_def_node_name = absl::StrCat(
node_prefix, "/", names.variant_node_name, "/", replace_node_suffix);
names.graph_def_tensor_name = absl::StrCat(names.graph_def_node_name, ":0");
return names;
}
......@@ -257,7 +257,8 @@ NamesForBindingRewrite GetVariantTensorNodeNameAndReplacement(
// the reseverse of `AddSerializationOpsForResults`, which is used on the result
// bindings of the function.
absl::Status AddDeserializationOpsForParameters(
tensorflow::GraphDef& graphdef_pb, v0::TensorFlow::Binding& binding) {
tensorflow::GraphDef& graphdef_pb, v0::TensorFlow::Binding& binding,
absl::string_view prefix = "root") {
switch (binding.binding_case()) {
case v0::TensorFlow::Binding::kSequence: {
// Get the name of the placeholder we're operating on and create a name
......@@ -265,7 +266,7 @@ absl::Status AddDeserializationOpsForParameters(
const std::string& variant_tensor_name =
binding.sequence().variant_tensor_name();
auto graph_names = GetVariantTensorNodeNameAndReplacement(
variant_tensor_name, kDatasetFromGraphOp);
variant_tensor_name, kDatasetFromGraphOp, prefix);
for (tensorflow::NodeDef& node_pb : *graphdef_pb.mutable_node()) {
// Change the placeholder op from variant to string, this will now
// be a placeholder for a serialized graphdef bytes.
......@@ -308,8 +309,10 @@ absl::Status AddDeserializationOpsForParameters(
return absl::OkStatus();
}
case v0::TensorFlow::Binding::kStruct: {
for (auto& member : *binding.mutable_struct_()->mutable_element()) {
TFF_TRY(AddDeserializationOpsForParameters(graphdef_pb, member));
for (int i = 0; i < binding.struct_().element_size(); ++i) {
auto& member = *binding.mutable_struct_()->mutable_element(i);
TFF_TRY(AddDeserializationOpsForParameters(
graphdef_pb, member, absl::StrCat(prefix, "/", i)));
}
return absl::OkStatus();
}
......@@ -351,7 +354,8 @@ absl::Status AddDeserializationOpsForParameters(
// the reseverse of `AddDeserializationOpsForParameters`, which is used on the
// parameter bindings of the function.
absl::Status AddSerializationOpsForResults(tensorflow::GraphDef& graphdef_pb,
v0::TensorFlow::Binding& binding) {
v0::TensorFlow::Binding& binding,
absl::string_view prefix = "root") {
switch (binding.binding_case()) {
case v0::TensorFlow::Binding::kSequence: {
if (binding.sequence().binding_case() ==
......@@ -364,7 +368,7 @@ absl::Status AddSerializationOpsForResults(tensorflow::GraphDef& graphdef_pb,
const std::string& variant_tensor_name =
binding.sequence().variant_tensor_name();
auto graph_names = GetVariantTensorNodeNameAndReplacement(
variant_tensor_name, kDatasetToGraphOp);
variant_tensor_name, kDatasetToGraphOp, prefix);
// We only need to add a new node to the graph and update the binding,
// since we'll depend on what is already in the graph.
tensorflow::NodeDef* graph_from_dataset_node = graphdef_pb.add_node();
......@@ -388,8 +392,10 @@ absl::Status AddSerializationOpsForResults(tensorflow::GraphDef& graphdef_pb,
return absl::OkStatus();
}
case v0::TensorFlow::Binding::kStruct: {
for (auto& member : *binding.mutable_struct_()->mutable_element()) {
TFF_TRY(AddSerializationOpsForResults(graphdef_pb, member));
for (int i = 0; i < binding.struct_().element_size(); ++i) {
auto& member = *binding.mutable_struct_()->mutable_element(i);
TFF_TRY(AddSerializationOpsForResults(graphdef_pb, member,
absl::StrCat(prefix, "/", i)));
}
return absl::OkStatus();
}
......
......@@ -143,7 +143,6 @@ class TensorFlowExecutorBindingsTest(parameterized.TestCase,
self.assertEqual(result, sum(range(5)))
def test_create_tuple_of_value_sequence(self):
self.skipTest('b/197147669')
datasets = (tf.data.Dataset.range(5), tf.data.Dataset.range(5))
executor = executor_bindings.create_tensorflow_executor()
struct_of_sequence_type = StructType([
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册