提交 5965c17c 编辑于 作者: Taylor Cramer's avatar Taylor Cramer 提交者: tensorflow-copybara
浏览文件

Refactor tensorflow_computation_factory to reuse graph creation

PiperOrigin-RevId: 343905109
上级 ba869ea1
......@@ -13,6 +13,7 @@
# limitations under the License.
"""A library of contruction functions for tensorflow computation structures."""
import functools
import types
from typing import Any, Callable, Optional, Tuple
......@@ -281,17 +282,7 @@ def create_empty_tuple() -> ProtoAndType:
The returned computation has the type signature `( -> <>)`.
"""
with tf.Graph().as_default() as graph:
result_type, result_binding = tensorflow_utils.capture_result_from_graph(
structure.Struct([]), graph)
type_signature = computation_types.FunctionType(None, result_type)
tensorflow = pb.TensorFlow(
graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
parameter=None,
result=result_binding)
return _tensorflow_comp(tensorflow, type_signature)
return create_computation_for_py_fn(lambda: structure.Struct([]), None)
def create_identity(type_signature: computation_types.Type) -> ProtoAndType:
......@@ -315,23 +306,16 @@ def create_identity(type_signature: computation_types.Type) -> ProtoAndType:
if parameter_type is None:
raise TypeError('TensorFlow identity cannot be created for NoneType.')
with tf.Graph().as_default() as graph:
parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
'x', parameter_type, graph)
# TF relies on feeds not-identical to fetches in certain circumstances.
if type_signature.is_tensor():
parameter_value = tf.identity(parameter_value)
elif type_signature.is_struct():
parameter_value = structure.map_structure(tf.identity, parameter_value)
result_type, result_binding = tensorflow_utils.capture_result_from_graph(
parameter_value, graph)
# TF relies on feeds not-identical to fetches in certain circumstances.
if type_signature.is_tensor() or type_signature.is_sequence():
identity_fn = tf.identity
elif type_signature.is_struct():
identity_fn = functools.partial(structure.map_structure, tf.identity)
else:
raise NotImplementedError(
f'TensorFlow identity cannot be created for type {type_signature}')
type_signature = computation_types.FunctionType(parameter_type, result_type)
tensorflow = pb.TensorFlow(
graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
parameter=parameter_binding,
result=result_binding)
return _tensorflow_comp(tensorflow, type_signature)
return create_computation_for_py_fn(identity_fn, parameter_type)
def create_replicate_input(type_signature: computation_types.Type,
......@@ -352,20 +336,7 @@ def create_replicate_input(type_signature: computation_types.Type,
type_analysis.check_tensorflow_compatible_type(type_signature)
py_typecheck.check_type(count, int)
parameter_type = type_signature
with tf.Graph().as_default() as graph:
parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
'x', parameter_type, graph)
result = [parameter_value] * count
result_type, result_binding = tensorflow_utils.capture_result_from_graph(
result, graph)
type_signature = computation_types.FunctionType(parameter_type, result_type)
tensorflow = pb.TensorFlow(
graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
parameter=parameter_binding,
result=result_binding)
return _tensorflow_comp(tensorflow, type_signature)
return create_computation_for_py_fn(lambda v: [v] * count, parameter_type)
def create_computation_for_py_fn(
......@@ -380,7 +351,6 @@ def create_computation_for_py_fn(
fn: A Python function.
parameter_type: A `computation_types.Type` or `None`.
"""
py_typecheck.check_type(fn, types.FunctionType)
if parameter_type is not None:
py_typecheck.check_type(parameter_type, computation_types.Type)
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册