Skip to content
GitLab
菜单
项目
群组
代码片段
/
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
ddf8ecae
提交
ddf8ecae
编辑于
9月 01, 2021
作者:
Zachary Charles
提交者:
tensorflow-copybara
9月 01, 2021
浏览文件
Add an API symbol for adding learning-oriented debug measurements to an aggregator factory.
PiperOrigin-RevId: 394299699
上级
46594ce6
变更
5
Show whitespace changes
Inline
Side-by-side
tensorflow_federated/python/learning/BUILD
浏览文件 @
ddf8ecae
...
...
@@ -28,6 +28,7 @@ py_library(
visibility
=
[
"//tensorflow_federated:__pkg__"
],
deps
=
[
":client_weight_lib"
,
":debug_measurements"
,
":federated_averaging"
,
":federated_evaluation"
,
":federated_sgd"
,
...
...
@@ -50,6 +51,33 @@ py_library(
srcs_version
=
"PY3"
,
)
py_library
(
name
=
"debug_measurements"
,
srcs
=
[
"debug_measurements.py"
],
srcs_version
=
"PY3"
,
deps
=
[
"//tensorflow_federated/python/aggregators:factory"
,
"//tensorflow_federated/python/aggregators:measurements"
,
"//tensorflow_federated/python/core/api:computations"
,
"//tensorflow_federated/python/core/impl/federated_context:intrinsics"
,
"//tensorflow_federated/python/core/impl/types:placements"
,
],
)
py_test
(
name
=
"debug_measurements_test"
,
srcs
=
[
"debug_measurements_test.py"
],
python_version
=
"PY3"
,
srcs_version
=
"PY3"
,
deps
=
[
":debug_measurements"
,
"//tensorflow_federated/python/aggregators:mean"
,
"//tensorflow_federated/python/core/api:computations"
,
"//tensorflow_federated/python/core/backends/native:execution_contexts"
,
"//tensorflow_federated/python/core/impl/types:computation_types"
,
],
)
py_library
(
name
=
"federated_averaging"
,
srcs
=
[
"federated_averaging.py"
],
...
...
tensorflow_federated/python/learning/__init__.py
浏览文件 @
ddf8ecae
...
...
@@ -18,6 +18,7 @@ from tensorflow_federated.python.learning import models
from
tensorflow_federated.python.learning
import
optimizers
from
tensorflow_federated.python.learning
import
reconstruction
from
tensorflow_federated.python.learning.client_weight_lib
import
ClientWeighting
from
tensorflow_federated.python.learning.debug_measurements
import
add_debug_measurements
from
tensorflow_federated.python.learning.federated_averaging
import
build_federated_averaging_process
from
tensorflow_federated.python.learning.federated_averaging
import
ClientFedAvg
from
tensorflow_federated.python.learning.federated_evaluation
import
build_federated_evaluation
...
...
tensorflow_federated/python/learning/
framework/
debug_measurements.py
→
tensorflow_federated/python/learning/debug_measurements.py
浏览文件 @
ddf8ecae
...
...
@@ -17,6 +17,8 @@ import collections
import
tensorflow
as
tf
from
tensorflow_federated.python.aggregators
import
factory
from
tensorflow_federated.python.aggregators
import
measurements
from
tensorflow_federated.python.core.api
import
computations
from
tensorflow_federated.python.core.impl.federated_context
import
intrinsics
from
tensorflow_federated.python.core.impl.types
import
placements
...
...
@@ -156,3 +158,58 @@ def build_aggregator_measurement_fns(weighted_aggregator: bool = True):
return
server_measurements
return
client_measurement_fn
,
server_measurement_fn
def
add_debug_measurements
(
aggregation_factory
:
factory
.
AggregationFactory
):
"""Adds measurements suitable for debugging learning processes.
This will wrap a `tff.aggregator.AggregationFactory` as a new factory that
will produce additional measurements useful for debugging learning processes.
The underlying aggregation of client values will remain unchanged.
These measurements generally concern the norm of the client updates, and the
norm of the aggregated server update. The implicit weighting will be
determined by `aggregation_factory`: If this is weighted, then the debugging
measurements will use this weighting when computing averages. If it is
unweighted, the debugging measurements will use uniform weighting.
The client measurements are:
* The average Euclidean norm of client updates.
* The standard deviation of these norms.
The standard deviation we report is the square root of the **unbiased**
variance. The server measurements are:
* The maximum entry of the aggregate client update.
* The Euclidean norm of the aggregate client update.
* The minimum entry of the aggregate client update.
In the above, an "entry" means any coordinate across all tensors in the
structure. For example, suppose that we have client structures before
aggregation:
* Client A: `[[-1, -3, -5], [2]]`
* Client B: `[[-1, -3, 1], [0]]`
If we use unweighted averaging, then the aggregate client update will be the
structure `[[-1, -3, -2], [1]]`. The maximum entry is `1`, the minimum entry
is `-3`, and the euclidean norm is `sqrt(15)`.
Args:
aggregation_factory: A `tff.aggregators.AggregationFactory`. Can be weighted
or unweighted.
Returns:
A `tff.aggregators.AggregationFactory`.
"""
is_weighted_aggregator
=
isinstance
(
aggregation_factory
,
factory
.
WeightedAggregationFactory
)
client_measurement_fn
,
server_measurement_fn
=
(
build_aggregator_measurement_fns
(
weighted_aggregator
=
is_weighted_aggregator
))
return
measurements
.
add_measurements
(
aggregation_factory
,
client_measurement_fn
=
client_measurement_fn
,
server_measurement_fn
=
server_measurement_fn
)
tensorflow_federated/python/learning/
framework/
debug_measurements_test.py
→
tensorflow_federated/python/learning/debug_measurements_test.py
浏览文件 @
ddf8ecae
...
...
@@ -17,10 +17,11 @@ import collections
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow_federated.python.aggregators
import
mean
from
tensorflow_federated.python.core.api
import
computations
from
tensorflow_federated.python.core.backends.native
import
execution_contexts
from
tensorflow_federated.python.core.impl.types
import
computation_types
from
tensorflow_federated.python.learning
.framework
import
debug_measurements
from
tensorflow_federated.python.learning
import
debug_measurements
TensorType
=
computation_types
.
TensorType
FloatType
=
TensorType
(
tf
.
float32
)
...
...
@@ -216,6 +217,106 @@ class DebugMeasurementsTest(tf.test.TestCase, parameterized.TestCase):
average_client_norm
=
expected_norm
,
std_dev_client_norm
=
unbiased_std_dev
)
self
.
assertAllClose
(
actual_client_statistics
,
expected_client_statistics
)
def
test_add_measurements_to_weighted_aggregation_factory_types
(
self
):
mean_factory
=
mean
.
MeanFactory
()
debug_mean_factory
=
debug_measurements
.
add_debug_measurements
(
mean_factory
)
value_type
=
computation_types
.
TensorType
(
tf
.
float32
)
mean_aggregator
=
mean_factory
.
create
(
value_type
,
value_type
)
debug_aggregator
=
debug_mean_factory
.
create
(
value_type
,
value_type
)
self
.
assertTrue
(
debug_aggregator
.
is_weighted
)
self
.
assertEqual
(
mean_aggregator
.
initialize
.
type_signature
,
debug_aggregator
.
initialize
.
type_signature
)
self
.
assertEqual
(
mean_aggregator
.
next
.
type_signature
.
parameter
,
debug_aggregator
.
next
.
type_signature
.
parameter
)
self
.
assertEqual
(
mean_aggregator
.
next
.
type_signature
.
result
.
state
,
debug_aggregator
.
next
.
type_signature
.
result
.
state
)
self
.
assertEqual
(
mean_aggregator
.
next
.
type_signature
.
result
.
result
,
debug_aggregator
.
next
.
type_signature
.
result
.
result
)
def
test_add_measurements_to_weighted_aggregation_factory_output
(
self
):
mean_factory
=
mean
.
MeanFactory
()
debug_mean_factory
=
debug_measurements
.
add_debug_measurements
(
mean_factory
)
value_type
=
computation_types
.
TensorType
(
tf
.
float32
)
mean_aggregator
=
mean_factory
.
create
(
value_type
,
value_type
)
debug_aggregator
=
debug_mean_factory
.
create
(
value_type
,
value_type
)
state
=
mean_aggregator
.
initialize
()
mean_output
=
mean_aggregator
.
next
(
state
,
[
2.0
,
4.0
],
[
1.0
,
1.0
])
debug_output
=
debug_aggregator
.
next
(
state
,
[
2.0
,
4.0
],
[
1.0
,
1.0
])
self
.
assertEqual
(
mean_output
.
state
,
debug_output
.
state
)
self
.
assertNear
(
mean_output
.
result
,
debug_output
.
result
,
err
=
1e-6
)
mean_measurements
=
mean_output
.
measurements
expected_debugging_measurements
=
{
'average_client_norm'
:
3.0
,
'std_dev_client_norm'
:
tf
.
math
.
sqrt
(
2.0
),
'server_update_max'
:
3.0
,
'server_update_norm'
:
3.0
,
'server_update_min'
:
3.0
,
}
debugging_measurements
=
debug_output
.
measurements
self
.
assertCountEqual
(
list
(
debugging_measurements
.
keys
()),
list
(
mean_measurements
.
keys
())
+
list
(
expected_debugging_measurements
.
keys
()))
for
k
in
mean_output
.
measurements
:
self
.
assertEqual
(
mean_measurements
[
k
],
debugging_measurements
[
k
])
for
k
in
expected_debugging_measurements
:
self
.
assertNear
(
debugging_measurements
[
k
],
expected_debugging_measurements
[
k
],
err
=
1e-6
)
def
test_add_measurements_to_unweighted_aggregation_factory_types
(
self
):
mean_factory
=
mean
.
UnweightedMeanFactory
()
debug_mean_factory
=
debug_measurements
.
add_debug_measurements
(
mean_factory
)
value_type
=
computation_types
.
TensorType
(
tf
.
float32
)
mean_aggregator
=
mean_factory
.
create
(
value_type
)
debug_aggregator
=
debug_mean_factory
.
create
(
value_type
)
self
.
assertFalse
(
debug_aggregator
.
is_weighted
)
self
.
assertEqual
(
mean_aggregator
.
initialize
.
type_signature
,
debug_aggregator
.
initialize
.
type_signature
)
self
.
assertEqual
(
mean_aggregator
.
next
.
type_signature
.
parameter
,
debug_aggregator
.
next
.
type_signature
.
parameter
)
self
.
assertEqual
(
mean_aggregator
.
next
.
type_signature
.
result
.
state
,
debug_aggregator
.
next
.
type_signature
.
result
.
state
)
self
.
assertEqual
(
mean_aggregator
.
next
.
type_signature
.
result
.
result
,
debug_aggregator
.
next
.
type_signature
.
result
.
result
)
def
test_add_measurements_to_unweighted_aggregation_factory_output
(
self
):
mean_factory
=
mean
.
UnweightedMeanFactory
()
debug_mean_factory
=
debug_measurements
.
add_debug_measurements
(
mean_factory
)
value_type
=
computation_types
.
TensorType
(
tf
.
float32
)
mean_aggregator
=
mean_factory
.
create
(
value_type
)
debug_aggregator
=
debug_mean_factory
.
create
(
value_type
)
state
=
mean_aggregator
.
initialize
()
mean_output
=
mean_aggregator
.
next
(
state
,
[
2.0
,
4.0
])
debug_output
=
debug_aggregator
.
next
(
state
,
[
2.0
,
4.0
])
self
.
assertEqual
(
mean_output
.
state
,
debug_output
.
state
)
self
.
assertNear
(
mean_output
.
result
,
debug_output
.
result
,
err
=
1e-6
)
mean_measurements
=
mean_output
.
measurements
expected_debugging_measurements
=
{
'average_client_norm'
:
3.0
,
'std_dev_client_norm'
:
tf
.
math
.
sqrt
(
2.0
),
'server_update_max'
:
3.0
,
'server_update_norm'
:
3.0
,
'server_update_min'
:
3.0
,
}
debugging_measurements
=
debug_output
.
measurements
self
.
assertCountEqual
(
list
(
debugging_measurements
.
keys
()),
list
(
mean_measurements
.
keys
())
+
list
(
expected_debugging_measurements
.
keys
()))
for
k
in
mean_output
.
measurements
:
self
.
assertEqual
(
mean_measurements
[
k
],
debugging_measurements
[
k
])
for
k
in
expected_debugging_measurements
:
self
.
assertNear
(
debugging_measurements
[
k
],
expected_debugging_measurements
[
k
],
err
=
1e-6
)
if
__name__
==
'__main__'
:
execution_contexts
.
set_local_execution_context
()
...
...
tensorflow_federated/python/learning/framework/BUILD
浏览文件 @
ddf8ecae
...
...
@@ -138,30 +138,6 @@ py_cpu_gpu_test(
deps
=
[
":dataset_reduce"
],
)
py_library
(
name
=
"debug_measurements"
,
srcs
=
[
"debug_measurements.py"
],
srcs_version
=
"PY3"
,
deps
=
[
"//tensorflow_federated/python/core/api:computations"
,
"//tensorflow_federated/python/core/impl/federated_context:intrinsics"
,
"//tensorflow_federated/python/core/impl/types:placements"
,
],
)
py_test
(
name
=
"debug_measurements_test"
,
srcs
=
[
"debug_measurements_test.py"
],
python_version
=
"PY3"
,
srcs_version
=
"PY3"
,
deps
=
[
":debug_measurements"
,
"//tensorflow_federated/python/core/api:computations"
,
"//tensorflow_federated/python/core/backends/native:execution_contexts"
,
"//tensorflow_federated/python/core/impl/types:computation_types"
,
],
)
py_library
(
name
=
"distributors"
,
srcs
=
[
"distributors.py"
],
...
...
编辑
预览
支持
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录