提交 0fa0ed2c 编辑于 作者: Keith Rush's avatar Keith Rush 提交者: tensorflow-copybara
浏览文件

Adds federated_map_all_equal URI to MAP casing in CC execution.

PiperOrigin-RevId: 411925855
上级 86dfb208
......@@ -46,6 +46,7 @@ using testing::intrinsic::FederatedAggregateV;
using testing::intrinsic::FederatedBroadcastV;
using testing::intrinsic::FederatedEvalAtClientsV;
using testing::intrinsic::FederatedEvalAtServerV;
using testing::intrinsic::FederatedMapAllEqualV;
using testing::intrinsic::FederatedMapV;
using testing::intrinsic::FederatedValueAtClientsV;
using testing::intrinsic::FederatedValueAtServerV;
......@@ -463,6 +464,43 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedMapAtClients) {
ExpectMaterialize(res_id, ClientsV(client_vals_out));
}
TEST_F(ComposingExecutorTest, CreateCallFederatedMapAllEqualAtClients) {
std::vector<v0::Value> client_vals_in;
std::vector<v0::Value> client_vals_out;
v0::Value fn = TensorV(24601);
for (uint32_t i = 0; i < mock_children_.size(); i++) {
const auto& child = mock_children_[i];
std::vector<v0::Value> in_vec;
std::vector<v0::Value> out_vec;
for (uint32_t j = 0; j < clients_per_child_[i]; j++) {
v0::Value in = TensorV(i * 10000 + j * 100);
v0::Value out = TensorV(i * 10000 + j * 100 + 1);
client_vals_in.push_back(in);
in_vec.push_back(in);
client_vals_out.push_back(out);
out_vec.push_back(out);
}
auto in_id = child->ExpectCreateValue(ClientsV(in_vec));
// We convert the all-equal map to a usual map in our children, relying on
// our callers to reinsert the all-equal information if desired.
auto map_id = child->ExpectCreateValue(FederatedMapV());
auto fn_id = child->ExpectCreateValue(fn);
auto args_id = child->ExpectCreateStruct({fn_id, in_id});
auto res_id = child->ExpectCreateCall(map_id, args_id);
child->ExpectMaterialize(res_id, ClientsV(out_vec));
}
TFF_ASSERT_OK_AND_ASSIGN(auto fn_id, test_executor_->CreateValue(fn));
TFF_ASSERT_OK_AND_ASSIGN(
auto input_id, test_executor_->CreateValue(ClientsV(client_vals_in)));
TFF_ASSERT_OK_AND_ASSIGN(
auto map_id, test_executor_->CreateValue(FederatedMapAllEqualV()));
TFF_ASSERT_OK_AND_ASSIGN(auto arg_id,
test_executor_->CreateStruct({fn_id, input_id}));
TFF_ASSERT_OK_AND_ASSIGN(auto res_id,
test_executor_->CreateCall(map_id, arg_id));
ExpectMaterialize(res_id, ClientsV(client_vals_out));
}
TEST_F(ComposingExecutorTest, CreateCallFederatedMapAtServer) {
v0::Value tensor = TensorV(23);
ValueId tensor_child_id = mock_server_->ExpectCreateValue(tensor);
......
......@@ -26,7 +26,8 @@ namespace tensorflow_federated {
absl::StatusOr<FederatedIntrinsic> FederatedIntrinsicFromUri(
const absl::string_view uri) {
if (uri == kFederatedMapAtClientsUri || uri == "federated_apply") {
if (uri == kFederatedMapAtClientsUri || uri == "federated_apply" ||
uri == "federated_map_all_equal") {
return FederatedIntrinsic::MAP;
} else if (uri == kFederatedZipAtClientsUri ||
uri == "federated_zip_at_server") {
......
......@@ -48,6 +48,7 @@ using testing::intrinsic::FederatedAggregateV;
using testing::intrinsic::FederatedBroadcastV;
using testing::intrinsic::FederatedEvalAtClientsV;
using testing::intrinsic::FederatedEvalAtServerV;
using testing::intrinsic::FederatedMapAllEqualV;
using testing::intrinsic::FederatedMapV;
using testing::intrinsic::FederatedValueAtClientsV;
using testing::intrinsic::FederatedValueAtServerV;
......@@ -416,6 +417,33 @@ TEST_F(FederatingExecutorTest, CreateCallFederatedMapAtClients) {
ExpectMaterialize(result_id, value);
}
TEST_F(FederatingExecutorTest, CreateCallFederatedMapAllEqualAtClients) {
std::vector<v0::Value> client_vals;
std::vector<ValueId> client_vals_child_ids;
for (int i = 0; i < NUM_CLIENTS; i++) {
client_vals.emplace_back(TensorV(i));
client_vals_child_ids.emplace_back(ExpectCreateInChild(TensorV(i)));
}
v0::Value value = ClientsV(client_vals);
TFF_ASSERT_OK_AND_ASSIGN(auto input_id,
test_executor_->CreateValue(ClientsV(client_vals)));
v0::Value fn = TensorV(2);
ValueId fn_child_id = ExpectCreateInChild(fn);
TFF_ASSERT_OK_AND_ASSIGN(auto fn_id, test_executor_->CreateValue(fn));
for (int i = 0; i < NUM_CLIENTS; i++) {
ValueId result_child_id =
ExpectCreateCallInChild(fn_child_id, client_vals_child_ids[i]);
ExpectMaterializeInChild(result_child_id, client_vals[i]);
}
TFF_ASSERT_OK_AND_ASSIGN(
auto map_id, test_executor_->CreateValue(FederatedMapAllEqualV()));
TFF_ASSERT_OK_AND_ASSIGN(auto arg_id,
test_executor_->CreateStruct({fn_id, input_id}));
TFF_ASSERT_OK_AND_ASSIGN(auto result_id,
test_executor_->CreateCall(map_id, arg_id));
ExpectMaterialize(result_id, value);
}
TEST_F(FederatingExecutorTest, CreateCallFederatedMapAtServer) {
v0::Value tensor = TensorV(23);
ValueId tensor_child_id = ExpectCreateInChild(tensor);
......
......@@ -141,6 +141,7 @@ namespace intrinsic {
INTRINSIC_FUNC(FederatedAggregateV, federated_aggregate);
INTRINSIC_FUNC(FederatedBroadcastV, federated_broadcast);
INTRINSIC_FUNC(FederatedMapV, federated_map);
INTRINSIC_FUNC(FederatedMapAllEqualV, federated_map_all_equal);
INTRINSIC_FUNC(FederatedEvalAtClientsV, federated_eval_at_clients);
INTRINSIC_FUNC(FederatedEvalAtServerV, federated_eval_at_server);
INTRINSIC_FUNC(FederatedValueAtClientsV, federated_value_at_clients);
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册