提交 826b66f3 编辑于 作者: Galen Andrew's avatar Galen Andrew 提交者: tensorflow-copybara
浏览文件

Adds functions for building useful compositions of aggregators. Also removes...

Adds functions for building useful compositions of aggregators. Also removes ZeroingClippingFactory because it was determined that composing Zeroing and Clipping is superior (since it allows independent estimators for the norms).

PiperOrigin-RevId: 341517003
上级 8d2f989f
......@@ -27,6 +27,35 @@ py_library(
],
)
py_library(
name = "clipping_compositions",
srcs = ["clipping_compositions.py"],
srcs_version = "PY3",
deps = [
":clipping_factory",
":factory",
":mean_factory",
":quantile_estimation",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:intrinsics",
],
)
py_test(
name = "clipping_compositions_test",
srcs = ["clipping_compositions_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":clipping_compositions",
":factory",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
],
)
py_library(
name = "clipping_factory",
srcs = ["clipping_factory.py"],
......@@ -36,7 +65,6 @@ py_library(
":sum_factory",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:intrinsics",
......
......@@ -14,7 +14,6 @@
"""Libraries for constructing federated aggregation."""
from tensorflow_federated.python.aggregators.clipping_factory import ClippingFactory
from tensorflow_federated.python.aggregators.clipping_factory import ZeroingClippingFactory
from tensorflow_federated.python.aggregators.clipping_factory import ZeroingFactory
from tensorflow_federated.python.aggregators.dp_factory import DifferentiallyPrivateFactory
from tensorflow_federated.python.aggregators.factory import AggregationProcessFactory
......
# Copyright 2020, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions providing recommended aggregator compositions."""
import tensorflow as tf
import tensorflow_privacy
from tensorflow_federated.python.aggregators import clipping_factory
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import mean_factory
from tensorflow_federated.python.aggregators import quantile_estimation
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import intrinsics
def _affine_transform(multiplier, increment):
transform_tf_comp = computations.tf_computation(
lambda value: multiplier * value + increment, tf.float32)
return computations.federated_computation(
lambda value: intrinsics.federated_map(transform_tf_comp, value),
computation_types.at_server(tf.float32))
def _make_quantile_estimation_process(initial_estimate: float,
target_quantile: float,
learning_rate: float):
return quantile_estimation.PrivateQuantileEstimationProcess(
tensorflow_privacy.NoPrivacyQuantileEstimatorQuery(
initial_estimate=initial_estimate,
target_quantile=target_quantile,
learning_rate=learning_rate,
geometric_update=True))
def adaptive_zeroing_mean(
initial_quantile_estimate: float,
target_quantile: float,
multiplier: float,
increment: float,
learning_rate: float,
norm_order: bool,
no_nan_mean: bool = False) -> factory.AggregationProcessFactory:
"""Creates a factory for mean with adaptive zeroing.
Estimates value at quantile `Z` of value norm distribution and zeroes out
values whose norm is greater than `rZ + i` for multiplier `r` and increment
`i`. The quantile `Z` is estimated using the geometric method described in
Thakkar et al. 2019, "Differentially Private Learning with Adaptive Clipping"
(https://arxiv.org/abs/1905.03871) without noise added (so not differentially
private).
Args:
initial_quantile_estimate: The initial estimate of the target quantile `Z`.
target_quantile: Which quantile to match, as a float in [0, 1]. For example,
0.5 for median, or 0.98 to zero out only the largest 2% of updates (if
multiplier=1 and increment=0).
multiplier: Factor `r` in zeroing norm formula `rZ + i`.
increment: Increment `i` in zeroing norm formula `rZ + i`.
learning_rate: Learning rate for quantile matching algorithm.
norm_order: A float for the order of the norm. Must be 1, 2, or np.inf.
no_nan_mean: A bool. If True, the computed mean is 0 if sum of weights is
equal to 0.
Returns:
A factory that performs mean after adaptive clipping.
"""
zeroing_quantile = _make_quantile_estimation_process(
initial_estimate=initial_quantile_estimate,
target_quantile=target_quantile,
learning_rate=learning_rate)
zeroing_norm = zeroing_quantile.map(_affine_transform(multiplier, increment))
mean = mean_factory.MeanFactory(no_nan_division=no_nan_mean)
return clipping_factory.ZeroingFactory(zeroing_norm, mean, norm_order)
def adaptive_zeroing_clipping_mean(
initial_zeroing_quantile_estimate: float,
target_zeroing_quantile: float,
zeroing_multiplier: float,
zeroing_increment: float,
zeroing_learning_rate: float,
zeroing_norm_order: bool,
initial_clipping_quantile_estimate: float,
target_clipping_quantile: float,
clipping_learning_rate: float,
no_nan_mean: bool = False) -> factory.AggregationProcessFactory:
"""Makes a factory for mean with adaptive zeroing and clipping.
Estimates value at quantile `Z` of value norm distribution and zeroes out
values whose norm is greater than `rZ + i` for multiplier `r` and increment
`i`. Also estimates value at quantile `C` and clips values whose L2 norm is
greater than `C` (without any multiplier or increment). The quantiles are
estimated using the geometric method described in Thakkar et al. 2019,
"Differentially Private Learning with Adaptive Clipping"
(https://arxiv.org/abs/1905.03871) without noise added (so not differentially
private). Zeroing occurs before clipping, so the estimation process for `C`
uses already zeroed values.
Note while the zeroing_norm_order may be 1.0 or np.inf, only L2 norm is used
for clipping.
Args:
initial_zeroing_quantile_estimate: The initial estimate of the target
quantile `Z` for zeroing.
target_zeroing_quantile: Which quantile to match for zeroing, as a float in
[0, 1]. For example, 0.5 for median, or 0.98 to zero out only the largest
2% of updates (if multiplier=1 and increment=0).
zeroing_multiplier: Factor `r` in zeroing norm formula `rZ + i`.
zeroing_increment: Increment `i` in zeroing norm formula `rZ + i`.
zeroing_learning_rate: Learning rate for zeroing quantile estimate.
zeroing_norm_order: A float for the order of the norm for zeroing. Must be
1, 2, or np.inf.
initial_clipping_quantile_estimate: The initial estimate of the target
quantile `C` for clipping. (Multiplier and increment are not used for
clipping.)
target_clipping_quantile: Which quantile to match for clipping, as a float
in [0, 1].
clipping_learning_rate: Learning rate for clipping quantile estimate.
no_nan_mean: A bool. If True, the computed mean is 0 if sum of weights is
equal to 0.
Returns:
A factory that performs mean after adaptive zeroing and clipping.
"""
zeroing_quantile = _make_quantile_estimation_process(
initial_estimate=initial_zeroing_quantile_estimate,
target_quantile=target_zeroing_quantile,
learning_rate=zeroing_learning_rate)
zeroing_norm = zeroing_quantile.map(
_affine_transform(zeroing_multiplier, zeroing_increment))
clipping_norm = _make_quantile_estimation_process(
initial_estimate=initial_clipping_quantile_estimate,
target_quantile=target_clipping_quantile,
learning_rate=clipping_learning_rate)
mean = mean_factory.MeanFactory(no_nan_division=no_nan_mean)
clip = clipping_factory.ClippingFactory(clipping_norm, mean)
return clipping_factory.ZeroingFactory(zeroing_norm, clip, zeroing_norm_order)
# Copyright 2020, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for factory compositions."""
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.aggregators import clipping_compositions as compositions
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.backends.native import execution_contexts
class ClippingCompositionsTest(test_case.TestCase):
def test_adaptive_zeroing_mean(self):
factory_ = compositions.adaptive_zeroing_mean(
initial_quantile_estimate=1.0,
target_quantile=0.5,
multiplier=2.0,
increment=1.0,
learning_rate=np.log(4.0),
norm_order=np.inf)
self.assertIsInstance(factory_, factory.AggregationProcessFactory)
process = factory_.create(computation_types.to_type(tf.float32))
state = process.initialize()
# Quantile estimate is 1.0, zeroing norm is 3.0.
client_data = [1.5, 3.5]
client_weight = [1.0, 1.0]
output = process.next(state, client_data, client_weight)
self.assertAllClose(1.5 / 2.0, output.result)
self.assertAllClose(3.0, output.measurements['zeroing_norm'])
self.assertAllClose(1.0, output.measurements['zeroed_count'])
# New quantile estimate is 1 * exp(0.5 ln(4)) = 2, zeroing norm is 5.0.
output = process.next(output.state, client_data, client_weight)
self.assertAllClose(5.0 / 2.0, output.result)
self.assertAllClose(5.0, output.measurements['zeroing_norm'])
self.assertAllClose(0.0, output.measurements['zeroed_count'])
def test_adaptive_zeroing_clipping_mean(self):
factory_ = compositions.adaptive_zeroing_clipping_mean(
initial_zeroing_quantile_estimate=1.0,
target_zeroing_quantile=0.5,
zeroing_multiplier=2.0,
zeroing_increment=2.0,
zeroing_learning_rate=np.log(4.0),
zeroing_norm_order=np.inf,
initial_clipping_quantile_estimate=2.0,
target_clipping_quantile=0.0,
clipping_learning_rate=np.log(4.0))
self.assertIsInstance(factory_, factory.AggregationProcessFactory)
process = factory_.create(computation_types.to_type(tf.float32))
state = process.initialize()
client_data = [3.0, 4.5]
client_weight = [1.0, 1.0]
# Zero quantile: 1.0, zero norm: 4.0, clip quantile (norm): 2.0.
output = process.next(state, client_data, client_weight)
self.assertAllClose(2.0 / 2.0, output.result)
self.assertAllClose(4.0, output.measurements['zeroing_norm'])
self.assertAllClose(1.0, output.measurements['zeroed_count'])
clip_measurements = output.measurements['agg_process']
self.assertAllClose(2.0, clip_measurements['clipping_norm'])
self.assertAllClose(1.0, clip_measurements['clipped_count'])
# New zero quantile: 1 * exp(0.5 ln(4)) = 2
# New zero norm is 6.0
# New clip quantile (norm) is 2 * exp(-0.5 ln(4)) = 1
output = process.next(output.state, client_data, client_weight)
self.assertAllClose(2.0 / 2.0, output.result)
self.assertAllClose(6.0, output.measurements['zeroing_norm'])
self.assertAllClose(0.0, output.measurements['zeroed_count'])
clip_measurements = output.measurements['agg_process']
self.assertAllClose(1.0, clip_measurements['clipping_norm'])
self.assertAllClose(2.0, clip_measurements['clipped_count'])
if __name__ == '__main__':
execution_contexts.set_local_execution_context()
test_case.main()
......@@ -23,7 +23,6 @@ from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import sum_factory
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.api import computation_base
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import intrinsics
......@@ -44,7 +43,7 @@ def _constant_process(value):
lambda state, value: state, init_fn.type_signature.result,
computation_types.at_clients(NORM_TF_TYPE))
report_fn = computations.federated_computation(
lambda s: intrinsics.federated_value(value, placements.SERVER),
lambda state: intrinsics.federated_value(value, placements.SERVER),
init_fn.type_signature.result)
return estimation_process.EstimationProcess(init_fn, next_fn, report_fn)
......@@ -329,166 +328,6 @@ class ZeroingFactory(factory.AggregationProcessFactory):
return aggregation_process.AggregationProcess(init_fn, next_fn)
class ZeroingClippingFactory(factory.AggregationProcessFactory):
"""`AggregationProcessFactory` for zeroing and clipping large values.
The created `tff.templates.AggregationProcess` zeroes out any values whose
norm is greater than than a given value, and further projects the values onto
an L2 ball (also referred to as "clipping") before aggregating the values as
specified by `inner_agg_factory`. The clipping norm is determined by the
`clipping_norm`, and the zeroing norm is computed as a function
(`zeroing_norm_fn`) applied to the clipping norm.
This is intended to be used when it is preferred (for privacy reasons perhaps)
to use only a single estimation process. If it is acceptable to use multiple
estimation processes it would be more flexible to compose a `ZeroingFactory`
with a `ClippingFactory`. For example, a `ZeroingFactory` allows zeroing
values with high L-inf norm, whereas this class supports only L2 norm.
The provided `clipping_norm` can either be a constant (for fixed norm), or an
instance of `tff.templates.EstimationProcess` (for adaptive norm). If it is an
estimation process, the value returned by its `report` method will be used as
the clipping norm. Its `next` method needs to accept a scalar float32 at
clients, corresponding to the norm of value being aggregated. The process can
thus adaptively determine the clipping norm based on the set of aggregated
values. For example if a `tff.aggregators.PrivateQuantileEstimationProcess` is
used, the clipping norm will be an estimate of a quantile of the norms of the
values being aggregated.
"""
def __init__(self, clipping_norm: Union[float,
estimation_process.EstimationProcess],
zeroing_norm_fn: computation_base.Computation,
inner_agg_factory: factory.AggregationProcessFactory):
"""Initializes a ZeroingClippingFactory.
Args:
clipping_norm: Either a float (for fixed norm) or an `EstimationProcess`
(for adaptive norm) that specifies the norm over which values should be
clipped. If an `EstimationProcess` is passed, value norms will be passed
to the process and its `report` function will be used as the clipping
norm.
zeroing_norm_fn: A `tff.Computation` to apply to the clipping norm to
produce the zeroing norm.
inner_agg_factory: A factory specifying the type of aggregation to be done
after zeroing and clipping.
"""
py_typecheck.check_type(inner_agg_factory,
factory.AggregationProcessFactory)
self._inner_agg_factory = inner_agg_factory
py_typecheck.check_type(clipping_norm,
(float, estimation_process.EstimationProcess))
if isinstance(clipping_norm, float):
clipping_norm = _constant_process(clipping_norm)
_check_norm_process(clipping_norm, 'clipping_norm')
self._clipping_norm_process = clipping_norm
py_typecheck.check_type(zeroing_norm_fn, computation_base.Computation)
zeroing_norm_arg_type = zeroing_norm_fn.type_signature.parameter
norm_type = clipping_norm.report.type_signature.result.member
if not zeroing_norm_arg_type.is_assignable_from(norm_type):
raise TypeError(
f'Argument of `zeroing_norm_fn` must be assignable from result of '
f'`clipping_norm`, but `clipping_norm` outputs {norm_type}\n '
f'and the argument of `zeroing_norm_fn` is {zeroing_norm_arg_type}.')
zeroing_norm_result_type = zeroing_norm_fn.type_signature.result
float_type = computation_types.to_type(NORM_TF_TYPE)
if not float_type.is_assignable_from(zeroing_norm_result_type):
raise TypeError(f'Result of `zeroing_norm_fn` must be assignable to '
f'NORM_TF_TYPE but found {zeroing_norm_result_type}.')
self._zeroing_norm_fn = zeroing_norm_fn
# The aggregation factories that will be used to count the number of zeroed
# and clipped values at each iteration. For now we are just creating them
# here, but soon we will make this customizable to allow DP measurements.
self._clipped_count_agg_factory = sum_factory.SumFactory()
self._zeroed_count_agg_factory = sum_factory.SumFactory()
def create(
self,
value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
py_typecheck.check_type(value_type, factory.ValueType.__args__)
if not all([t.dtype.is_floating for t in structure.flatten(value_type)]):
raise TypeError(f'All values in provided value_type must be of floating '
f'dtype. Provided value_type: {value_type}')
inner_agg_process = self._inner_agg_factory.create(value_type)
count_type = computation_types.to_type(COUNT_TF_TYPE)
clipped_count_agg_process = self._clipped_count_agg_factory.create(
count_type)
zeroed_count_agg_process = self._zeroed_count_agg_factory.create(count_type)
@computations.federated_computation()
def init_fn():
return intrinsics.federated_zip(
collections.OrderedDict(
clipping_norm=self._clipping_norm_process.initialize(),
inner_agg=inner_agg_process.initialize(),
clipped_count_agg=clipped_count_agg_process.initialize(),
zeroed_count_agg=zeroed_count_agg_process.initialize()))
@computations.tf_computation(value_type, NORM_TF_TYPE, NORM_TF_TYPE)
def clip_and_zero(value, clipping_norm, zeroing_norm):
clipped_value_as_list, global_norm = tf.clip_by_global_norm(
tf.nest.flatten(value), clipping_norm)
clipped_value = tf.nest.pack_sequence_as(value, clipped_value_as_list)
was_clipped = tf.cast((global_norm > clipping_norm), COUNT_TF_TYPE)
should_zero = (global_norm > zeroing_norm)
zeroed_and_clipped = tf.cond(
should_zero, lambda: tf.nest.map_structure(tf.zeros_like, value),
lambda: clipped_value)
was_zeroed = tf.cast(should_zero, COUNT_TF_TYPE)
return zeroed_and_clipped, global_norm, was_clipped, was_zeroed
@computations.federated_computation(
init_fn.type_signature.result, computation_types.at_clients(value_type),
computation_types.at_clients(tf.float32))
def next_fn(state, value, weight):
(clipping_norm_state, agg_state, clipped_count_state,
zeroed_count_state) = state
clipping_norm = self._clipping_norm_process.report(clipping_norm_state)
zeroing_norm = intrinsics.federated_map(self._zeroing_norm_fn,
clipping_norm)
(zeroed_and_clipped, global_norm,
was_clipped, was_zeroed) = intrinsics.federated_map(
clip_and_zero, (value, intrinsics.federated_broadcast(clipping_norm),
intrinsics.federated_broadcast(zeroing_norm)))
new_clipping_norm_state = self._clipping_norm_process.next(
clipping_norm_state, global_norm)
agg_output = inner_agg_process.next(agg_state, zeroed_and_clipped, weight)
clipped_count_output = clipped_count_agg_process.next(
clipped_count_state, was_clipped)
zeroed_count_output = zeroed_count_agg_process.next(
zeroed_count_state, was_zeroed)
new_state = collections.OrderedDict(
clipping_norm=new_clipping_norm_state,
inner_agg=agg_output.state,
clipped_count_agg=clipped_count_output.state,
zeroed_count_agg=zeroed_count_output.state)
measurements = collections.OrderedDict(
agg_process=agg_output.measurements,
clipping_norm=clipping_norm,
zeroing_norm=zeroing_norm,
clipped_count=clipped_count_output.result,
zeroed_count=zeroed_count_output.result)
return measured_process.MeasuredProcessOutput(
state=intrinsics.federated_zip(new_state),
result=agg_output.result,
measurements=intrinsics.federated_zip(measurements))
return aggregation_process.AggregationProcess(init_fn, next_fn)
def _global_inf_norm(l):
norms = [tf.reduce_max(tf.abs(a)) for a in tf.nest.flatten(l)]
return tf.reduce_max(tf.stack(norms))
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for ClippingFactory, ZeroingFactory and ZeroingClippingFactory."""
"""Tests for ClippingFactory and ZeroingFactory."""
import collections
import itertools
......@@ -48,12 +48,6 @@ def _zero_cons(clip=2.0, norm_order=2.0):
norm_order)
def _zero_clip_cons(clip=2.0):
zeroing_norm_fn = computations.tf_computation(lambda x: x + 3, tf.float32)
return clipping_factory.ZeroingClippingFactory(clip, zeroing_norm_fn,
mean_factory.MeanFactory())
_float_at_server = computation_types.at_server(tf.float32)
_float_at_clients = computation_types.at_clients(tf.float32)
......@@ -162,57 +156,11 @@ class ClippingFactoryComputationTest(test_case.TestCase,
self.assertTrue(
process.next.type_signature.is_equivalent_to(expected_next_type))
@parameterized.named_parameters(
('float', tf.float32),
('struct', _test_struct_type),
)
def test_zero_clip_type_properties(self, value_type):
factory = _zero_clip_cons()
value_type = computation_types.to_type(value_type)
process = factory.create(value_type)
self.assertIsInstance(process, aggregation_process.AggregationProcess)
mean_state_type = collections.OrderedDict(
value_sum_process=(), weight_sum_process=())
server_state_type = computation_types.at_server(
collections.OrderedDict(
clipping_norm=(),
inner_agg=mean_state_type,
clipped_count_agg=(),
zeroed_count_agg=()))
expected_initialize_type = computation_types.FunctionType(
parameter=None, result=server_state_type)
self.assertTrue(
process.initialize.type_signature.is_equivalent_to(
expected_initialize_type))
expected_measurements_type = computation_types.at_server(
collections.OrderedDict(
agg_process=collections.OrderedDict(
value_sum_process=(), weight_sum_process=()),
clipping_norm=clipping_factory.NORM_TF_TYPE,
zeroing_norm=clipping_factory.NORM_TF_TYPE,
clipped_count=clipping_factory.COUNT_TF_TYPE,
zeroed_count=clipping_factory.COUNT_TF_TYPE))
expected_next_type = computation_types.FunctionType(
parameter=collections.OrderedDict(
state=server_state_type,
value=computation_types.at_clients(value_type),
weight=computation_types.at_clients(tf.float32)),
result=measured_process.MeasuredProcessOutput(
state=server_state_type,
result=computation_types.at_server(value_type),
measurements=expected_measurements_type))
self.assertTrue(
process.next.type_signature.is_equivalent_to(expected_next_type))
@parameterized.named_parameters(
('clip_float_on_clients', 1.0, placements.CLIENTS, _clip_cons),
('clip_string_on_server', 'bad', placements.SERVER, _clip_cons),
('zero_float_on_clients', 1.0, placements.CLIENTS, _zero_cons),
('zero_string_on_server', 'bad', placements.SERVER, _zero_cons),
('zero_clip_float_on_clients', 1.0, placements.CLIENTS, _zero_clip_cons),
('zero_clip_string_on_server', 'bad', placements.SERVER, _zero_clip_cons),
)
def test_raises_on_bad_norm_process_result(self, value, placement,
factory_cons):
......@@ -228,7 +176,6 @@ class ClippingFactoryComputationTest(test_case.TestCase,
@parameterized.named_parameters(
('clip', _clip_cons),
('zero', _zero_cons),