提交 d87edcc6 编辑于 作者: Taylor Cramer's avatar Taylor Cramer 提交者: tensorflow-copybara
浏览文件

Allow federated_zip intrinsic to take arbitrary structures

Previously, the federated_zip intrinsic only operated on exactly two
elements at a time. When users passed structures with more than two elementsto `tff.federated_zip`, a TFF created a tree of calls to the `federated_zip`
intrinsic. This resulted in a slightly simpler model for runtime
implementations of `federated_zip`, but greatly increased the complexity
within the compiler, often resulting in large complex trees of calls with no
functional purpose.

This CL modifies `federated_zip` and its runtime implementations in order to
provide zipping of large, uneven structures through a single call. This
results in reduced complexity both in the implementation and in the
generate ASTs.

PiperOrigin-RevId: 414008706
上级 27654541
......@@ -630,7 +630,7 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedValueAtServer) {
ExpectMaterialize(result_id, ServerV(tensor));
}
TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtClients) {
TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtClientsFlat) {
v0::Value v1 = ClientsV({TensorV(1)}, true);
v0::Value v2 = ClientsV({TensorV(2)}, true);
auto merged_struct = StructV({TensorV(1), TensorV(2)});
......@@ -646,14 +646,39 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtClients) {
TFF_ASSERT_OK_AND_ASSIGN(auto arg_id,
test_executor_->CreateValue(StructV({v1, v2})));
TFF_ASSERT_OK_AND_ASSIGN(
auto zip_id, test_executor_->CreateValue(FederatedZipAtServerV()));
auto zip_id, test_executor_->CreateValue(FederatedZipAtClientsV()));
TFF_ASSERT_OK_AND_ASSIGN(auto res_id,
test_executor_->CreateCall(zip_id, arg_id));
ExpectMaterialize(
res_id, ClientsV(std::vector<v0::Value>(total_clients_, merged_struct)));
}
TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtServer) {
TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtClientsNested) {
v0::Value v1 = ClientsV({TensorV(1)}, true);
v0::Value v2 = ClientsV({TensorV(2)}, true);
v0::Value v2_struct = StructV({v2});
auto merged_struct = StructV({TensorV(1), StructV({TensorV(2)})});
v0::Value merged = ClientsV({merged_struct}, true);
for (const auto& child : mock_children_) {
auto child_v1 = child->ExpectCreateValue(v1);
auto child_v2 = child->ExpectCreateValue(v2);
auto child_zip = child->ExpectCreateValue(FederatedZipAtClientsV());
auto child_v2_struct = child->ExpectCreateStruct({child_v2});
auto child_zip_arg = child->ExpectCreateStruct({child_v1, child_v2_struct});
auto child_res = child->ExpectCreateCall(child_zip, child_zip_arg);
child->ExpectMaterialize(child_res, merged);
}
TFF_ASSERT_OK_AND_ASSIGN(
auto arg_id, test_executor_->CreateValue(StructV({v1, v2_struct})));
TFF_ASSERT_OK_AND_ASSIGN(
auto zip_id, test_executor_->CreateValue(FederatedZipAtClientsV()));
TFF_ASSERT_OK_AND_ASSIGN(auto res_id,
test_executor_->CreateCall(zip_id, arg_id));
ExpectMaterialize(
res_id, ClientsV(std::vector<v0::Value>(total_clients_, merged_struct)));
}
TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtServerFlat) {
v0::Value v1 = TensorV(1);
v0::Value v2 = TensorV(2);
ValueId v1_child_id = mock_server_->ExpectCreateValue(v1);
......@@ -670,6 +695,26 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtServer) {
ExpectMaterialize(result_id, ServerV(StructV({v1, v2})));
}
TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtServerNested) {
v0::Value v1 = TensorV(1);
v0::Value v2 = TensorV(2);
ValueId v1_child_id = mock_server_->ExpectCreateValue(v1);
ValueId v2_child_id = mock_server_->ExpectCreateValue(v2);
TFF_ASSERT_OK_AND_ASSIGN(auto v_id,
test_executor_->CreateValue(
StructV({ServerV(v1), StructV({ServerV(v2)})})));
TFF_ASSERT_OK_AND_ASSIGN(
auto zip_id, test_executor_->CreateValue(FederatedZipAtServerV()));
ValueId v2_struct_child_id = mock_server_->ExpectCreateStruct({v2_child_id});
ValueId struct_child_id =
mock_server_->ExpectCreateStruct({v1_child_id, v2_struct_child_id});
TFF_ASSERT_OK_AND_ASSIGN(auto result_id,
test_executor_->CreateCall(zip_id, v_id));
mock_server_->ExpectMaterialize(struct_child_id,
StructV({v1, StructV({v2})}));
ExpectMaterialize(result_id, ServerV(StructV({v1, StructV({v2})})));
}
TEST_F(ComposingExecutorTest, CreateCallFederatedZipDifferentPlacementsFails) {
v0::Value v1 = ClientsV({TensorV(1)}, true);
v0::Value v2_inner = TensorV(2);
......
......@@ -29,9 +29,10 @@ absl::StatusOr<FederatedIntrinsic> FederatedIntrinsicFromUri(
if (uri == kFederatedMapAtClientsUri || uri == "federated_apply" ||
uri == "federated_map_all_equal") {
return FederatedIntrinsic::MAP;
} else if (uri == kFederatedZipAtClientsUri ||
uri == "federated_zip_at_server") {
return FederatedIntrinsic::ZIP;
} else if (uri == kFederatedZipAtClientsUri) {
return FederatedIntrinsic::ZIP_AT_CLIENTS;
} else if (uri == kFederatedZipAtServerUri) {
return FederatedIntrinsic::ZIP_AT_SERVER;
} else if (uri == "federated_broadcast") {
return FederatedIntrinsic::BROADCAST;
} else if (uri == "federated_value_at_clients") {
......
......@@ -25,11 +25,13 @@ const absl::string_view kFederatedMapAtClientsUri = "federated_map";
const absl::string_view kFederatedEvalAtClientsUri =
"federated_eval_at_clients";
const absl::string_view kFederatedZipAtClientsUri = "federated_zip_at_clients";
const absl::string_view kFederatedZipAtServerUri = "federated_zip_at_server";
const absl::string_view kFederatedAggregateUri = "federated_aggregate";
enum class FederatedIntrinsic {
MAP,
ZIP,
ZIP_AT_CLIENTS,
ZIP_AT_SERVER,
BROADCAST,
VALUE_AT_CLIENTS,
VALUE_AT_SERVER,
......
......@@ -170,7 +170,7 @@ class FederatingExecutor : public ExecutorBase<ExecutorValue> {
return ::tensorflow_federated::NewClients(num_clients_);
}
absl::StatusOr<ExecutorValue> CreateFederatedValue_(
absl::StatusOr<ExecutorValue> CreateFederatedValue(
FederatedKind kind, const v0::Value_Federated& federated) {
switch (kind) {
case FederatedKind::SERVER: {
......@@ -230,7 +230,7 @@ class FederatingExecutor : public ExecutorBase<ExecutorValue> {
case v0::Value::kFederated: {
const v0::Value_Federated& federated = value_pb.federated();
auto kind = TFF_TRY(ValidateFederated(num_clients_, federated));
return CreateFederatedValue_(kind, federated);
return CreateFederatedValue(kind, federated);
}
case v0::Value::kStruct: {
auto elements = NewStructure();
......@@ -288,6 +288,66 @@ class FederatingExecutor : public ExecutorBase<ExecutorValue> {
}
}
// Embeds `arg` containing structures of server-placed values into the
// `child_` executor.
absl::StatusOr<std::shared_ptr<OwnedValueId>> ZipStructIntoServer(
const ExecutorValue& arg) {
switch (arg.type()) {
case ExecutorValue::ValueType::SERVER: {
return arg.server();
}
case ExecutorValue::ValueType::STRUCTURE: {
std::vector<std::shared_ptr<OwnedValueId>> owned_element_ids;
owned_element_ids.reserve(arg.structure()->size());
for (const auto& element : *arg.structure()) {
owned_element_ids.push_back(TFF_TRY(ZipStructIntoServer(element)));
}
std::vector<ValueId> element_ids;
element_ids.reserve(arg.structure()->size());
for (const auto& owned_id : owned_element_ids) {
element_ids.push_back(owned_id->ref());
}
return ShareValueId(TFF_TRY(child_->CreateStruct(element_ids)));
}
default: {
return absl::InvalidArgumentError(absl::StrCat(
"Cannot `", kFederatedZipAtServerUri,
"` a structure containing a value of kind ", arg.type()));
}
}
}
// Embeds `arg` containing structures of client-placed values into the
// `child_` executor. The resulting structure on `child_` will contain all
// values for the client corresponding to `client_index`.
absl::StatusOr<std::shared_ptr<OwnedValueId>> ZipStructIntoClient(
const ExecutorValue& arg, uint32_t client_index) {
switch (arg.type()) {
case ExecutorValue::ValueType::CLIENTS: {
return (*arg.clients())[client_index];
}
case ExecutorValue::ValueType::STRUCTURE: {
std::vector<std::shared_ptr<OwnedValueId>> owned_element_ids;
owned_element_ids.reserve(arg.structure()->size());
for (const auto& element : *arg.structure()) {
owned_element_ids.push_back(
TFF_TRY(ZipStructIntoClient(element, client_index)));
}
std::vector<ValueId> element_ids;
element_ids.reserve(arg.structure()->size());
for (const auto& owned_id : owned_element_ids) {
element_ids.push_back(owned_id->ref());
}
return ShareValueId(TFF_TRY(child_->CreateStruct(element_ids)));
}
default: {
return absl::InvalidArgumentError(absl::StrCat(
"Cannot `", kFederatedZipAtClientsUri,
"` a structure containing a value of kind ", arg.type()));
}
}
}
absl::StatusOr<ExecutorValue> CallFederatedIntrinsic(
FederatedIntrinsic function, ExecutorValue arg) {
switch (function) {
......@@ -370,34 +430,16 @@ class FederatingExecutor : public ExecutorBase<ExecutorValue> {
"Attempted to map non-federated value.");
}
}
case FederatedIntrinsic::ZIP: {
TFF_TRY(CheckLenForUseAsArgument(arg, "federated_zip", 2));
const auto& first = arg.structure()->at(0);
const auto& second = arg.structure()->at(1);
if (first.type() != second.type()) {
return absl::InvalidArgumentError(absl::StrCat(
"Attempted to zip values with different placements: ",
first.type(), " and ", second.type()));
}
if (first.type() == ExecutorValue::ValueType::CLIENTS) {
Clients pairs = NewClients();
for (int i = 0; i < num_clients_; i++) {
ValueId first_id = first.clients()->at(i)->ref();
ValueId second_id = second.clients()->at(i)->ref();
auto pair = TFF_TRY(child_->CreateStruct({first_id, second_id}));
pairs->emplace_back(ShareValueId(std::move(pair)));
}
return ExecutorValue::CreateClientsPlaced(std::move(pairs));
} else if (first.type() == ExecutorValue::ValueType::SERVER) {
ValueId first_id = first.server()->ref();
ValueId second_id = second.server()->ref();
auto pair = TFF_TRY(child_->CreateStruct({first_id, second_id}));
return ExecutorValue::CreateServerPlaced(
ShareValueId(std::move(pair)));
} else {
return absl::InvalidArgumentError(
"Attempted to zip non-federated value.");
case FederatedIntrinsic::ZIP_AT_CLIENTS: {
Clients results = NewClients();
for (uint32_t i = 0; i < num_clients_; i++) {
results->push_back(TFF_TRY(ZipStructIntoClient(arg, i)));
}
return ExecutorValue::CreateClientsPlaced(std::move(results));
}
case FederatedIntrinsic::ZIP_AT_SERVER: {
return ExecutorValue::CreateServerPlaced(
TFF_TRY(ZipStructIntoServer(arg)));
}
}
}
......
......@@ -520,7 +520,7 @@ TEST_F(FederatingExecutorTest, CreateCallFederatedValueAtServer) {
ExpectMaterialize(result_id, ServerV(tensor));
}
TEST_F(FederatingExecutorTest, CreateCallFederatedZipAtClients) {
TEST_F(FederatingExecutorTest, CreateCallFederatedZipAtClientsFlat) {
v0::Value v1 = TensorV(1);
v0::Value v2 = TensorV(2);
ValueId v1_child_id = ExpectCreateInChild(v1);
......@@ -539,7 +539,29 @@ TEST_F(FederatingExecutorTest, CreateCallFederatedZipAtClients) {
NUM_CLIENTS, StructV({v1, v2}))));
}
TEST_F(FederatingExecutorTest, CreateCallFederatedZipAtServer) {
TEST_F(FederatingExecutorTest, CreateCallFederatedZipAtClientsNested) {
v0::Value v1 = TensorV(1);
v0::Value v2 = TensorV(2);
ValueId v1_child_id = ExpectCreateInChild(v1);
ValueId v2_child_id = ExpectCreateInChild(v2);
TFF_ASSERT_OK_AND_ASSIGN(
auto v_id, test_executor_->CreateValue(StructV(
{ClientsV({v1}, true), StructV({ClientsV({v2}, true)})})));
TFF_ASSERT_OK_AND_ASSIGN(
auto zip_id, test_executor_->CreateValue(FederatedZipAtClientsV()));
ValueId v2_struct_child_id =
ExpectCreateStructInChild({v2_child_id}, ONCE_PER_CLIENT);
ValueId struct_child_id = ExpectCreateStructInChild(
{v1_child_id, v2_struct_child_id}, ONCE_PER_CLIENT);
TFF_ASSERT_OK_AND_ASSIGN(auto result_id,
test_executor_->CreateCall(zip_id, v_id));
ExpectMaterializeInChild(struct_child_id, StructV({v1, StructV({v2})}),
ONCE_PER_CLIENT);
ExpectMaterialize(result_id, ClientsV(std::vector<v0::Value>(
NUM_CLIENTS, StructV({v1, StructV({v2})}))));
}
TEST_F(FederatingExecutorTest, CreateCallFederatedZipAtServerFlat) {
v0::Value v1 = TensorV(1);
v0::Value v2 = TensorV(2);
ValueId v1_child_id = ExpectCreateInChild(v1);
......@@ -556,6 +578,25 @@ TEST_F(FederatingExecutorTest, CreateCallFederatedZipAtServer) {
ExpectMaterialize(result_id, ServerV(StructV({v1, v2})));
}
TEST_F(FederatingExecutorTest, CreateCallFederatedZipAtServerNested) {
v0::Value v1 = TensorV(1);
v0::Value v2 = TensorV(2);
ValueId v1_child_id = ExpectCreateInChild(v1);
ValueId v2_child_id = ExpectCreateInChild(v2);
TFF_ASSERT_OK_AND_ASSIGN(auto v_id,
test_executor_->CreateValue(
StructV({ServerV(v1), StructV({ServerV(v2)})})));
TFF_ASSERT_OK_AND_ASSIGN(
auto zip_id, test_executor_->CreateValue(FederatedZipAtServerV()));
ValueId v2_struct_child_id = ExpectCreateStructInChild({v2_child_id});
ValueId struct_child_id =
ExpectCreateStructInChild({v1_child_id, v2_struct_child_id});
TFF_ASSERT_OK_AND_ASSIGN(auto result_id,
test_executor_->CreateCall(zip_id, v_id));
ExpectMaterializeInChild(struct_child_id, StructV({v1, StructV({v2})}));
ExpectMaterialize(result_id, ServerV(StructV({v1, StructV({v2})})));
}
TEST_F(FederatingExecutorTest, CreateCallFederatedZipMixedPlacementsFails) {
v0::Value v1 = TensorV(1);
v0::Value v2 = TensorV(2);
......
......@@ -84,15 +84,15 @@ py_test(
size = "small",
srcs = ["building_block_factory_test.py"],
args = [
"--golden",
"$(location building_block_factory_test_goldens/replaces_single_element.expected)",
"--golden",
"$(location building_block_factory_test_goldens/skips_unnamed_element.expected)",
"--golden",
"$(location building_block_factory_test_goldens/constructs_correct_computation_clients.expected)",
"--golden",
"$(location building_block_factory_test_goldens/constructs_correct_computation_server.expected)",
"--golden",
"$(location building_block_factory_test_goldens/replaces_single_element.expected)",
"--golden",
"$(location building_block_factory_test_goldens/skips_unnamed_element.expected)",
"--golden",
"$(location building_block_factory_test_goldens/tuple_federated_map_with_two_values_unnamed.expected)",
"--golden",
"$(location building_block_factory_test_goldens/tuple_federated_map_with_two_values_named.expected)",
......@@ -105,40 +105,6 @@ py_test(
"--golden",
"$(location building_block_factory_test_goldens/tuple_federated_apply_with_two_values_different_typed.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_clients_with_two_values_unnamed.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_clients_with_two_values_unnamed_tuple.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_clients_with_two_values_named.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_clients_with_two_values_named_tuple.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_unnamed.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_unnamed_tuple.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_named.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_named_tuple.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_different_typed.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_different_typed_tuple.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_server_with_two_values_named.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_server_with_three_values_unnamed.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_server_with_three_values_named.expected)",
"--golden",
"$(location building_block_factory_test_goldens/federated_zip_at_server_with_three_values_different_typed.expected)",
"--golden",
"$(location building_block_factory_test_goldens/wide_zip_creates_minimum_depth_binary_tree.expected)",
"--golden",
"$(location building_block_factory_test_goldens/nested_returns_federated_zip_at_clients.expected)",
"--golden",
"$(location building_block_factory_test_goldens/nested_returns_federated_zip_at_server.expected)",
"--golden",
"$(location building_block_factory_test_goldens/zips_tuple_unnamed.expected)",
"--golden",
"$(location building_block_factory_test_goldens/zips_tuple_named.expected)",
......@@ -148,22 +114,6 @@ py_test(
data = [
"building_block_factory_test_goldens/constructs_correct_computation_clients.expected",
"building_block_factory_test_goldens/constructs_correct_computation_server.expected",
"building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_different_typed.expected",
"building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_different_typed_tuple.expected",
"building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_named.expected",
"building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_named_tuple.expected",
"building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_unnamed.expected",
"building_block_factory_test_goldens/federated_zip_at_clients_with_three_values_unnamed_tuple.expected",
"building_block_factory_test_goldens/federated_zip_at_clients_with_two_values_named.expected",
"building_block_factory_test_goldens/federated_zip_at_clients_with_two_values_named_tuple.expected",
"building_block_factory_test_goldens/federated_zip_at_clients_with_two_values_unnamed.expected",
"building_block_factory_test_goldens/federated_zip_at_clients_with_two_values_unnamed_tuple.expected",
"building_block_factory_test_goldens/federated_zip_at_server_with_three_values_different_typed.expected",
"building_block_factory_test_goldens/federated_zip_at_server_with_three_values_named.expected",
"building_block_factory_test_goldens/federated_zip_at_server_with_three_values_unnamed.expected",
"building_block_factory_test_goldens/federated_zip_at_server_with_two_values_named.expected",
"building_block_factory_test_goldens/nested_returns_federated_zip_at_clients.expected",
"building_block_factory_test_goldens/nested_returns_federated_zip_at_server.expected",
"building_block_factory_test_goldens/replaces_single_element.expected",
"building_block_factory_test_goldens/skips_unnamed_element.expected",
"building_block_factory_test_goldens/tuple_federated_apply_with_two_values_different_typed.expected",
......@@ -172,7 +122,6 @@ py_test(
"building_block_factory_test_goldens/tuple_federated_map_with_two_values_different_typed.expected",
"building_block_factory_test_goldens/tuple_federated_map_with_two_values_named.expected",
"building_block_factory_test_goldens/tuple_federated_map_with_two_values_unnamed.expected",
"building_block_factory_test_goldens/wide_zip_creates_minimum_depth_binary_tree.expected",
"building_block_factory_test_goldens/zips_reference.expected",
"building_block_factory_test_goldens/zips_tuple_named.expected",
"building_block_factory_test_goldens/zips_tuple_unnamed.expected",
......
......@@ -36,7 +36,6 @@ from tensorflow_federated.python.core.impl.types import type_serialization
from tensorflow_federated.python.core.impl.types import type_transformations
from tensorflow_federated.python.core.impl.utils import tensorflow_utils
Index = Union[str, int]
Path = Union[Index, Tuple[Index, ...]]
......@@ -1228,143 +1227,6 @@ def create_federated_value(
return building_blocks.Call(intrinsic, value)
def _create_flat_federated_zip(value):
r"""Private function to create a called federated zip for a non-nested type.
Call
/ \
Intrinsic Tuple
|
[Comp, Comp]
This function returns a federated tuple given a `value` with a tuple of
federated values type signature.
Args:
value: A `building_blocks.ComputationBuildingBlock` with a `type_signature`
of type `computation_types.StructType` containing at least one element.
Returns:
A `building_blocks.Call`.
Raises:
TypeError: If any of the types do not match.
ValueError: If `value` does not contain any elements.
"""
py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock)
named_type_signatures = structure.to_elements(value.type_signature)
container_type = value.type_signature.python_container
names_to_add = [name for name, _ in named_type_signatures]
length = len(named_type_signatures)
if length == 0:
raise ValueError('federated_zip is only supported on non-empty tuples.')
first_name, first_type_signature = named_type_signatures[0]
if first_type_signature.placement == placements.CLIENTS:
map_fn = create_federated_map
elif first_type_signature.placement == placements.SERVER:
map_fn = create_federated_apply
else:
raise TypeError('Unsupported placement {}.'.format(
first_type_signature.placement))
if length == 1:
ref = building_blocks.Reference('arg', first_type_signature.member)
values = building_blocks.Struct(((first_name, ref),), container_type)
fn = building_blocks.Lambda(ref.name, ref.type_signature, values)
sel = building_blocks.Selection(value, index=0)
return map_fn(fn, sel)
elif length == 2:
# No point building and tearing down a tree if we can just federated_zip
# Note: this branch is purely an optimization and is not necessary.
if any((name is not None for name in names_to_add)):
# Remove names if necessary
named_ref = building_blocks.Reference('named', value.type_signature)
value = building_blocks.Block(
[(named_ref.name, value)],
building_blocks.Struct((
building_blocks.Selection(named_ref, index=0),
building_blocks.Selection(named_ref, index=1),
)))
unnamed_zip = create_zip_two_values(value)
else:
# Build a binary tree of federated zips
args = building_blocks.Reference('value', value.type_signature)
zipped, paths = _build_tree_of_zips_and_paths_to_elements(
args, 0,
len(value.type_signature) - 1)
zipped_block = building_blocks.Block([(args.name, value)], zipped)
# Select the values out of the tree back into a flat tuple
zipped_tree_ref = building_blocks.Reference('zipped_tree',
zipped.type_signature.member)
flattened_tree = building_blocks.Struct(
[_selection_from_path(zipped_tree_ref, path) for path in paths])
flatten_fn = building_blocks.Lambda(zipped_tree_ref.name,
zipped_tree_ref.type_signature,
flattened_tree)
unnamed_zip = map_fn(flatten_fn, zipped_block)
return create_named_federated_tuple(unnamed_zip, names_to_add, container_type)
def _prepend_to_paths(paths: List[List[int]], element: int):
for path in paths:
path.insert(0, element)
def _build_tree_of_zips_and_paths_to_elements(
args: building_blocks.Reference,
start_index: int,
end_index: int,
) -> Tuple[building_blocks.ComputationBuildingBlock, List[List[int]]]:
"""Builds a binary tree of federated_zips and a list of paths to each element.
Args:
args: A reference to the values to be zipped.
start_index: The index of the first element of `args` to zip.
end_index: The index of the last element of `args` to zip.
Returns:
A tuple containing the tree of zips as well as a list of paths to the
element at each index. A single path is a list of indices that can be used
with `_selection_from_path` to select an element out of the result.
"""
py_typecheck.check_type(args, building_blocks.Reference)
py_typecheck.check_type(args.type_signature, computation_types.StructType)
if start_index == end_index:
# Base case for one element
tree = building_blocks.Selection(args, index=start_index)
paths = [[]]
elif start_index + 1 == end_index:
# Base case for two elements
first = building_blocks.Selection(args, index=start_index)
second = building_blocks.Selection(args, index=end_index)
values = building_blocks.Struct((first, second))
tree = create_zip_two_values(values)
paths = [[0], [1]]
else:
# Recursive case for three or more elements
split_point = int((start_index + end_index) / 2)
left_tree, left_paths = _build_tree_of_zips_and_paths_to_elements(
args, start_index, split_point)
right_tree, right_paths = _build_tree_of_zips_and_paths_to_elements(
args, split_point + 1, end_index)
values = building_blocks.Struct((left_tree, right_tree))
tree = create_zip_two_values(values)
_prepend_to_paths(left_paths, 0)
_prepend_to_paths(right_paths, 1)
paths = left_paths + right_paths
py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)
assert len(paths) == (end_index - start_index + 1)
return (tree, paths)
def _selection_from_path(
selected: building_blocks.ComputationBuildingBlock,
path: List[int],
) -> building_blocks.ComputationBuildingBlock:
for path_element in path:
selected = building_blocks.Selection(selected, index=path_element)
return selected
def _check_placements(
placement_values: AbstractSet[placements.PlacementLiteral]):
"""Checks if the placements of the values being zipped are compatible."""
......@@ -1408,71 +1270,46 @@ def create_federated_zip(
py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock)
py_typecheck.check_type(value.type_signature, computation_types.StructType)
# If the type signature is flat, just call _create_flat_federated_zip.
elements