提交 f0aecb8f 编辑于 作者: Keith Rush's avatar Keith Rush 提交者: tensorflow-copybara
浏览文件

Preserve resulting all-equal information while merging intrinsics.

PiperOrigin-RevId: 393880481
上级 00a23d7b
......@@ -158,7 +158,6 @@ py_test(
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/impl/types:type_transformations",
"//tensorflow_federated/python/core/impl/wrappers:computation_wrapper_instances",
],
)
......@@ -395,11 +395,21 @@ def _compute_merged_intrinsics(
unpack_to_locals=[]))
else:
calls = [local[1] for local in locals_for_uri]
result_placement = calls[0].type_signature.placement
result_all_equal = calls[0].type_signature.all_equal
for call in calls:
if call.type_signature.all_equal != result_all_equal:
raise ValueError('Encountered intrinsics to be merged with '
f'mismatched all_equal bits. Intrinsic of URI {uri} '
f'first call had all_equal bit {result_all_equal}, '
'encountered call with all_equal value '
f'{call.type_signature.all_equal}')
return_type = computation_types.FederatedType(
computation_types.StructType([
(None, call.type_signature.member) for call in calls
]),
placement=calls[0].type_signature.placement)
placement=result_placement,
all_equal=result_all_equal)
abstract_parameter_type = default_call.function.intrinsic_def(
).type_signature.parameter
results.append(
......
......@@ -32,7 +32,6 @@ from tensorflow_federated.python.core.impl.executors import executor_stacks
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_analysis
from tensorflow_federated.python.core.impl.types import type_transformations
from tensorflow_federated.python.core.impl.wrappers import computation_wrapper_instances
DEFAULT_GRAPPLER_CONFIG = tf.compat.v1.ConfigProto()
......@@ -237,19 +236,6 @@ class ConsolidateAndExtractTest(test_case.TestCase):
self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)
def _remove_client_all_equals_from_type(type_signature):
def _transform(inner_type):
if (inner_type.is_federated() and inner_type.placement.is_clients() and
inner_type.all_equal):
return computation_types.FederatedType(inner_type.member,
inner_type.placement, False), True
return inner_type, False
return type_transformations.transform_type_postorder(type_signature,
_transform)[0]
class ForceAlignAndSplitByIntrinsicTest(test_case.TestCase):
def assert_splits_on(self, comp, calls):
......@@ -291,13 +277,6 @@ class ForceAlignAndSplitByIntrinsicTest(test_case.TestCase):
before.type_signature.result[i],
after.parameter_type.intrinsic_results[i])
abstract_signature = calls[i].function.intrinsic_def().type_signature
# `force_align_and_split_by_intrinsics` loses all-equal data due to
# zipping and unzipping. This is okay because the resulting computations
# are not used together directly, but are compiled into unplaced TF code.
abstract_signature = _remove_client_all_equals_from_type(
abstract_signature)
concrete_signature = _remove_client_all_equals_from_type(
concrete_signature)
type_analysis.check_concrete_instance_of(concrete_signature,
abstract_signature)
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册