Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
0b5abd0d
提交
0b5abd0d
编辑于
11月 23, 2020
作者:
Zachary Charles
提交者:
tensorflow-copybara
11月 23, 2020
浏览文件
Add `get_model_weights` tf_computation to FedAvg and FedSGD processes created by tff.learning.
PiperOrigin-RevId: 343939365
上级
5965c17c
变更
5
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/learning/BUILD
浏览文件 @
0b5abd0d
...
...
@@ -44,6 +44,7 @@ py_library(
":model_utils"
,
"//tensorflow_federated/python/aggregators:factory"
,
"//tensorflow_federated/python/common_libs:py_typecheck"
,
"//tensorflow_federated/python/core/api:computations"
,
"//tensorflow_federated/python/core/templates:iterative_process"
,
"//tensorflow_federated/python/core/templates:measured_process"
,
"//tensorflow_federated/python/learning/framework:dataset_reduce"
,
...
...
@@ -113,6 +114,7 @@ py_library(
":model_utils"
,
"//tensorflow_federated/python/aggregators:factory"
,
"//tensorflow_federated/python/common_libs:py_typecheck"
,
"//tensorflow_federated/python/core/api:computations"
,
"//tensorflow_federated/python/core/templates:iterative_process"
,
"//tensorflow_federated/python/core/templates:measured_process"
,
"//tensorflow_federated/python/learning/framework:dataset_reduce"
,
...
...
tensorflow_federated/python/learning/federated_averaging.py
浏览文件 @
0b5abd0d
...
...
@@ -28,6 +28,7 @@ import tensorflow as tf
from
tensorflow_federated.python.aggregators
import
factory
from
tensorflow_federated.python.common_libs
import
py_typecheck
from
tensorflow_federated.python.core.api
import
computations
from
tensorflow_federated.python.core.templates
import
iterative_process
from
tensorflow_federated.python.core.templates
import
measured_process
from
tensorflow_federated.python.learning
import
model
as
model_lib
...
...
@@ -151,7 +152,7 @@ def build_federated_averaging_process(
This function creates a `tff.templates.IterativeProcess` that performs
federated averaging on client models. The iterative process has the following
methods:
methods
inherited from `tff.templates.IterativeProcess`
:
* `initialize`: A `tff.Computation` with the functional type signature
`( -> S@SERVER)`, where `S` is a `tff.learning.framework.ServerState`
...
...
@@ -166,6 +167,13 @@ def build_federated_averaging_process(
`tff.learning.Model.federated_output_computation` during client training
and any other metrics from broadcast and aggregation processes.
The iterative process also has the following method not inherited from
`tff.templates.IterativeProcess`:
* `get_model_weights`: A `tff.Computation` that takes as input the
a `tff.learning.framework.ServerState`, and returns a
`tff.learning.ModelWeights` containing the state's model weights.
Each time the `next` method is called, the server model is broadcast to each
client using a broadcast function. For each client, one epoch of local
training is performed via the `tf.keras.optimizers.Optimizer.apply_gradients`
...
...
@@ -220,10 +228,19 @@ def build_federated_averaging_process(
return
ClientFedAvg
(
model_fn
(),
client_optimizer_fn
(),
client_weight_fn
,
use_experimental_simulation_loop
)
return
optimizer_utils
.
build_model_delta_optimizer_process
(
iter_proc
=
optimizer_utils
.
build_model_delta_optimizer_process
(
model_fn
,
model_to_client_delta_fn
=
client_fed_avg
,
server_optimizer_fn
=
server_optimizer_fn
,
broadcast_process
=
broadcast_process
,
aggregation_process
=
aggregation_process
,
model_update_aggregation_factory
=
model_update_aggregation_factory
)
server_state_type
=
iter_proc
.
state_type
.
member
@
computations
.
tf_computation
(
server_state_type
)
def
get_model_weights
(
server_state
):
return
server_state
.
model
iter_proc
.
get_model_weights
=
get_model_weights
return
iter_proc
tensorflow_federated/python/learning/federated_averaging_test.py
浏览文件 @
0b5abd0d
...
...
@@ -261,6 +261,33 @@ class FederatedAveragingModelTffTest(test_case.TestCase,
self
.
assertEqual
(
metric_outputs
[
'train'
][
'num_examples'
],
0
)
self
.
assertTrue
(
tf
.
math
.
is_nan
(
metric_outputs
[
'train'
][
'loss'
]))
@
test_utils
.
skip_test_for_multi_gpu
def
test_get_model_weights
(
self
):
iterative_process
=
federated_averaging
.
build_federated_averaging_process
(
model_fn
=
model_examples
.
LinearRegression
,
client_optimizer_fn
=
lambda
:
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
0.1
))
num_clients
=
3
ds
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
collections
.
OrderedDict
(
x
=
[[
1.0
,
2.0
],
[
3.0
,
4.0
]],
y
=
[[
5.0
],
[
6.0
]],
)).
batch
(
2
)
datasets
=
[
ds
]
*
num_clients
state
=
iterative_process
.
initialize
()
self
.
assertIsInstance
(
iterative_process
.
get_model_weights
(
state
),
model_utils
.
ModelWeights
)
self
.
assertAllClose
(
state
.
model
.
trainable
,
iterative_process
.
get_model_weights
(
state
).
trainable
)
for
_
in
range
(
3
):
state
,
_
=
iterative_process
.
next
(
state
,
datasets
)
self
.
assertIsInstance
(
iterative_process
.
get_model_weights
(
state
),
model_utils
.
ModelWeights
)
self
.
assertAllClose
(
state
.
model
.
trainable
,
iterative_process
.
get_model_weights
(
state
).
trainable
)
if
__name__
==
'__main__'
:
execution_contexts
.
set_local_execution_context
()
...
...
tensorflow_federated/python/learning/federated_sgd.py
浏览文件 @
0b5abd0d
...
...
@@ -27,6 +27,7 @@ import tensorflow as tf
from
tensorflow_federated.python.aggregators
import
factory
from
tensorflow_federated.python.common_libs
import
py_typecheck
from
tensorflow_federated.python.core.api
import
computations
from
tensorflow_federated.python.core.templates
import
iterative_process
from
tensorflow_federated.python.core.templates
import
measured_process
from
tensorflow_federated.python.learning
import
model
as
model_lib
...
...
@@ -159,8 +160,8 @@ def build_federated_sgd_process(
"""Builds the TFF computations for optimization using federated SGD.
This function creates a `tff.templates.IterativeProcess` that performs
federated
averaging
on client models. The iterative process has the following
methods:
federated
SGD
on client models. The iterative process has the following
methods
inherited from `tff.templates.IterativeProcess`
:
* `initialize`: A `tff.Computation` with the functional type signature
`( -> S@SERVER)`, where `S` is a `tff.learning.framework.ServerState`
...
...
@@ -175,6 +176,13 @@ def build_federated_sgd_process(
`tff.learning.Model.federated_output_computation` during client training
and any other metrics from broadcast and aggregation processes.
The iterative process also has the following method not inherited from
`tff.templates.IterativeProcess`:
* `get_model_weights`: A `tff.Computation` that takes as input the
a `tff.learning.framework.ServerState`, and returns a
`tff.learning.ModelWeights` containing the state's model weights.
Each time the `next` method is called, the server model is broadcast to each
client using a broadcast function. Each client sums the gradients at each
batch in the client's local dataset. These gradient sums are then aggregated
...
...
@@ -228,10 +236,19 @@ def build_federated_sgd_process(
client_weight_fn
,
use_experimental_simulation_loop
=
use_experimental_simulation_loop
)
return
optimizer_utils
.
build_model_delta_optimizer_process
(
iter_proc
=
optimizer_utils
.
build_model_delta_optimizer_process
(
model_fn
,
model_to_client_delta_fn
=
client_sgd_avg
,
server_optimizer_fn
=
server_optimizer_fn
,
broadcast_process
=
broadcast_process
,
aggregation_process
=
aggregation_process
,
model_update_aggregation_factory
=
model_update_aggregation_factory
)
server_state_type
=
iter_proc
.
state_type
.
member
@
computations
.
tf_computation
(
server_state_type
)
def
get_model_weights
(
server_state
):
return
server_state
.
model
iter_proc
.
get_model_weights
=
get_model_weights
return
iter_proc
tensorflow_federated/python/learning/federated_sgd_test.py
浏览文件 @
0b5abd0d
...
...
@@ -224,6 +224,32 @@ class FederatedSGDTffTest(test_case.TestCase, parameterized.TestCase):
self
.
assertEqual
(
metric_outputs
[
'train'
][
'num_examples'
],
0
)
self
.
assertTrue
(
tf
.
math
.
is_nan
(
metric_outputs
[
'train'
][
'loss'
]))
@
test_utils
.
skip_test_for_multi_gpu
def
test_get_model_weights
(
self
):
iterative_process
=
federated_sgd
.
build_federated_sgd_process
(
model_fn
=
model_examples
.
LinearRegression
)
num_clients
=
3
ds
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
collections
.
OrderedDict
(
x
=
[[
1.0
,
2.0
],
[
3.0
,
4.0
]],
y
=
[[
5.0
],
[
6.0
]],
)).
batch
(
2
)
datasets
=
[
ds
]
*
num_clients
state
=
iterative_process
.
initialize
()
self
.
assertIsInstance
(
iterative_process
.
get_model_weights
(
state
),
model_utils
.
ModelWeights
)
self
.
assertAllClose
(
state
.
model
.
trainable
,
iterative_process
.
get_model_weights
(
state
).
trainable
)
for
_
in
range
(
3
):
state
,
_
=
iterative_process
.
next
(
state
,
datasets
)
self
.
assertIsInstance
(
iterative_process
.
get_model_weights
(
state
),
model_utils
.
ModelWeights
)
self
.
assertAllClose
(
state
.
model
.
trainable
,
iterative_process
.
get_model_weights
(
state
).
trainable
)
if
__name__
==
'__main__'
:
execution_contexts
.
set_local_execution_context
()
...
...
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录