提交 0fd36cab 编辑于 作者: Zachary Charles's avatar Zachary Charles 提交者: tensorflow-copybara
浏览文件

De-flake end-to-end learning tests involving aggregator randomness.

PiperOrigin-RevId: 411828023
上级 14f06878
......@@ -105,7 +105,7 @@ class FederatedAveragingE2ETest(tff.test.TestCase, parameterized.TestCase):
client_optimizer_fn=_get_keras_optimizer_fn(),
aggregator_factory=tff.learning.dp_aggregator(1e-8, 10))
self.assertLessEqual(loss, 0.2)
self.assertLessEqual(loss, 0.22)
self.assertGreater(accuracy, 0.92)
def test_emnist10_cnn_convergence_dp_aggregator_high_noise(self):
......@@ -115,10 +115,7 @@ class FederatedAveragingE2ETest(tff.test.TestCase, parameterized.TestCase):
client_optimizer_fn=_get_keras_optimizer_fn(),
aggregator_factory=tff.learning.dp_aggregator(2e-1, 10))
self.assertGreaterEqual(loss, 0.2)
self.assertLessEqual(loss, 5)
self.assertLess(accuracy, 0.9)
self.assertGreater(accuracy, 0.15)
......
......@@ -29,7 +29,11 @@ def _get_keras_optimizer_fn(learning_rate=0.1):
class FederatedSGDE2ETest(tff.test.TestCase, parameterized.TestCase):
def _run_process(self, process, client_selection_fn):
def _run_process(self,
process,
client_selection_fn,
loss_threshold=0.4,
accuracy_threshold=0.85):
state = process.initialize()
training_metrics = []
for round_num in range(200):
......@@ -44,8 +48,8 @@ class FederatedSGDE2ETest(tff.test.TestCase, parameterized.TestCase):
average_loss_last_10_rounds = np.mean(loss_last_10_rounds)
average_accuracy_last_10_rounds = np.mean(accuracy_last_10_rounds)
self.assertLessEqual(average_loss_last_10_rounds, 0.4)
self.assertGreater(average_accuracy_last_10_rounds, 0.85)
self.assertLessEqual(average_loss_last_10_rounds, loss_threshold)
self.assertGreater(average_accuracy_last_10_rounds, accuracy_threshold)
@parameterized.named_parameters([
('keras_opt', _get_keras_optimizer_fn()),
......@@ -74,7 +78,6 @@ class FederatedSGDE2ETest(tff.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters([
('robust_aggregator', tff.learning.robust_aggregator),
('compression_aggregator', tff.learning.compression_aggregator),
('secure_aggregator', tff.learning.secure_aggregator),
])
def test_emnist10_cnn_convergence_with_aggregator(self,
......@@ -100,6 +103,28 @@ class FederatedSGDE2ETest(tff.test.TestCase, parameterized.TestCase):
model_update_aggregation_factory=aggregator_factory_fn())
self._run_process(process, client_selection_fn)
def test_emnist10_cnn_convergence_with_compression_aggregator(self):
train_client_spec = tff.simulation.baselines.ClientSpec(
num_epochs=1, batch_size=32, shuffle_buffer_size=1)
task = tff.simulation.baselines.emnist.create_character_recognition_task(
train_client_spec, model_id='cnn', only_digits=True)
train_client_ids = sorted(task.datasets.train_data.client_ids)
preprocessed_train_data = task.datasets.train_data.preprocess(
task.datasets.train_preprocess_fn)
def client_selection_fn(round_num):
random_state = np.random.RandomState(round_num)
client_ids = random_state.choice(train_client_ids, size=10, replace=False)
return [
preprocessed_train_data.create_tf_dataset_for_client(a)
for a in client_ids
]
process = tff.learning.build_federated_sgd_process(
model_fn=task.model_fn,
model_update_aggregation_factory=tff.learning.compression_aggregator())
self._run_process(process, client_selection_fn, loss_threshold=0.42)
if __name__ == '__main__':
# We must use the test execution context for the secure intrinsics introduced
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册