提交 dc585975 编辑于 作者: Keith Rush's avatar Keith Rush 提交者: tensorflow-copybara
浏览文件

Refactors handling of references in remove duplicate building blocks.

PiperOrigin-RevId: 345530794
上级 99036cc9
......@@ -447,8 +447,8 @@ class SymbolTree(object):
comp = comp.parent
return None
def get_all_payloads_with_value(self, value, equal_fn=None):
"""Returns all the payloads whose `value` attribute is equal to `value`.
def get_higher_payloads_with_value(self, value, equal_fn=None):
"""Returns payloads above `active_node` whose `value` is equal to `value`.
Args:
value: The value to test.
......@@ -458,14 +458,14 @@ class SymbolTree(object):
payloads = []
if equal_fn is None:
equal_fn = operator.is_
comp = typing.cast(SequentialBindingNode, self.active_node)
while comp.parent is not None or comp.older_sibling is not None:
if comp.payload.value is not None and equal_fn(value, comp.payload.value):
payloads.append(comp.payload)
if comp.older_sibling is not None:
comp = comp.older_sibling
elif comp.parent is not None:
comp = comp.parent
node = typing.cast(SequentialBindingNode, self.active_node)
while node.parent is not None or node.older_sibling is not None:
if node.payload.value is not None and equal_fn(value, node.payload.value):
payloads.append(node.payload)
if node.older_sibling is not None:
node = node.older_sibling
elif node.parent is not None:
node = node.parent
return payloads
def update_payload_with_name(self, name):
......
......@@ -1077,6 +1077,58 @@ def remove_duplicate_block_locals(comp):
"""Returns `True` if `comp` should be transformed."""
return comp.is_block() or comp.is_reference()
def _resolve_reference_to_concrete(
ref: building_blocks.Reference,
symbol_tree: transformation_utils.SymbolTree
) -> building_blocks.ComputationBuildingBlock:
"""Resolves `value` to a concrete building block, as far as possible.
Args:
ref: Instance of `building_blocks.Reference` to resolve in `symbol_tree`.
symbol_tree: Instance of `transformation_utils.SymbolTree` which contains
variable bindings to be used when resolving `value`.
Returns:
The resolution of `value` in symbol tree. If this resolution is
itself a reference, this indicates that the reference chain terminates in
either an unbound reference or a parameter binding, and thus cannot be
resolved any further.
"""
comp = ref
while comp.is_reference():
payload = symbol_tree.get_payload_with_name(comp.name)
if payload is None:
# We've resolved this reference to an unbound comp; we cannot alter the
# unbound comp, so return it in place of `ref`.
return comp
new_comp = payload.value
if new_comp is None:
# `comp` is bound by a lambda; we cannot alter this either.
return comp
else:
comp = new_comp
return comp
def _remove_reference_chain(ref, symbol_tree):
value = _resolve_reference_to_concrete(ref, symbol_tree)
if value.is_reference():
return value, True
payloads_with_value = symbol_tree.get_higher_payloads_with_value(
value, tree_analysis.trees_equal)
if not payloads_with_value:
# In this case, the current binding is the only visible binding with value
# `value`. We don't need to update anything, or replace the current
# reference.
return ref, False
else:
highest_payload = payloads_with_value[-1]
lower_payloads = payloads_with_value[:-1]
for payload in lower_payloads:
symbol_tree.update_payload_with_name(payload.name)
highest_building_block = building_blocks.Reference(
highest_payload.name, highest_payload.value.type_signature)
return highest_building_block, True
def _transform(comp, symbol_tree):
"""Returns a new transformed computation or `comp`."""
if not _should_transform(comp):
......@@ -1094,34 +1146,7 @@ def remove_duplicate_block_locals(comp):
comp = building_blocks.Block(variables, comp.result)
return comp, True
elif comp.is_reference():
payload = symbol_tree.get_payload_with_name(comp.name)
if payload is None:
# Comp is unbound; we cannot alter it.
return comp, False
value = payload.value
if value is None:
return comp, False
while value.is_reference():
payload = symbol_tree.get_payload_with_name(value.name)
if payload is None:
# value is unbound; we cannot alter it.
return value, True
new_value = payload.value
if new_value is None:
comp = building_blocks.Reference(value.name, value.type_signature)
return comp, True
else:
value = new_value
payloads_with_value = symbol_tree.get_all_payloads_with_value(
value, tree_analysis.trees_equal)
if payloads_with_value:
highest_payload = payloads_with_value[-1]
lower_payloads = payloads_with_value[:-1]
for payload in lower_payloads:
symbol_tree.update_payload_with_name(payload.name)
comp = building_blocks.Reference(highest_payload.name,
highest_payload.value.type_signature)
return comp, True
return _remove_reference_chain(comp, symbol_tree)
return comp, False
symbol_tree = transformation_utils.SymbolTree(
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册