Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
0fd36cab
提交
0fd36cab
编辑于
11月 23, 2021
作者:
Zachary Charles
提交者:
tensorflow-copybara
11月 23, 2021
浏览文件
De-flake end-to-end learning tests involving aggregator randomness.
PiperOrigin-RevId: 411828023
上级
14f06878
变更
2
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/tests/federated_averaging_e2e_convergence_test.py
浏览文件 @
0fd36cab
...
...
@@ -105,7 +105,7 @@ class FederatedAveragingE2ETest(tff.test.TestCase, parameterized.TestCase):
client_optimizer_fn
=
_get_keras_optimizer_fn
(),
aggregator_factory
=
tff
.
learning
.
dp_aggregator
(
1e-8
,
10
))
self
.
assertLessEqual
(
loss
,
0.2
)
self
.
assertLessEqual
(
loss
,
0.2
2
)
self
.
assertGreater
(
accuracy
,
0.92
)
def
test_emnist10_cnn_convergence_dp_aggregator_high_noise
(
self
):
...
...
@@ -115,10 +115,7 @@ class FederatedAveragingE2ETest(tff.test.TestCase, parameterized.TestCase):
client_optimizer_fn
=
_get_keras_optimizer_fn
(),
aggregator_factory
=
tff
.
learning
.
dp_aggregator
(
2e-1
,
10
))
self
.
assertGreaterEqual
(
loss
,
0.2
)
self
.
assertLessEqual
(
loss
,
5
)
self
.
assertLess
(
accuracy
,
0.9
)
self
.
assertGreater
(
accuracy
,
0.15
)
...
...
tensorflow_federated/python/tests/federated_sgd_e2e_convergence_test.py
浏览文件 @
0fd36cab
...
...
@@ -29,7 +29,11 @@ def _get_keras_optimizer_fn(learning_rate=0.1):
class
FederatedSGDE2ETest
(
tff
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_run_process
(
self
,
process
,
client_selection_fn
):
def
_run_process
(
self
,
process
,
client_selection_fn
,
loss_threshold
=
0.4
,
accuracy_threshold
=
0.85
):
state
=
process
.
initialize
()
training_metrics
=
[]
for
round_num
in
range
(
200
):
...
...
@@ -44,8 +48,8 @@ class FederatedSGDE2ETest(tff.test.TestCase, parameterized.TestCase):
average_loss_last_10_rounds
=
np
.
mean
(
loss_last_10_rounds
)
average_accuracy_last_10_rounds
=
np
.
mean
(
accuracy_last_10_rounds
)
self
.
assertLessEqual
(
average_loss_last_10_rounds
,
0.4
)
self
.
assertGreater
(
average_accuracy_last_10_rounds
,
0.85
)
self
.
assertLessEqual
(
average_loss_last_10_rounds
,
loss_threshold
)
self
.
assertGreater
(
average_accuracy_last_10_rounds
,
accuracy_threshold
)
@
parameterized
.
named_parameters
([
(
'keras_opt'
,
_get_keras_optimizer_fn
()),
...
...
@@ -74,7 +78,6 @@ class FederatedSGDE2ETest(tff.test.TestCase, parameterized.TestCase):
@
parameterized
.
named_parameters
([
(
'robust_aggregator'
,
tff
.
learning
.
robust_aggregator
),
(
'compression_aggregator'
,
tff
.
learning
.
compression_aggregator
),
(
'secure_aggregator'
,
tff
.
learning
.
secure_aggregator
),
])
def
test_emnist10_cnn_convergence_with_aggregator
(
self
,
...
...
@@ -100,6 +103,28 @@ class FederatedSGDE2ETest(tff.test.TestCase, parameterized.TestCase):
model_update_aggregation_factory
=
aggregator_factory_fn
())
self
.
_run_process
(
process
,
client_selection_fn
)
def
test_emnist10_cnn_convergence_with_compression_aggregator
(
self
):
train_client_spec
=
tff
.
simulation
.
baselines
.
ClientSpec
(
num_epochs
=
1
,
batch_size
=
32
,
shuffle_buffer_size
=
1
)
task
=
tff
.
simulation
.
baselines
.
emnist
.
create_character_recognition_task
(
train_client_spec
,
model_id
=
'cnn'
,
only_digits
=
True
)
train_client_ids
=
sorted
(
task
.
datasets
.
train_data
.
client_ids
)
preprocessed_train_data
=
task
.
datasets
.
train_data
.
preprocess
(
task
.
datasets
.
train_preprocess_fn
)
def
client_selection_fn
(
round_num
):
random_state
=
np
.
random
.
RandomState
(
round_num
)
client_ids
=
random_state
.
choice
(
train_client_ids
,
size
=
10
,
replace
=
False
)
return
[
preprocessed_train_data
.
create_tf_dataset_for_client
(
a
)
for
a
in
client_ids
]
process
=
tff
.
learning
.
build_federated_sgd_process
(
model_fn
=
task
.
model_fn
,
model_update_aggregation_factory
=
tff
.
learning
.
compression_aggregator
())
self
.
_run_process
(
process
,
client_selection_fn
,
loss_threshold
=
0.42
)
if
__name__
==
'__main__'
:
# We must use the test execution context for the secure intrinsics introduced
...
...
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录