提交 04ca748c 编辑于 作者: Zachary Charles's avatar Zachary Charles 提交者: tensorflow-copybara
浏览文件

Create a composer for computations mirroring the structure of federated...

Create a composer for computations mirroring the structure of federated evaluation. This change also creates re-usable layers for client work and metrics aggregation.

PiperOrigin-RevId: 392952848
上级 b53d2518
......@@ -301,3 +301,39 @@ py_cpu_gpu_test(
srcs_version = "PY3",
deps = [":dataset_reduce"],
)
py_library(
name = "evaluation",
srcs = ["evaluation.py"],
srcs_version = "PY3",
deps = [
":dataset_reduce",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
],
)
py_test(
name = "evaluation_test",
srcs = ["evaluation_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":evaluation",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/learning:keras_utils",
"//tensorflow_federated/python/learning:model_utils",
],
)
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An implementation of federated evaluation using re-usable layers."""
import collections
from typing import Callable
import tensorflow as tf
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.api import computation_base
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.framework import dataset_reduce
def build_eval_work(
model_fn: Callable[[], model_lib.Model],
model_weights_type: computation_types.Type,
data_type: computation_types.Type,
use_experimental_simulation_loop: bool = False
) -> computation_base.Computation:
"""Builds a `tff.Computation` for evaluating a model on a dataset.
This function accepts model weights matching `model_weights_type` and data
whose batch type matches `data_type`, and returns metrics computed at
the corresponding model over the data.
Args:
model_fn: A no-arg function that returns a `tff.learning.Model`.
model_weights_type: A `tff.Type` representing the type of the model weights.
data_type: A `tff.Type` representing the batch type of the data. This must
be compatible with the batch type expected by the forward pass of the
model returned by `model_fn`.
use_experimental_simulation_loop: A boolean controlling the reduce loop used
to iterate over client datasets. If set to `True`, an experimental reduce
loop is used.
Returns:
A `tff.Computation`.
"""
@computations.tf_computation(model_weights_type,
computation_types.SequenceType(data_type))
@tf.function
def client_eval(incoming_model_weights, dataset):
with tf.init_scope():
model = model_fn()
model_weights = model_utils.ModelWeights.from_model(model)
tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
incoming_model_weights)
def reduce_fn(num_examples, batch):
model_output = model.forward_pass(batch, training=False)
return num_examples + tf.cast(model_output.num_examples, tf.int64)
dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
use_experimental_simulation_loop)
num_examples = dataset_reduce_fn(
reduce_fn=reduce_fn,
dataset=dataset,
initial_state_fn=lambda: tf.zeros([], dtype=tf.int64))
return collections.OrderedDict(
local_outputs=model.report_local_outputs(), num_examples=num_examples)
return client_eval
def build_model_metrics_aggregator(
model: model_lib.Model,
metrics_type: computation_types.Type) -> computation_base.Computation:
"""Creates a stateless aggregator for client metrics."""
@computations.federated_computation(
computation_types.at_clients(metrics_type))
def aggregate_metrics(client_metrics):
model_metrics = model.federated_output_computation(
client_metrics.local_outputs)
statistics = collections.OrderedDict(
num_examples=intrinsics.federated_sum(client_metrics.num_examples))
return intrinsics.federated_zip(
collections.OrderedDict(eval=model_metrics, stat=statistics))
return aggregate_metrics
class FederatedEvalTypeError(TypeError):
"""Raises evaluation components do not have expected type signatures."""
pass
class FederatedEvalInputOutputError(TypeError):
"""Raises when evaluation components have mismatched input/outputs."""
def check_federated_type_with_correct_placement(value_type, placement):
"""Checks that a `tff.Type` has the desired federated placement."""
if value_type is None:
return False
elif value_type.is_federated() and value_type.placement == placement:
return True
return False
def _validate_eval_types(stateless_distributor: computation_base.Computation,
client_eval_work: computation_base.Computation,
stateless_aggregator: computation_base.Computation):
"""Checks `compose_eval_computation` arguments meet documented constraints."""
py_typecheck.check_type(stateless_distributor, computation_base.Computation)
py_typecheck.check_type(client_eval_work, computation_base.Computation)
py_typecheck.check_type(stateless_aggregator, computation_base.Computation)
distributor_type = stateless_distributor.type_signature
distributor_parameter = distributor_type.parameter
if not check_federated_type_with_correct_placement(distributor_parameter,
placements.SERVER):
raise FederatedEvalTypeError(
f'The distributor must have a single input placed at `SERVER`, found '
f'input type signature {distributor_parameter}.')
distributor_result = distributor_type.result
if not check_federated_type_with_correct_placement(distributor_result,
placements.CLIENTS):
raise FederatedEvalTypeError(
f'The distributor must have a single output placed at `CLIENTS`, found '
f'output type signature {distributor_result}.')
client_work_type = client_eval_work.type_signature
if client_work_type.parameter.is_federated(
) or client_work_type.result.is_federated():
raise FederatedEvalTypeError(
f'The `client_eval_work` must be not be a federated computation, but '
f'found type signature {client_work_type}.')
aggregator_type = stateless_aggregator.type_signature
aggregator_parameter = aggregator_type.parameter
if not check_federated_type_with_correct_placement(aggregator_parameter,
placements.CLIENTS):
raise FederatedEvalTypeError(
f'The aggregator must have a single input placed at `CLIENTS`, found '
f'type signature {aggregator_parameter}.')
aggregator_result = aggregator_type.result
if not check_federated_type_with_correct_placement(aggregator_result,
placements.SERVER):
raise FederatedEvalTypeError(
f'The aggregator must have a single output placed at `SERVER`, found '
f'type signature {aggregator_result}.')
if not client_work_type.parameter[0].is_assignable_from(
distributor_result.member):
raise FederatedEvalInputOutputError(
f'The output of the distributor must be assignable to the first input '
f'of the client work. Found distributor result of type '
f'{distributor_result.member}, but client work argument of type'
f'{client_work_type.parameter[0]}.')
if not aggregator_parameter.member.is_assignable_from(
client_work_type.result):
raise FederatedEvalInputOutputError(
f'The output of the client work must be assignable to the input of the '
f'aggregator. Found client work output of type '
f'{client_work_type.result}, but aggregator parameter of type '
f'{aggregator_parameter.member}.')
def compose_eval_computation(
stateless_distributor: computation_base.Computation,
client_eval_work: computation_base.Computation,
stateless_metrics_aggregator: computation_base.Computation):
"""Builds a TFF computation performing stateless evaluation across clients.
The resulting computation has type signature
`(S@SERVER, T@CLIENTS -> A@SERVER)`, where `S` represents some value of the
server, `T` represents data held by the clients, and `A` are the aggregate
metrics across all clients.
Args:
stateless_distributor: A `tff.Computation` that broadcasts a value placed at
`tff.SERVER` to the clients. It must have type signature `(S@SERVER ->
R@CLIENTS)`, where the member of `R` matches the first type expected by
`client_eval_work`.
client_eval_work: A `tff.Computation` used to compute client metrics. Must
have type signature (<R, T> -> M) of unplaced types, where `R` is a type
representing a reference value broadcast from the server (such as model
weights), `T` represents values held by the client, and `M` is the type of
the metrics computed by the client.
stateless_metrics_aggregator: A `tff.Computation` that aggregates metrics
from the client. It must have type signature `(M@CLIENTS -> A@SERVER)`
where `M` has member matching the output of `client_eval_work`, and `A`
represents the aggregated metrics.
Returns:
A `tff.Computation`.
"""
_validate_eval_types(stateless_distributor, client_eval_work,
stateless_metrics_aggregator)
distributor_input_type = stateless_distributor.type_signature.parameter
client_data_type = client_eval_work.type_signature.parameter[1]
@computations.federated_computation(
distributor_input_type, computation_types.at_clients(client_data_type))
def evaluate(server_value, client_data):
distributor_output = stateless_distributor(server_value)
client_metrics = intrinsics.federated_map(client_eval_work,
(distributor_output, client_data))
return stateless_metrics_aggregator(client_metrics)
return evaluate
# Copyright 2021, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_federated.python.core.api import computation_base
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.learning import keras_utils
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.framework import evaluation
# Convenience aliases.
StructType = computation_types.StructType
TensorType = computation_types.TensorType
def keras_model_builder():
# Create a simple linear regression model, single output.
# We initialize all weights to one.
return tf.keras.Sequential([
tf.keras.layers.Dense(
1,
kernel_initializer='ones',
bias_initializer='ones',
input_shape=(1,))
])
def create_dataset():
# Create data satisfying y = 2*x + 1
x = [[1.0], [2.0], [3.0]]
y = [[3.0], [5.0], [7.0]]
return tf.data.Dataset.from_tensor_slices((x, y)).batch(1)
def get_input_spec():
return create_dataset().element_spec
def tff_model_builder():
return keras_utils.from_keras_model(
keras_model=keras_model_builder(),
input_spec=get_input_spec(),
loss=tf.keras.losses.MeanSquaredError(),
metrics=[tf.keras.metrics.MeanSquaredError()])
class BuildEvalWorkTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('default_simulation_loop', False),
('experimental_simulation_loop', True),
)
def test_evaluation_types(self, use_experimental_simulation_loop):
model = tff_model_builder()
model_weights_type = model_utils.weights_type_from_model(model)
client_eval_work = evaluation.build_eval_work(
tff_model_builder, model_weights_type, get_input_spec(),
use_experimental_simulation_loop)
self.assertIsInstance(client_eval_work, computation_base.Computation)
type_signature = client_eval_work.type_signature
self.assertLen(type_signature.parameter, 2)
type_signature.parameter[0].check_assignable_from(model_weights_type)
type_signature.parameter[1].check_assignable_from(
computation_types.SequenceType(get_input_spec()))
@parameterized.named_parameters(
('default_simulation_loop', False),
('experimental_simulation_loop', True),
)
def test_evaluation_on_default_weights(self,
use_experimental_simulation_loop):
model = tff_model_builder()
model_weights_type = model_utils.weights_type_from_model(model)
model_weights = model_utils.ModelWeights.from_model(model)
client_eval_work = evaluation.build_eval_work(
tff_model_builder, model_weights_type, get_input_spec(),
use_experimental_simulation_loop)
# All weights are set to 1, so the model outputs f(x) = x + 1.
eval_metrics = client_eval_work(model_weights, create_dataset())
self.assertCountEqual(eval_metrics.keys(),
['local_outputs', 'num_examples'])
self.assertEqual(eval_metrics['num_examples'], 3)
local_outputs = eval_metrics['local_outputs']
self.assertCountEqual(local_outputs.keys(), ['loss', 'mean_squared_error'])
self.assertEqual(local_outputs['loss'], local_outputs['mean_squared_error'])
expected_loss_sum = (3.0 - 2.0)**2 + (5.0 - 3.0)**2 + (7.0 - 4.0)**2
self.assertAllClose(
local_outputs['loss'], [expected_loss_sum, 3.0], atol=1e-6)
def test_evaluation_on_input_weights(self):
model = tff_model_builder()
model_weights_type = model_utils.weights_type_from_model(model)
model_weights = model_utils.ModelWeights.from_model(model)
zero_weights = tf.nest.map_structure(tf.zeros_like, model_weights)
tf.nest.map_structure(lambda v, t: v.assign(t), model_weights, zero_weights)
client_eval_work = evaluation.build_eval_work(tff_model_builder,
model_weights_type,
get_input_spec())
# We compute metrics where all weights are set to 0, so the model should
# output f(x) = 0.
eval_metrics = client_eval_work(model_weights, create_dataset())
self.assertCountEqual(eval_metrics.keys(),
['local_outputs', 'num_examples'])
self.assertEqual(eval_metrics['num_examples'], 3)
local_outputs = eval_metrics['local_outputs']
self.assertCountEqual(local_outputs.keys(), ['loss', 'mean_squared_error'])
self.assertEqual(local_outputs['loss'], local_outputs['mean_squared_error'])
expected_loss_sum = 9.0 + 25.0 + 49.0
self.assertAllClose(
local_outputs['loss'], [expected_loss_sum, 3.0], atol=1e-6)
class BuildModelMetricsAggregatorTest(tf.test.TestCase):
def _get_metrics_type(self):
return StructType([
('local_outputs',
StructType([
('mean_squared_error', (TensorType(tf.float32),
TensorType(tf.float32))),
('loss', (TensorType(tf.float32), TensorType(tf.float32))),
])),
('num_examples', TensorType(tf.float32)),
])
def _get_aggregated_metrics_type(self):
return StructType([
('eval',
StructType([
('mean_squared_error', TensorType(tf.float32)),
('loss', TensorType(tf.float32)),
])),
('stat', StructType([
('num_examples', TensorType(tf.float32)),
])),
])
def test_metrics_aggregator_types(self):
model = tff_model_builder()
metrics_type = self._get_metrics_type()
model_metrics_aggregator = evaluation.build_model_metrics_aggregator(
model, metrics_type)
self.assertIsInstance(model_metrics_aggregator,
computation_base.Computation)
aggregator_parameter = model_metrics_aggregator.type_signature.parameter
aggregator_parameter.check_assignable_from(
computation_types.at_clients(metrics_type))
aggregator_result = model_metrics_aggregator.type_signature.result
aggregator_result.check_assignable_from(
computation_types.at_server(self._get_aggregated_metrics_type()))
def test_metrics_aggregator_correctness_with_one_client(self):
client_metrics = collections.OrderedDict(
local_outputs=collections.OrderedDict(
mean_squared_error=(4.0, 2.0), loss=(5.0, 1.0)),
num_examples=10.0)
model = tff_model_builder()
metrics_type = self._get_metrics_type()
model_metrics_aggregator = evaluation.build_model_metrics_aggregator(
model, metrics_type)
aggregate_metrics = model_metrics_aggregator([client_metrics])
expected_metrics = collections.OrderedDict(
eval=model.federated_output_computation(
[client_metrics['local_outputs']]),
stat=collections.OrderedDict(num_examples=10.0))
self.assertAllClose(aggregate_metrics, expected_metrics, atol=1e-6)
def test_metrics_aggregator_correctness_with_three_client(self):
client_metrics1 = collections.OrderedDict(
local_outputs=collections.OrderedDict(
mean_squared_error=(4.0, 2.0), loss=(5.0, 1.0)),
num_examples=10.0)
client_metrics2 = collections.OrderedDict(
local_outputs=collections.OrderedDict(
mean_squared_error=(4.0, 4.0), loss=(1.0, 5.0)),
num_examples=7.0)
client_metrics3 = collections.OrderedDict(
local_outputs=collections.OrderedDict(
mean_squared_error=(6.0, 2.0), loss=(5.0, 5.0)),
num_examples=3.0)
model = tff_model_builder()
metrics_type = self._get_metrics_type()
model_metrics_aggregator = evaluation.build_model_metrics_aggregator(
model, metrics_type)
federated_metrics = [client_metrics1, client_metrics2, client_metrics3]
federated_local_outputs = [x['local_outputs'] for x in federated_metrics]
aggregate_metrics = model_metrics_aggregator(federated_metrics)
expected_metrics = collections.OrderedDict(
eval=model.federated_output_computation(federated_local_outputs),
stat=collections.OrderedDict(num_examples=20.0))
self.assertAllClose(aggregate_metrics, expected_metrics, atol=1e-6)
class EvalComposerTest(tf.test.TestCase):
def create_test_distributor(self):
@computations.federated_computation(computation_types.at_server(tf.float32))
def basic_distribute(x):
return intrinsics.federated_broadcast(x)
return basic_distribute
def create_test_client_work(self):
@tf.function
def multiply_and_add(x, dataset):
total_sum = 0.0
for a in dataset:
total_sum = total_sum + x * a
return total_sum
@computations.tf_computation(tf.float32,
computation_types.SequenceType(tf.float32))
def basic_client_work(x, dataset):
return multiply_and_add(x, dataset)
return basic_client_work
def create_test_aggregator(self):
@computations.federated_computation(
computation_types.at_clients(tf.float32))
def basic_aggregate(x):
return intrinsics.federated_sum(x)
return basic_aggregate
def test_basic_composition_has_expected_types(self):
eval_computation = evaluation.compose_eval_computation(
self.create_test_distributor(), self.create_test_client_work(),
self.create_test_aggregator())
expected_parameter = computation_types.StructType([
computation_types.at_server(tf.float32),
computation_types.at_clients(
computation_types.SequenceType(tf.float32))
])
eval_computation.type_signature.parameter.check_assignable_from(
expected_parameter)
expected_result = computation_types.at_server(tf.float32)
eval_computation.type_signature.result.check_assignable_from(
expected_result)
def test_basic_composition_computes_expected_value(self):
eval_computation = evaluation.compose_eval_computation(
self.create_test_distributor(), self.create_test_client_work(),
self.create_test_aggregator())
client_data = [[1.0, 2.0, 3.0], [-1.0, -2.0, -5.0]]
actual_result = eval_computation(1.0, client_data)
self.assertEqual(actual_result, -2.0)
def test_basic_composition_with_struct_type(self):
distributor_struct = computation_types.at_server(StructType([tf.float32]))
@computations.federated_computation(distributor_struct)
def distributor_with_struct_parameter(x):
return intrinsics.federated_broadcast(x[0])
eval_computation = evaluation.compose_eval_computation(
distributor_with_struct_parameter, self.create_test_client_work(),
self.create_test_aggregator())
expected_parameter = computation_types.StructType([
distributor_struct,
computation_types.at_clients(
computation_types.SequenceType(tf.float32))
])
eval_computation.type_signature.parameter.check_assignable_from(
expected_parameter)
expected_result = computation_types.at_server(tf.float32)
eval_computation.type_signature.result.check_assignable_from(
expected_result)
def test_raises_on_python_callable_distributor(self):
def python_distributor(x):
return x
with self.assertRaises(TypeError):
evaluation.compose_eval_computation(python_distributor,