Skip to content
Snippets Groups Projects
Commit 0b53db27 authored by Keith Rush's avatar Keith Rush Committed by tensorflow-copybara
Browse files

Renames remove_lambdas_and_blocks to remove_called_lambdas_and_blocks.

PiperOrigin-RevId: 321265603
parent 187f78aa
Branches
Tags
No related merge requests found
......@@ -218,7 +218,7 @@ def consolidate_and_extract_local_processing(comp):
produced by this extraction step, as described above.
"""
py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
comp, _ = transformations.remove_lambdas_and_blocks(comp)
comp, _ = transformations.remove_called_lambdas_and_blocks(comp)
if comp.type_signature.is_function():
if comp.is_compiled_computation():
return comp
......
......@@ -35,7 +35,7 @@ from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERA
from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_MAP_ALL_EQUAL
from tensorflow_federated.python.core.impl.compiler.transformation_utils import get_map_of_unbound_references
from tensorflow_federated.python.core.impl.compiler.transformation_utils import transform_postorder
from tensorflow_federated.python.core.impl.compiler.transformations import remove_lambdas_and_blocks
from tensorflow_federated.python.core.impl.compiler.transformations import remove_called_lambdas_and_blocks
from tensorflow_federated.python.core.impl.compiler.tree_analysis import check_broadcast_not_dependent_on_aggregate
from tensorflow_federated.python.core.impl.compiler.tree_analysis import check_has_unique_names
from tensorflow_federated.python.core.impl.compiler.tree_analysis import check_intrinsics_whitelisted_for_reduction
......
......@@ -55,7 +55,8 @@ def prepare_for_rebinding(comp):
selections from tuples collapsed.
"""
# TODO(b/146430051): Follow up here and consider removing or enforcing more
# strict output invariants when `remove_lambdas_and_blocks` is moved in here.
# strict output invariants when `remove_called_lambdas_and_blocks` is moved
# in here.
py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
comp, _ = tree_transformations.uniquify_reference_names(comp)
comp, _ = tree_transformations.replace_called_lambda_with_block(comp)
......@@ -80,7 +81,7 @@ def prepare_for_rebinding(comp):
comp, _transform_fn, symbol_tree)
def remove_lambdas_and_blocks(comp):
def remove_called_lambdas_and_blocks(comp):
"""Removes any called lambdas and blocks from `comp`.
This function will rename all the variables in `comp` in a single walk of the
......@@ -620,17 +621,17 @@ class TensorFlowGenerator(transformation_utils.TransformSpec):
def transform(self, local_function):
if not self.should_transform(local_function):
return local_function, False
refs_removed, _ = remove_lambdas_and_blocks(local_function)
refs_removed, _ = remove_called_lambdas_and_blocks(local_function)
parsed_to_tf, _ = remove_duplicate_called_graphs(refs_removed)
if parsed_to_tf.is_compiled_computation() or (
parsed_to_tf.is_call() and
parsed_to_tf.function.is_compiled_computation()):
return parsed_to_tf, True
# TODO(b/146430051): We should only end up in this case if
# `remove_lambdas_and_blocks` above is in its failure mode, IE, failing to
# resolve references due to too-deep indirection; we should remove
# this extra case and simply raise if we fail here when we fix the attached
# bug.
# `remove_called_lambdas_and_blocks` above is in its failure mode, IE,
# failing to resolve references due to too-deep indirection; we should
# remove this extra case and simply raise if we fail here when we fix the
# attached bug.
called_graphs_inserted, _ = tree_transformations.insert_called_tf_identity_at_leaves(
parsed_to_tf)
compiled_comp, _ = transformation_utils.transform_postorder(
......
......@@ -46,7 +46,7 @@ class RemoveLambdasAndBlocksTest(test.TestCase):
'x', tf.int32, building_blocks.Reference('x', tf.int32))
called_lambda = building_blocks.Call(identity_lam,
building_blocks.Data('a', tf.int32))
lambdas_and_blocks_removed, modified = transformations.remove_lambdas_and_blocks(
lambdas_and_blocks_removed, modified = transformations.remove_called_lambdas_and_blocks(
called_lambda)
self.assertTrue(modified)
self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
......@@ -57,7 +57,7 @@ class RemoveLambdasAndBlocksTest(test.TestCase):
simple_block = building_blocks.Block([('x', data)],
building_blocks.Reference(
'x', tf.int32))
lambdas_and_blocks_removed, modified = transformations.remove_lambdas_and_blocks(
lambdas_and_blocks_removed, modified = transformations.remove_called_lambdas_and_blocks(
simple_block)
self.assertTrue(modified)
self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
......@@ -77,7 +77,7 @@ class RemoveLambdasAndBlocksTest(test.TestCase):
concrete_arg = building_blocks.Data('a', tf.int32)
arg_tuple = building_blocks.Tuple([concrete_fn, concrete_arg])
generated_structure = building_blocks.Block([('arg', arg_tuple)], called_fn)
lambdas_and_blocks_removed, modified = transformations.remove_lambdas_and_blocks(
lambdas_and_blocks_removed, modified = transformations.remove_called_lambdas_and_blocks(
generated_structure)
self.assertTrue(modified)
self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
......@@ -91,7 +91,7 @@ class RemoveLambdasAndBlocksTest(test.TestCase):
zipped = building_block_factory.create_federated_zip(unzipped)
placement_unwrapped, _ = tree_transformations.unwrap_placement(zipped)
placement_gone = placement_unwrapped.argument
lambdas_and_blocks_removed, modified = transformations.remove_lambdas_and_blocks(
lambdas_and_blocks_removed, modified = transformations.remove_called_lambdas_and_blocks(
placement_gone)
self.assertTrue(modified)
self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
......@@ -105,7 +105,7 @@ class RemoveLambdasAndBlocksTest(test.TestCase):
higher_level_lambda = building_blocks.Lambda('fn',
identity_lam.type_signature,
called_inner_lambda)
lambdas_and_blocks_removed, modified = transformations.remove_lambdas_and_blocks(
lambdas_and_blocks_removed, modified = transformations.remove_called_lambdas_and_blocks(
higher_level_lambda)
self.assertTrue(modified)
self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
......@@ -126,7 +126,7 @@ class RemoveLambdasAndBlocksTest(test.TestCase):
('b', tuple_wrapping_ref),
('c', selection_from_ref),
], called_lambda_with_indirection)
lambdas_and_blocks_removed, modified = transformations.remove_lambdas_and_blocks(
lambdas_and_blocks_removed, modified = transformations.remove_called_lambdas_and_blocks(
blk)
self.assertTrue(modified)
self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
......@@ -145,7 +145,7 @@ class RemoveLambdasAndBlocksTest(test.TestCase):
left_lambda = building_blocks.Lambda('x', middle_lambda.type_signature, rez)
higher_call = building_blocks.Call(left_lambda, middle_lambda)
high_call = building_blocks.Call(higher_call, data)
lambdas_and_blocks_removed, modified = transformations.remove_lambdas_and_blocks(
lambdas_and_blocks_removed, modified = transformations.remove_called_lambdas_and_blocks(
high_call)
self.assertTrue(modified)
self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment