提交 0fa0ed2c 编辑于 作者: Keith Rush's avatar Keith Rush 提交者: tensorflow-copybara
浏览文件

Adds federated_map_all_equal URI to MAP casing in CC execution.

PiperOrigin-RevId: 411925855
上级 86dfb208
...@@ -46,6 +46,7 @@ using testing::intrinsic::FederatedAggregateV; ...@@ -46,6 +46,7 @@ using testing::intrinsic::FederatedAggregateV;
using testing::intrinsic::FederatedBroadcastV; using testing::intrinsic::FederatedBroadcastV;
using testing::intrinsic::FederatedEvalAtClientsV; using testing::intrinsic::FederatedEvalAtClientsV;
using testing::intrinsic::FederatedEvalAtServerV; using testing::intrinsic::FederatedEvalAtServerV;
using testing::intrinsic::FederatedMapAllEqualV;
using testing::intrinsic::FederatedMapV; using testing::intrinsic::FederatedMapV;
using testing::intrinsic::FederatedValueAtClientsV; using testing::intrinsic::FederatedValueAtClientsV;
using testing::intrinsic::FederatedValueAtServerV; using testing::intrinsic::FederatedValueAtServerV;
...@@ -463,6 +464,43 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedMapAtClients) { ...@@ -463,6 +464,43 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedMapAtClients) {
ExpectMaterialize(res_id, ClientsV(client_vals_out)); ExpectMaterialize(res_id, ClientsV(client_vals_out));
} }
TEST_F(ComposingExecutorTest, CreateCallFederatedMapAllEqualAtClients) {
std::vector<v0::Value> client_vals_in;
std::vector<v0::Value> client_vals_out;
v0::Value fn = TensorV(24601);
for (uint32_t i = 0; i < mock_children_.size(); i++) {
const auto& child = mock_children_[i];
std::vector<v0::Value> in_vec;
std::vector<v0::Value> out_vec;
for (uint32_t j = 0; j < clients_per_child_[i]; j++) {
v0::Value in = TensorV(i * 10000 + j * 100);
v0::Value out = TensorV(i * 10000 + j * 100 + 1);
client_vals_in.push_back(in);
in_vec.push_back(in);
client_vals_out.push_back(out);
out_vec.push_back(out);
}
auto in_id = child->ExpectCreateValue(ClientsV(in_vec));
// We convert the all-equal map to a usual map in our children, relying on
// our callers to reinsert the all-equal information if desired.
auto map_id = child->ExpectCreateValue(FederatedMapV());
auto fn_id = child->ExpectCreateValue(fn);
auto args_id = child->ExpectCreateStruct({fn_id, in_id});
auto res_id = child->ExpectCreateCall(map_id, args_id);
child->ExpectMaterialize(res_id, ClientsV(out_vec));
}
TFF_ASSERT_OK_AND_ASSIGN(auto fn_id, test_executor_->CreateValue(fn));
TFF_ASSERT_OK_AND_ASSIGN(
auto input_id, test_executor_->CreateValue(ClientsV(client_vals_in)));
TFF_ASSERT_OK_AND_ASSIGN(
auto map_id, test_executor_->CreateValue(FederatedMapAllEqualV()));
TFF_ASSERT_OK_AND_ASSIGN(auto arg_id,
test_executor_->CreateStruct({fn_id, input_id}));
TFF_ASSERT_OK_AND_ASSIGN(auto res_id,
test_executor_->CreateCall(map_id, arg_id));
ExpectMaterialize(res_id, ClientsV(client_vals_out));
}
TEST_F(ComposingExecutorTest, CreateCallFederatedMapAtServer) { TEST_F(ComposingExecutorTest, CreateCallFederatedMapAtServer) {
v0::Value tensor = TensorV(23); v0::Value tensor = TensorV(23);
ValueId tensor_child_id = mock_server_->ExpectCreateValue(tensor); ValueId tensor_child_id = mock_server_->ExpectCreateValue(tensor);
......
...@@ -26,7 +26,8 @@ namespace tensorflow_federated { ...@@ -26,7 +26,8 @@ namespace tensorflow_federated {
absl::StatusOr<FederatedIntrinsic> FederatedIntrinsicFromUri( absl::StatusOr<FederatedIntrinsic> FederatedIntrinsicFromUri(
const absl::string_view uri) { const absl::string_view uri) {
if (uri == kFederatedMapAtClientsUri || uri == "federated_apply") { if (uri == kFederatedMapAtClientsUri || uri == "federated_apply" ||
uri == "federated_map_all_equal") {
return FederatedIntrinsic::MAP; return FederatedIntrinsic::MAP;
} else if (uri == kFederatedZipAtClientsUri || } else if (uri == kFederatedZipAtClientsUri ||
uri == "federated_zip_at_server") { uri == "federated_zip_at_server") {
......
...@@ -48,6 +48,7 @@ using testing::intrinsic::FederatedAggregateV; ...@@ -48,6 +48,7 @@ using testing::intrinsic::FederatedAggregateV;
using testing::intrinsic::FederatedBroadcastV; using testing::intrinsic::FederatedBroadcastV;
using testing::intrinsic::FederatedEvalAtClientsV; using testing::intrinsic::FederatedEvalAtClientsV;
using testing::intrinsic::FederatedEvalAtServerV; using testing::intrinsic::FederatedEvalAtServerV;
using testing::intrinsic::FederatedMapAllEqualV;
using testing::intrinsic::FederatedMapV; using testing::intrinsic::FederatedMapV;
using testing::intrinsic::FederatedValueAtClientsV; using testing::intrinsic::FederatedValueAtClientsV;
using testing::intrinsic::FederatedValueAtServerV; using testing::intrinsic::FederatedValueAtServerV;
...@@ -416,6 +417,33 @@ TEST_F(FederatingExecutorTest, CreateCallFederatedMapAtClients) { ...@@ -416,6 +417,33 @@ TEST_F(FederatingExecutorTest, CreateCallFederatedMapAtClients) {
ExpectMaterialize(result_id, value); ExpectMaterialize(result_id, value);
} }
TEST_F(FederatingExecutorTest, CreateCallFederatedMapAllEqualAtClients) {
std::vector<v0::Value> client_vals;
std::vector<ValueId> client_vals_child_ids;
for (int i = 0; i < NUM_CLIENTS; i++) {
client_vals.emplace_back(TensorV(i));
client_vals_child_ids.emplace_back(ExpectCreateInChild(TensorV(i)));
}
v0::Value value = ClientsV(client_vals);
TFF_ASSERT_OK_AND_ASSIGN(auto input_id,
test_executor_->CreateValue(ClientsV(client_vals)));
v0::Value fn = TensorV(2);
ValueId fn_child_id = ExpectCreateInChild(fn);
TFF_ASSERT_OK_AND_ASSIGN(auto fn_id, test_executor_->CreateValue(fn));
for (int i = 0; i < NUM_CLIENTS; i++) {
ValueId result_child_id =
ExpectCreateCallInChild(fn_child_id, client_vals_child_ids[i]);
ExpectMaterializeInChild(result_child_id, client_vals[i]);
}
TFF_ASSERT_OK_AND_ASSIGN(
auto map_id, test_executor_->CreateValue(FederatedMapAllEqualV()));
TFF_ASSERT_OK_AND_ASSIGN(auto arg_id,
test_executor_->CreateStruct({fn_id, input_id}));
TFF_ASSERT_OK_AND_ASSIGN(auto result_id,
test_executor_->CreateCall(map_id, arg_id));
ExpectMaterialize(result_id, value);
}
TEST_F(FederatingExecutorTest, CreateCallFederatedMapAtServer) { TEST_F(FederatingExecutorTest, CreateCallFederatedMapAtServer) {
v0::Value tensor = TensorV(23); v0::Value tensor = TensorV(23);
ValueId tensor_child_id = ExpectCreateInChild(tensor); ValueId tensor_child_id = ExpectCreateInChild(tensor);
......
...@@ -141,6 +141,7 @@ namespace intrinsic { ...@@ -141,6 +141,7 @@ namespace intrinsic {
INTRINSIC_FUNC(FederatedAggregateV, federated_aggregate); INTRINSIC_FUNC(FederatedAggregateV, federated_aggregate);
INTRINSIC_FUNC(FederatedBroadcastV, federated_broadcast); INTRINSIC_FUNC(FederatedBroadcastV, federated_broadcast);
INTRINSIC_FUNC(FederatedMapV, federated_map); INTRINSIC_FUNC(FederatedMapV, federated_map);
INTRINSIC_FUNC(FederatedMapAllEqualV, federated_map_all_equal);
INTRINSIC_FUNC(FederatedEvalAtClientsV, federated_eval_at_clients); INTRINSIC_FUNC(FederatedEvalAtClientsV, federated_eval_at_clients);
INTRINSIC_FUNC(FederatedEvalAtServerV, federated_eval_at_server); INTRINSIC_FUNC(FederatedEvalAtServerV, federated_eval_at_server);
INTRINSIC_FUNC(FederatedValueAtClientsV, federated_value_at_clients); INTRINSIC_FUNC(FederatedValueAtClientsV, federated_value_at_clients);
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册