Commit 39f1f33a authored by Shanshan Wu's avatar Shanshan Wu Committed by tensorflow-copybara
Browse files

Remove `federated_output_computation` and `report_local_outputs` from base model class.

PiperOrigin-RevId: 411629757
parent bd9009e4
......@@ -157,7 +157,9 @@ def evaluate_fn(model: tff.learning.Model,
initial_state=0, reduce_func=reduce_fn)
eval_metrics = collections.OrderedDict()
eval_metrics['num_test_examples'] = num_examples_sum
local_outputs = model.report_local_outputs()
# TODO(b/202027089): Removes this pytype comment once this function does not
# use `report_local_outputs`.
local_outputs = model.report_local_outputs() # pytype: disable=attribute-error
# Postprocesses the metric values. This is needed because the values returned
# by `model.report_local_outputs()` are values of the state variables in each
# `tf.keras.metrics.Metric`. These values should be processed in the same way
......@@ -240,7 +240,6 @@ py_library(
name = "model",
srcs = [""],
srcs_version = "PY3",
deps = ["//tensorflow_federated/python/core/api:computation_base"],
......@@ -19,8 +19,6 @@ from typing import Any, Callable, OrderedDict, Sequence
import attr
import tensorflow as tf
from tensorflow_federated.python.core.api import computation_base
MetricFinalizersType = OrderedDict[str, Callable[[Any], Any]]
......@@ -190,65 +188,6 @@ class Model(object, metaclass=abc.ABCMeta):
function understands the result.
def report_local_outputs(self):
"""Returns tensors representing values aggregated over `forward_pass` calls.
In federated learning, the values returned by this method will typically
be further aggregated across clients and made available on the server.
This method returns results from aggregating across *all* previous calls
to `forward_pass`, most typically metrics like accuracy and loss. If needed,
we may add a `clear_aggregated_outputs` method, which would likely just
run the initializers on the `local_variables`.
In general, the tensors returned can be an arbitrary function of all
the `tf.Variables` of this model, not just the `local_variables`; for
example, this could return tensors measuring the total L2 norm of the model
(which might have been updated by training).
This method may return arbitrarily shaped tensors, not just scalar metrics.
For example, it could return the average feature vector or a count of
how many times each feature exceed a certain magnitude.
A structure of tensors (as supported by `tf.nest`)
to be aggregated across clients.
def federated_output_computation(self) -> computation_base.Computation:
"""Performs federated aggregation of the `Model's` `local_outputs`.
This is typically used to aggregate metrics across many clients, e.g. the
body of the computation might be:
return {
'num_examples': tff.federated_sum(local_outputs.num_examples),
'loss': tff.federated_mean(local_outputs.loss)
N.B. It is assumed all TensorFlow computation happens in the
`report_local_outputs` method, and this method only uses TFF constructs to
specify aggregations across clients.
Either a `tff.Computation`, or None if no federated aggregation is needed.
The `tff.Computation` should take as its single input a
`tff.CLIENTS`-placed `tff.Value` corresponding to the return value of
`Model.report_local_outputs`, and return an `OrderedDict` (possibly
nested) of `tff.SERVER`-placed values. The consumer of this
method should generally provide these server-placed values as outputs of
the overall computation consuming the model. Using an `OrderedDict`
allows the value returned by TFF executor to be converted back to an
`OrderedDict` via the `._asdict(recursive=True)` member function.
def report_local_unfinalized_metrics(self) -> OrderedDict[str, Any]:
"""Creates an `OrderedDict` of metric names to unfinalized values.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment