提交 00a23d7b 编辑于 作者: Zachary Charles's avatar Zachary Charles 提交者: tensorflow-copybara
浏览文件

Create a library containing client and server measurements suitable for...

Create a library containing client and server measurements suitable for debugging learning processes, and are intended for use with tff.aggregators.add_measurements.

PiperOrigin-RevId: 393843919
上级 4bd1be79
......@@ -124,6 +124,44 @@ py_test(
],
)
py_library(
name = "dataset_reduce",
srcs = ["dataset_reduce.py"],
srcs_version = "PY3",
)
py_cpu_gpu_test(
name = "dataset_reduce_test",
srcs = ["dataset_reduce_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":dataset_reduce"],
)
py_library(
name = "debug_measurements",
srcs = ["debug_measurements.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:placements",
],
)
py_test(
name = "debug_measurements_test",
srcs = ["debug_measurements_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":debug_measurements",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/types:computation_types",
],
)
py_library(
name = "distributors",
srcs = ["distributors.py"],
......@@ -194,6 +232,42 @@ py_cpu_gpu_test(
],
)
py_library(
name = "evaluation",
srcs = ["evaluation.py"],
srcs_version = "PY3",
deps = [
":dataset_reduce",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
],
)
py_test(
name = "evaluation_test",
srcs = ["evaluation_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":evaluation",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/learning:keras_utils",
"//tensorflow_federated/python/learning:model_utils",
],
)
py_library(
name = "finalizers",
srcs = ["finalizers.py"],
......@@ -287,53 +361,3 @@ py_cpu_gpu_test(
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
py_library(
name = "dataset_reduce",
srcs = ["dataset_reduce.py"],
srcs_version = "PY3",
)
py_cpu_gpu_test(
name = "dataset_reduce_test",
srcs = ["dataset_reduce_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":dataset_reduce"],
)
py_library(
name = "evaluation",
srcs = ["evaluation.py"],
srcs_version = "PY3",
deps = [
":dataset_reduce",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
],
)
py_test(
name = "evaluation_test",
srcs = ["evaluation_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":evaluation",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/learning:keras_utils",
"//tensorflow_federated/python/learning:model_utils",
],
)
# Copyright 2021, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library of aggregator measurements useful for debugging learning processes."""
import collections
import tensorflow as tf
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import placements
@computations.tf_computation
def calculate_global_norm(tensor_struct):
"""Calculate the Euclidean norm of a nested structure of tensors."""
return tf.linalg.global_norm(tf.nest.flatten(tensor_struct))
@computations.tf_computation
def square_value(tensor_value):
"""Computes the square of a tensor."""
return tensor_value**2
@computations.tf_computation
def calculate_server_update_statistics(server_update):
"""Calculate the L2 norm, and the max and min values of a server update."""
flattened_struct = tf.nest.flatten(server_update)
max_value = tf.math.reduce_max(
tf.nest.map_structure(tf.math.reduce_max, flattened_struct))
min_value = tf.math.reduce_min(
tf.nest.map_structure(tf.math.reduce_min, flattened_struct))
global_norm = tf.linalg.global_norm(flattened_struct)
return collections.OrderedDict(
server_update_max=max_value,
server_update_norm=global_norm,
server_update_min=min_value)
@computations.tf_computation
def calculate_unbiased_std_dev(expected_value, expected_squared_value,
sum_of_weights, sum_of_squared_weights):
"""Calculate the standard_deviation of a discrete distribution.
Here, we assume that we have some distribution that takes on values `x_1` up
through `x_n` with probabilities `w_1, ..., w_n`. We compute the standard
deviation of this distribution, relative to the unbiased variance.
This involves multipying the biased variance by a correction factor involving
sums of weights and weights squared. If `a` is the sum of the `w_i` and `b` is
the sum of the `w_i**2`, then the correction factor for the variance is
`a**2/(a**2-b)`. Note that when the weights are all equal, this reduces to the
standard Bessel correction factor of `n/(n-1)`. We then take a square root to
get the standard deviation.
Args:
expected_value: A float representing the weighted mean of the distribution.
expected_squared_value: A float representing the expected square value of
the distribution.
sum_of_weights: A float representing the sum of weights in the distribution.
sum_of_squared_weights: A float representing the sum of the squared weights
in the distribution.
Returns:
A float representing the standard deviation with respect to the unbiased
variance.
"""
biased_variance = expected_squared_value - expected_value**2
correction_factor = tf.math.divide_no_nan(
sum_of_weights**2, sum_of_weights**2 - sum_of_squared_weights)
return tf.math.sqrt(correction_factor * biased_variance)
def calculate_client_update_statistics(client_updates, client_weights):
"""Calculate the average and standard deviation of client updates."""
client_norms = intrinsics.federated_map(calculate_global_norm, client_updates)
client_norms_squared = intrinsics.federated_map(square_value, client_norms)
average_client_norm = intrinsics.federated_mean(client_norms, client_weights)
average_client_norm_squared = intrinsics.federated_mean(
client_norms_squared, client_weights)
# TODO(b/197972289): Add SecAgg compatibility to these measurements
sum_of_client_weights = intrinsics.federated_sum(client_weights)
client_weights_squared = intrinsics.federated_map(square_value,
client_weights)
sum_of_client_weights_squared = intrinsics.federated_sum(
client_weights_squared)
unbiased_std_dev = intrinsics.federated_map(
calculate_unbiased_std_dev,
(average_client_norm, average_client_norm_squared, sum_of_client_weights,
sum_of_client_weights_squared))
return intrinsics.federated_zip(
collections.OrderedDict(
average_client_norm=average_client_norm,
std_dev_client_norm=unbiased_std_dev))
def build_aggregator_measurement_fns(weighted_aggregator: bool = True):
"""Create measurement functions suitable for debugging learning processes.
These functions are intended for use with `tff.aggregators.add_measurements`.
This function creates client and server measurements functions. The client
measurement function computes:
* The (weighted) average Euclidean norm of client updates.
* The (weighted) standard deviation of these norms.
The standard deviation we report is the square root of the **unbiased**
variance. The server measurement function computes:
* The maximum entry of the aggregate client update.
* The Euclidean norm of the aggregate client update.
* The minimum entry of the aggregate client update.
Note that the `client_measurement_fn` will either have input arguments
`(client_value, client_weight)` or `client_value`, depending on whether
`weighted_aggregator = True` or `False`, respectively. The
`server_measurement_fn` will have input argument `server_value`.
Args:
weighted_aggregator: A boolean indicating whether the client measurement
function is intended for use with weighted aggregators (`True`) or not
(`False`).
Returns:
A tuple `(client_measurement_fn, server_measurement_fn)` of Python callables
matching the docstring above.
"""
if weighted_aggregator:
client_measurement_fn = calculate_client_update_statistics
else:
def client_measurement_fn(value):
client_weights = intrinsics.federated_value(1.0, placements.CLIENTS)
return calculate_client_update_statistics(value, client_weights)
def server_measurement_fn(value):
server_measurements = intrinsics.federated_map(
calculate_server_update_statistics, value)
return server_measurements
return client_measurement_fn, server_measurement_fn
# Copyright 2021, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.learning.framework import debug_measurements
TensorType = computation_types.TensorType
FloatType = TensorType(tf.float32)
FloatAtServer = computation_types.at_server(FloatType)
FloatAtClients = computation_types.at_clients(FloatType)
SERVER_MEASUREMENTS_OUTPUT_TYPE = computation_types.at_server(
collections.OrderedDict([
('server_update_max', FloatType),
('server_update_norm', FloatType),
('server_update_min', FloatType),
]))
CLIENT_MEASUREMENTS_OUTPUT_TYPE = computation_types.at_server(
collections.OrderedDict([
('average_client_norm', FloatType),
('std_dev_client_norm', FloatType),
]))
class DebugMeasurementsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('scalar_type', FloatType),
('vector_type', TensorType(tf.float32, [3])),
('struct_type', [FloatType, FloatType]),
('nested_struct_type', [
[TensorType(tf.float32, [3])],
[FloatType, FloatType],
]),
)
def test_server_measurement_fn_traceable_by_federated_computation(
self, value_type):
_, server_measurement_fn = (
debug_measurements.build_aggregator_measurement_fns())
input_type = computation_types.at_server(value_type)
@computations.federated_computation(input_type)
def get_server_measurements(server_update):
return server_measurement_fn(server_update)
type_signature = get_server_measurements.type_signature
type_signature.parameter.check_assignable_from(input_type)
type_signature.result.check_assignable_from(SERVER_MEASUREMENTS_OUTPUT_TYPE)
@parameterized.named_parameters(
('scalar_type', FloatType),
('vector_type', TensorType(tf.float32, [3])),
('struct_type', [FloatType, FloatType]),
('nested_struct_type', [
[TensorType(tf.float32, [3])],
[FloatType, FloatType],
]),
)
def test_unweighted_client_measurement_fn_traceable_by_federated_computation(
self, value_type):
client_measurement_fn, _ = debug_measurements.build_aggregator_measurement_fns(
weighted_aggregator=False)
input_type = computation_types.at_clients(value_type)
@computations.federated_computation(input_type)
def get_client_measurements(client_update):
return client_measurement_fn(client_update)
type_signature = get_client_measurements.type_signature
type_signature.parameter.check_assignable_from(input_type)
type_signature.result.check_assignable_from(CLIENT_MEASUREMENTS_OUTPUT_TYPE)
@parameterized.named_parameters(
('scalar_type', FloatType),
('vector_type', TensorType(tf.float32, [3])),
('struct_type', [FloatType, FloatType]),
('nested_struct_type', [
[TensorType(tf.float32, [3])],
[FloatType, FloatType],
]),
)
def test_weighted_client_measurement_fn_traceable_by_federated_computation(
self, value_type):
client_measurement_fn, _ = debug_measurements.build_aggregator_measurement_fns(
weighted_aggregator=True)
input_type = computation_types.at_clients(value_type)
weights_type = computation_types.at_clients(tf.float32)
@computations.federated_computation(input_type, weights_type)
def get_client_measurements(client_update, client_weights):
return client_measurement_fn(client_update, client_weights)
type_signature = get_client_measurements.type_signature
type_signature.parameter[0].check_assignable_from(input_type)
type_signature.parameter[1].check_assignable_from(weights_type)
type_signature.result.check_assignable_from(CLIENT_MEASUREMENTS_OUTPUT_TYPE)
@parameterized.named_parameters(
('server_update1', [-3.0, 4.0, 0.0], 4.0, 5.0, -3.0),
('server_update2', [0.0], 0.0, 0.0, 0.0),
('server_update3', {
'a': tf.constant([1.0, -1.0]),
'b': tf.constant(2.0),
}, 2.0, tf.math.sqrt(6.0), -1.0),
)
def test_correctness_of_server_update_statistics(self, server_update,
expected_max, expected_norm,
expected_min):
actual_server_statistics = debug_measurements.calculate_server_update_statistics(
server_update)
expected_server_statistics = collections.OrderedDict(
server_update_max=expected_max,
server_update_norm=expected_norm,
server_update_min=expected_min)
self.assertAllClose(actual_server_statistics, expected_server_statistics)
@parameterized.named_parameters(
('distribution1', [1.0, 3.0, 0.0]),
('distribution2', [-1.0]),
('distribution3', [2.0, 2.0, 2.0]),
)
def test_correctness_unbiased_std_dev_unweighted(self, distribution):
n = tf.cast(len(distribution), dtype=tf.float32)
expected_value = tf.math.reduce_mean(distribution)
expected_value_squared = tf.math.reduce_mean(tf.constant(distribution)**2)
unbiased_std_dev = debug_measurements.calculate_unbiased_std_dev(
expected_value, expected_value_squared, n, n)
biased_std_dev = tf.math.reduce_std(distribution)
correct_unbiased_std_dev = tf.math.sqrt(tf.math.divide_no_nan(
n, n - 1)) * biased_std_dev
self.assertNear(unbiased_std_dev, correct_unbiased_std_dev, 1e-6)
@parameterized.named_parameters(
('client_updates1', [1.0, -2.0, 5.0]),
('client_updates2', [7.0]),
('client_updates3', [2.0, 2.0, 2.0]),
)
def test_correctness_of_unweighted_client_update_statistics(
self, client_updates):
client_weights = [1.0 for _ in client_updates]
@computations.federated_computation(
computation_types.at_clients(tf.float32),
computation_types.at_clients(tf.float32))
def compute_client_statistics(client_updates, client_weights):
return debug_measurements.calculate_client_update_statistics(
client_updates, client_weights)
actual_client_statistics = compute_client_statistics(
client_updates, client_weights)
client_norms = [tf.math.abs(a) for a in client_updates]
expected_average_norm = tf.math.reduce_mean(client_norms)
num_clients = tf.cast(len(client_updates), tf.float32)
expected_std_dev = tf.math.reduce_std(client_norms) * tf.math.sqrt(
tf.math.divide_no_nan(num_clients, num_clients - 1))
expected_client_statistics = collections.OrderedDict(
average_client_norm=expected_average_norm,
std_dev_client_norm=expected_std_dev)
self.assertAllClose(actual_client_statistics, expected_client_statistics)
@parameterized.named_parameters(
('distribution1', [1.0, 3.0, 0.0], [2.0, 3.0, 1.0]),
('distribution2', [-1.0], [5.0]),
('distribution3', [2.0, -2.0, 2.0], [6.0, 7.0, 4.0]),
('distribution4', [1.0, 2.0, -3.0], [1.0, 1.0, 0.0]),
)
def test_correctness_of_weighted_client_update_statistics(
self, client_updates, client_weights):
@computations.federated_computation(
computation_types.at_clients(tf.float32),
computation_types.at_clients(tf.float32))
def compute_client_statistics(client_updates, client_weights):
return debug_measurements.calculate_client_update_statistics(
client_updates, client_weights)
actual_client_statistics = compute_client_statistics(
client_updates, client_weights)
client_updates = tf.constant(client_updates)
client_weights = tf.constant(client_weights)
weights_sum = tf.math.reduce_sum(client_weights)
weights_squared_sum = tf.math.reduce_sum(client_weights**2)
expected_norm = tf.math.divide_no_nan(
tf.math.reduce_sum(tf.math.abs(client_updates) * client_weights),
weights_sum)
expected_norm_squared = tf.math.divide_no_nan(
tf.math.reduce_sum(client_updates**2 * client_weights), weights_sum)
biased_variance = expected_norm_squared - expected_norm**2
unbiased_variance = tf.math.divide_no_nan(
weights_sum**2, weights_sum**2 - weights_squared_sum) * biased_variance
unbiased_std_dev = tf.math.sqrt(unbiased_variance)
expected_client_statistics = collections.OrderedDict(
average_client_norm=expected_norm, std_dev_client_norm=unbiased_std_dev)
self.assertAllClose(actual_client_statistics, expected_client_statistics)
if __name__ == '__main__':
execution_contexts.set_local_execution_context()
tf.test.main()
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册