Skip to content
GitLab
菜单
项目
群组
代码片段
/
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
14f06878
提交
14f06878
编辑于
11月 22, 2021
作者:
Sean Augenstein
提交者:
tensorflow-copybara
11月 22, 2021
浏览文件
Automated rollback of commit
39f1f33a
PiperOrigin-RevId: 411664233
上级
39f1f33a
变更
3
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/examples/personalization/p13n_utils.py
浏览文件 @
14f06878
...
...
@@ -157,9 +157,7 @@ def evaluate_fn(model: tff.learning.Model,
initial_state
=
0
,
reduce_func
=
reduce_fn
)
eval_metrics
=
collections
.
OrderedDict
()
eval_metrics
[
'num_test_examples'
]
=
num_examples_sum
# TODO(b/202027089): Removes this pytype comment once this function does not
# use `report_local_outputs`.
local_outputs
=
model
.
report_local_outputs
()
# pytype: disable=attribute-error
local_outputs
=
model
.
report_local_outputs
()
# Postprocesses the metric values. This is needed because the values returned
# by `model.report_local_outputs()` are values of the state variables in each
# `tf.keras.metrics.Metric`. These values should be processed in the same way
...
...
tensorflow_federated/python/learning/BUILD
浏览文件 @
14f06878
...
...
@@ -240,6 +240,7 @@ py_library(
name
=
"model"
,
srcs
=
[
"model.py"
],
srcs_version
=
"PY3"
,
deps
=
[
"//tensorflow_federated/python/core/api:computation_base"
],
)
py_library
(
...
...
tensorflow_federated/python/learning/model.py
浏览文件 @
14f06878
...
...
@@ -19,6 +19,8 @@ from typing import Any, Callable, OrderedDict, Sequence
import
attr
import
tensorflow
as
tf
from
tensorflow_federated.python.core.api
import
computation_base
MODEL_ARG_NAME
=
'x'
MODEL_LABEL_NAME
=
'y'
MetricFinalizersType
=
OrderedDict
[
str
,
Callable
[[
Any
],
Any
]]
...
...
@@ -188,6 +190,65 @@ class Model(object, metaclass=abc.ABCMeta):
function understands the result.
"""
.
format
(
MODEL_ARG_NAME
)
@
abc
.
abstractmethod
def
report_local_outputs
(
self
):
"""Returns tensors representing values aggregated over `forward_pass` calls.
In federated learning, the values returned by this method will typically
be further aggregated across clients and made available on the server.
This method returns results from aggregating across *all* previous calls
to `forward_pass`, most typically metrics like accuracy and loss. If needed,
we may add a `clear_aggregated_outputs` method, which would likely just
run the initializers on the `local_variables`.
In general, the tensors returned can be an arbitrary function of all
the `tf.Variables` of this model, not just the `local_variables`; for
example, this could return tensors measuring the total L2 norm of the model
(which might have been updated by training).
This method may return arbitrarily shaped tensors, not just scalar metrics.
For example, it could return the average feature vector or a count of
how many times each feature exceed a certain magnitude.
Returns:
A structure of tensors (as supported by `tf.nest`)
to be aggregated across clients.
"""
pass
@
abc
.
abstractproperty
def
federated_output_computation
(
self
)
->
computation_base
.
Computation
:
"""Performs federated aggregation of the `Model's` `local_outputs`.
This is typically used to aggregate metrics across many clients, e.g. the
body of the computation might be:
```python
return {
'num_examples': tff.federated_sum(local_outputs.num_examples),
'loss': tff.federated_mean(local_outputs.loss)
}
```
N.B. It is assumed all TensorFlow computation happens in the
`report_local_outputs` method, and this method only uses TFF constructs to
specify aggregations across clients.
Returns:
Either a `tff.Computation`, or None if no federated aggregation is needed.
The `tff.Computation` should take as its single input a
`tff.CLIENTS`-placed `tff.Value` corresponding to the return value of
`Model.report_local_outputs`, and return an `OrderedDict` (possibly
nested) of `tff.SERVER`-placed values. The consumer of this
method should generally provide these server-placed values as outputs of
the overall computation consuming the model. Using an `OrderedDict`
allows the value returned by TFF executor to be converted back to an
`OrderedDict` via the `._asdict(recursive=True)` member function.
"""
pass
@
abc
.
abstractmethod
def
report_local_unfinalized_metrics
(
self
)
->
OrderedDict
[
str
,
Any
]:
"""Creates an `OrderedDict` of metric names to unfinalized values.
...
...
编辑
预览
支持
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录