提交 14f06878 编辑于 作者: Sean Augenstein's avatar Sean Augenstein 提交者: tensorflow-copybara
浏览文件

Automated rollback of commit 39f1f33a

PiperOrigin-RevId: 411664233
上级 39f1f33a
...@@ -157,9 +157,7 @@ def evaluate_fn(model: tff.learning.Model, ...@@ -157,9 +157,7 @@ def evaluate_fn(model: tff.learning.Model,
initial_state=0, reduce_func=reduce_fn) initial_state=0, reduce_func=reduce_fn)
eval_metrics = collections.OrderedDict() eval_metrics = collections.OrderedDict()
eval_metrics['num_test_examples'] = num_examples_sum eval_metrics['num_test_examples'] = num_examples_sum
# TODO(b/202027089): Removes this pytype comment once this function does not local_outputs = model.report_local_outputs()
# use `report_local_outputs`.
local_outputs = model.report_local_outputs() # pytype: disable=attribute-error
# Postprocesses the metric values. This is needed because the values returned # Postprocesses the metric values. This is needed because the values returned
# by `model.report_local_outputs()` are values of the state variables in each # by `model.report_local_outputs()` are values of the state variables in each
# `tf.keras.metrics.Metric`. These values should be processed in the same way # `tf.keras.metrics.Metric`. These values should be processed in the same way
......
...@@ -240,6 +240,7 @@ py_library( ...@@ -240,6 +240,7 @@ py_library(
name = "model", name = "model",
srcs = ["model.py"], srcs = ["model.py"],
srcs_version = "PY3", srcs_version = "PY3",
deps = ["//tensorflow_federated/python/core/api:computation_base"],
) )
py_library( py_library(
......
...@@ -19,6 +19,8 @@ from typing import Any, Callable, OrderedDict, Sequence ...@@ -19,6 +19,8 @@ from typing import Any, Callable, OrderedDict, Sequence
import attr import attr
import tensorflow as tf import tensorflow as tf
from tensorflow_federated.python.core.api import computation_base
MODEL_ARG_NAME = 'x' MODEL_ARG_NAME = 'x'
MODEL_LABEL_NAME = 'y' MODEL_LABEL_NAME = 'y'
MetricFinalizersType = OrderedDict[str, Callable[[Any], Any]] MetricFinalizersType = OrderedDict[str, Callable[[Any], Any]]
...@@ -188,6 +190,65 @@ class Model(object, metaclass=abc.ABCMeta): ...@@ -188,6 +190,65 @@ class Model(object, metaclass=abc.ABCMeta):
function understands the result. function understands the result.
""".format(MODEL_ARG_NAME) """.format(MODEL_ARG_NAME)
@abc.abstractmethod
def report_local_outputs(self):
"""Returns tensors representing values aggregated over `forward_pass` calls.
In federated learning, the values returned by this method will typically
be further aggregated across clients and made available on the server.
This method returns results from aggregating across *all* previous calls
to `forward_pass`, most typically metrics like accuracy and loss. If needed,
we may add a `clear_aggregated_outputs` method, which would likely just
run the initializers on the `local_variables`.
In general, the tensors returned can be an arbitrary function of all
the `tf.Variables` of this model, not just the `local_variables`; for
example, this could return tensors measuring the total L2 norm of the model
(which might have been updated by training).
This method may return arbitrarily shaped tensors, not just scalar metrics.
For example, it could return the average feature vector or a count of
how many times each feature exceed a certain magnitude.
Returns:
A structure of tensors (as supported by `tf.nest`)
to be aggregated across clients.
"""
pass
@abc.abstractproperty
def federated_output_computation(self) -> computation_base.Computation:
"""Performs federated aggregation of the `Model's` `local_outputs`.
This is typically used to aggregate metrics across many clients, e.g. the
body of the computation might be:
```python
return {
'num_examples': tff.federated_sum(local_outputs.num_examples),
'loss': tff.federated_mean(local_outputs.loss)
}
```
N.B. It is assumed all TensorFlow computation happens in the
`report_local_outputs` method, and this method only uses TFF constructs to
specify aggregations across clients.
Returns:
Either a `tff.Computation`, or None if no federated aggregation is needed.
The `tff.Computation` should take as its single input a
`tff.CLIENTS`-placed `tff.Value` corresponding to the return value of
`Model.report_local_outputs`, and return an `OrderedDict` (possibly
nested) of `tff.SERVER`-placed values. The consumer of this
method should generally provide these server-placed values as outputs of
the overall computation consuming the model. Using an `OrderedDict`
allows the value returned by TFF executor to be converted back to an
`OrderedDict` via the `._asdict(recursive=True)` member function.
"""
pass
@abc.abstractmethod @abc.abstractmethod
def report_local_unfinalized_metrics(self) -> OrderedDict[str, Any]: def report_local_unfinalized_metrics(self) -> OrderedDict[str, Any]:
"""Creates an `OrderedDict` of metric names to unfinalized values. """Creates an `OrderedDict` of metric names to unfinalized values.
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册