提交 ddf8ecae 编辑于 作者: Zachary Charles's avatar Zachary Charles 提交者: tensorflow-copybara
浏览文件

Add an API symbol for adding learning-oriented debug measurements to an aggregator factory.

PiperOrigin-RevId: 394299699
上级 46594ce6
...@@ -28,6 +28,7 @@ py_library( ...@@ -28,6 +28,7 @@ py_library(
visibility = ["//tensorflow_federated:__pkg__"], visibility = ["//tensorflow_federated:__pkg__"],
deps = [ deps = [
":client_weight_lib", ":client_weight_lib",
":debug_measurements",
":federated_averaging", ":federated_averaging",
":federated_evaluation", ":federated_evaluation",
":federated_sgd", ":federated_sgd",
...@@ -50,6 +51,33 @@ py_library( ...@@ -50,6 +51,33 @@ py_library(
srcs_version = "PY3", srcs_version = "PY3",
) )
py_library(
name = "debug_measurements",
srcs = ["debug_measurements.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/aggregators:measurements",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:placements",
],
)
py_test(
name = "debug_measurements_test",
srcs = ["debug_measurements_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":debug_measurements",
"//tensorflow_federated/python/aggregators:mean",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/types:computation_types",
],
)
py_library( py_library(
name = "federated_averaging", name = "federated_averaging",
srcs = ["federated_averaging.py"], srcs = ["federated_averaging.py"],
......
...@@ -18,6 +18,7 @@ from tensorflow_federated.python.learning import models ...@@ -18,6 +18,7 @@ from tensorflow_federated.python.learning import models
from tensorflow_federated.python.learning import optimizers from tensorflow_federated.python.learning import optimizers
from tensorflow_federated.python.learning import reconstruction from tensorflow_federated.python.learning import reconstruction
from tensorflow_federated.python.learning.client_weight_lib import ClientWeighting from tensorflow_federated.python.learning.client_weight_lib import ClientWeighting
from tensorflow_federated.python.learning.debug_measurements import add_debug_measurements
from tensorflow_federated.python.learning.federated_averaging import build_federated_averaging_process from tensorflow_federated.python.learning.federated_averaging import build_federated_averaging_process
from tensorflow_federated.python.learning.federated_averaging import ClientFedAvg from tensorflow_federated.python.learning.federated_averaging import ClientFedAvg
from tensorflow_federated.python.learning.federated_evaluation import build_federated_evaluation from tensorflow_federated.python.learning.federated_evaluation import build_federated_evaluation
......
...@@ -17,6 +17,8 @@ import collections ...@@ -17,6 +17,8 @@ import collections
import tensorflow as tf import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import measurements
from tensorflow_federated.python.core.api import computations from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.impl.federated_context import intrinsics from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.impl.types import placements
...@@ -156,3 +158,58 @@ def build_aggregator_measurement_fns(weighted_aggregator: bool = True): ...@@ -156,3 +158,58 @@ def build_aggregator_measurement_fns(weighted_aggregator: bool = True):
return server_measurements return server_measurements
return client_measurement_fn, server_measurement_fn return client_measurement_fn, server_measurement_fn
def add_debug_measurements(aggregation_factory: factory.AggregationFactory):
"""Adds measurements suitable for debugging learning processes.
This will wrap a `tff.aggregator.AggregationFactory` as a new factory that
will produce additional measurements useful for debugging learning processes.
The underlying aggregation of client values will remain unchanged.
These measurements generally concern the norm of the client updates, and the
norm of the aggregated server update. The implicit weighting will be
determined by `aggregation_factory`: If this is weighted, then the debugging
measurements will use this weighting when computing averages. If it is
unweighted, the debugging measurements will use uniform weighting.
The client measurements are:
* The average Euclidean norm of client updates.
* The standard deviation of these norms.
The standard deviation we report is the square root of the **unbiased**
variance. The server measurements are:
* The maximum entry of the aggregate client update.
* The Euclidean norm of the aggregate client update.
* The minimum entry of the aggregate client update.
In the above, an "entry" means any coordinate across all tensors in the
structure. For example, suppose that we have client structures before
aggregation:
* Client A: `[[-1, -3, -5], [2]]`
* Client B: `[[-1, -3, 1], [0]]`
If we use unweighted averaging, then the aggregate client update will be the
structure `[[-1, -3, -2], [1]]`. The maximum entry is `1`, the minimum entry
is `-3`, and the euclidean norm is `sqrt(15)`.
Args:
aggregation_factory: A `tff.aggregators.AggregationFactory`. Can be weighted
or unweighted.
Returns:
A `tff.aggregators.AggregationFactory`.
"""
is_weighted_aggregator = isinstance(aggregation_factory,
factory.WeightedAggregationFactory)
client_measurement_fn, server_measurement_fn = (
build_aggregator_measurement_fns(
weighted_aggregator=is_weighted_aggregator))
return measurements.add_measurements(
aggregation_factory,
client_measurement_fn=client_measurement_fn,
server_measurement_fn=server_measurement_fn)
...@@ -17,10 +17,11 @@ import collections ...@@ -17,10 +17,11 @@ import collections
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow_federated.python.aggregators import mean
from tensorflow_federated.python.core.api import computations from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.learning.framework import debug_measurements from tensorflow_federated.python.learning import debug_measurements
TensorType = computation_types.TensorType TensorType = computation_types.TensorType
FloatType = TensorType(tf.float32) FloatType = TensorType(tf.float32)
...@@ -216,6 +217,106 @@ class DebugMeasurementsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -216,6 +217,106 @@ class DebugMeasurementsTest(tf.test.TestCase, parameterized.TestCase):
average_client_norm=expected_norm, std_dev_client_norm=unbiased_std_dev) average_client_norm=expected_norm, std_dev_client_norm=unbiased_std_dev)
self.assertAllClose(actual_client_statistics, expected_client_statistics) self.assertAllClose(actual_client_statistics, expected_client_statistics)
def test_add_measurements_to_weighted_aggregation_factory_types(self):
mean_factory = mean.MeanFactory()
debug_mean_factory = debug_measurements.add_debug_measurements(mean_factory)
value_type = computation_types.TensorType(tf.float32)
mean_aggregator = mean_factory.create(value_type, value_type)
debug_aggregator = debug_mean_factory.create(value_type, value_type)
self.assertTrue(debug_aggregator.is_weighted)
self.assertEqual(mean_aggregator.initialize.type_signature,
debug_aggregator.initialize.type_signature)
self.assertEqual(mean_aggregator.next.type_signature.parameter,
debug_aggregator.next.type_signature.parameter)
self.assertEqual(mean_aggregator.next.type_signature.result.state,
debug_aggregator.next.type_signature.result.state)
self.assertEqual(mean_aggregator.next.type_signature.result.result,
debug_aggregator.next.type_signature.result.result)
def test_add_measurements_to_weighted_aggregation_factory_output(self):
mean_factory = mean.MeanFactory()
debug_mean_factory = debug_measurements.add_debug_measurements(mean_factory)
value_type = computation_types.TensorType(tf.float32)
mean_aggregator = mean_factory.create(value_type, value_type)
debug_aggregator = debug_mean_factory.create(value_type, value_type)
state = mean_aggregator.initialize()
mean_output = mean_aggregator.next(state, [2.0, 4.0], [1.0, 1.0])
debug_output = debug_aggregator.next(state, [2.0, 4.0], [1.0, 1.0])
self.assertEqual(mean_output.state, debug_output.state)
self.assertNear(mean_output.result, debug_output.result, err=1e-6)
mean_measurements = mean_output.measurements
expected_debugging_measurements = {
'average_client_norm': 3.0,
'std_dev_client_norm': tf.math.sqrt(2.0),
'server_update_max': 3.0,
'server_update_norm': 3.0,
'server_update_min': 3.0,
}
debugging_measurements = debug_output.measurements
self.assertCountEqual(
list(debugging_measurements.keys()),
list(mean_measurements.keys()) +
list(expected_debugging_measurements.keys()))
for k in mean_output.measurements:
self.assertEqual(mean_measurements[k], debugging_measurements[k])
for k in expected_debugging_measurements:
self.assertNear(
debugging_measurements[k],
expected_debugging_measurements[k],
err=1e-6)
def test_add_measurements_to_unweighted_aggregation_factory_types(self):
mean_factory = mean.UnweightedMeanFactory()
debug_mean_factory = debug_measurements.add_debug_measurements(mean_factory)
value_type = computation_types.TensorType(tf.float32)
mean_aggregator = mean_factory.create(value_type)
debug_aggregator = debug_mean_factory.create(value_type)
self.assertFalse(debug_aggregator.is_weighted)
self.assertEqual(mean_aggregator.initialize.type_signature,
debug_aggregator.initialize.type_signature)
self.assertEqual(mean_aggregator.next.type_signature.parameter,
debug_aggregator.next.type_signature.parameter)
self.assertEqual(mean_aggregator.next.type_signature.result.state,
debug_aggregator.next.type_signature.result.state)
self.assertEqual(mean_aggregator.next.type_signature.result.result,
debug_aggregator.next.type_signature.result.result)
def test_add_measurements_to_unweighted_aggregation_factory_output(self):
mean_factory = mean.UnweightedMeanFactory()
debug_mean_factory = debug_measurements.add_debug_measurements(mean_factory)
value_type = computation_types.TensorType(tf.float32)
mean_aggregator = mean_factory.create(value_type)
debug_aggregator = debug_mean_factory.create(value_type)
state = mean_aggregator.initialize()
mean_output = mean_aggregator.next(state, [2.0, 4.0])
debug_output = debug_aggregator.next(state, [2.0, 4.0])
self.assertEqual(mean_output.state, debug_output.state)
self.assertNear(mean_output.result, debug_output.result, err=1e-6)
mean_measurements = mean_output.measurements
expected_debugging_measurements = {
'average_client_norm': 3.0,
'std_dev_client_norm': tf.math.sqrt(2.0),
'server_update_max': 3.0,
'server_update_norm': 3.0,
'server_update_min': 3.0,
}
debugging_measurements = debug_output.measurements
self.assertCountEqual(
list(debugging_measurements.keys()),
list(mean_measurements.keys()) +
list(expected_debugging_measurements.keys()))
for k in mean_output.measurements:
self.assertEqual(mean_measurements[k], debugging_measurements[k])
for k in expected_debugging_measurements:
self.assertNear(
debugging_measurements[k],
expected_debugging_measurements[k],
err=1e-6)
if __name__ == '__main__': if __name__ == '__main__':
execution_contexts.set_local_execution_context() execution_contexts.set_local_execution_context()
......
...@@ -138,30 +138,6 @@ py_cpu_gpu_test( ...@@ -138,30 +138,6 @@ py_cpu_gpu_test(
deps = [":dataset_reduce"], deps = [":dataset_reduce"],
) )
py_library(
name = "debug_measurements",
srcs = ["debug_measurements.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:placements",
],
)
py_test(
name = "debug_measurements_test",
srcs = ["debug_measurements_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":debug_measurements",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/types:computation_types",
],
)
py_library( py_library(
name = "distributors", name = "distributors",
srcs = ["distributors.py"], srcs = ["distributors.py"],
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册