提交 f49e5267 编辑于 作者: Galen Andrew's avatar Galen Andrew 提交者: tensorflow-copybara
浏览文件

Rewrite model_update_aggregator to use attrs for config classes, and to make...

Rewrite model_update_aggregator to use attrs for config classes, and to make factory creation closer to canonical example of how to compose aggregators.

PiperOrigin-RevId: 346679949
上级 1b73ae9c
......@@ -229,7 +229,6 @@ py_library(
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/aggregators:mean_factory",
"//tensorflow_federated/python/aggregators:quantile_estimation",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:intrinsics",
......@@ -243,12 +242,12 @@ py_test(
srcs_version = "PY3",
deps = [
":model_update_aggregator",
"//tensorflow_federated/python/aggregators:clipping_factory",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/aggregators:mean_factory",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/templates:aggregation_process",
"//tensorflow_federated/python/core/templates:estimation_process",
],
)
......
......@@ -16,6 +16,7 @@
import math
from typing import Optional, Union
import attr
import tensorflow as tf
import tensorflow_privacy as tfp
......@@ -24,20 +25,29 @@ from tensorflow_federated.python.aggregators import dp_factory
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import mean_factory
from tensorflow_federated.python.aggregators import quantile_estimation
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import intrinsics
def _check_positive(value, label):
AggregationFactory = Union[factory.WeightedAggregationFactory,
factory.UnweightedAggregationFactory]
def _check_positive(instance, attribute, value):
if value <= 0:
raise ValueError(f'{label} must be positive. Found {value}.')
raise ValueError(f'{attribute.name} must be positive. Found {value}.')
def _check_nonnegative(value, label):
def _check_nonnegative(instance, attribute, value):
if value < 0:
raise ValueError(f'{label} must be nonnegative. Found {value}.')
raise ValueError(f'{attribute.name} must be nonnegative. Found {value}.')
def _check_probability(instance, attribute, value):
if not 0 <= value <= 1:
raise ValueError(f'{attribute.name} must be between 0 and 1 (inclusive). '
f'Found {value}.')
def _affine_transform(multiplier, increment):
......@@ -48,213 +58,230 @@ def _affine_transform(multiplier, increment):
computation_types.at_server(tf.float32))
class QuantileEstimationConfig:
"""A config class for quantile estimation."""
def __init__(self, initial_estimate: float, target_quantile: float,
learning_rate: float):
"""Initializes a QuantileEstimationConfig.
Args:
initial_estimate: A float representing the initial quantile estimate.
target_quantile: A float in [0, 1] representing the quantile to match.
learning_rate: A float determining the learning rate for the process.
"""
py_typecheck.check_type(initial_estimate, float, 'initial_estimate')
_check_positive(initial_estimate, 'initial_estimate')
self._initial_estimate = initial_estimate
py_typecheck.check_type(target_quantile, float, 'target_quantile')
if not 0 <= target_quantile <= 1:
raise ValueError('target_quantile must be in the range [0, 1].')
self._target_quantile = target_quantile
py_typecheck.check_type(learning_rate, float, 'learning_rate')
_check_positive(learning_rate, 'learning_rate')
self._learning_rate = learning_rate
@property
def initial_estimate(self) -> float:
return self._initial_estimate
@property
def target_quantile(self) -> float:
return self._target_quantile
@property
def learning_rate(self) -> float:
return self._learning_rate
def to_quantile_estimation_process(
self) -> quantile_estimation.PrivateQuantileEstimationProcess:
return quantile_estimation.PrivateQuantileEstimationProcess(
tfp.NoPrivacyQuantileEstimatorQuery(
initial_estimate=self._initial_estimate,
target_quantile=self._target_quantile,
learning_rate=self._learning_rate,
geometric_update=True))
class ZeroingConfig:
"""Config for adaptive zeroing based on a quantile estimate."""
def __init__(self,
quantile: Optional[QuantileEstimationConfig] = None,
multiplier: float = 2.0,
increment: float = 1.0):
"""Initializes a ZeroingConfig.
Estimates value at quantile `Z` of value norm distribution and zeroes out
values whose norm is greater than `rZ + i` for multiplier `r` and increment
`i`. The quantile `Z` is estimated using the geometric method described in
Thakkar et al. 2019, "Differentially Private Learning with Adaptive
Clipping" (https://arxiv.org/abs/1905.03871) without noise added (so not
differentially private).
Args:
quantile: A `QuantileEstimationConfig` specifying the quantile estimation
process for zeroing. If None, defaults to a fast-adapting process that
zeroes only very high values.
multiplier: A float for factor `r` in zeroing norm formula `rZ + i`.
increment: A float for increment `i` in zeroing norm formula `rZ + i`.
"""
if quantile is None:
quantile = QuantileEstimationConfig(10.0, 0.98, math.log(10))
else:
py_typecheck.check_type(quantile, QuantileEstimationConfig, 'quantile')
self._quantile = quantile
py_typecheck.check_type(multiplier, float, 'multiplier')
_check_positive(multiplier, 'multiplier')
self._multiplier = multiplier
py_typecheck.check_type(increment, float, 'increment')
_check_nonnegative(increment, 'increment')
self._increment = increment
def to_factory(self, inner_factory) -> clipping_factory.ZeroingFactory:
zeroing_quantile = self._quantile.to_quantile_estimation_process()
zeroing_norm = zeroing_quantile.map(
_affine_transform(self._multiplier, self._increment))
return clipping_factory.ZeroingFactory(
zeroing_norm, inner_factory, norm_order=float('inf'))
class ClippingConfig:
"""Config for fixed or adaptive clipping with recommended defaults."""
def __init__(self,
clip: Optional[Union[float, QuantileEstimationConfig]] = None):
"""Initializes a ClippingConfig.
Args:
clip: Either a float representing the fixed clip norm, or a
QuantileEstimationConfig specifying the quantile estimation process for
adaptive clipping. If None, defaults to a quantile estimation process
that adapts reasonably fast and clips to a moderately high norm.
"""
if clip is None:
clip = QuantileEstimationConfig(1.0, 0.8, 0.2)
elif isinstance(clip, float):
_check_positive(clip, 'clip')
else:
py_typecheck.check_type(clip, QuantileEstimationConfig, 'clip')
self._clip = clip
@property
def clip(self) -> Union[float, QuantileEstimationConfig]:
return self._clip
@property
def is_fixed(self) -> bool:
return isinstance(self._clip, float)
def to_factory(self, inner_factory) -> clipping_factory.ClippingFactory:
if self.is_fixed:
return clipping_factory.ClippingFactory(self._clip, inner_factory)
else:
return clipping_factory.ClippingFactory(
self._clip.to_quantile_estimation_process(), inner_factory)
class DPConfig:
"""A config class for differential privacy with recommended defaults."""
def __init__(self,
noise_multiplier: float,
clients_per_round: float,
clipping: Optional[ClippingConfig] = None,
clipped_count_stddev: Optional[float] = None):
"""Initializes a DPConfig.
Args:
noise_multiplier: A float specifying the noise multiplier for the Gaussian
mechanism for model updates.
clients_per_round: A float specifying the expected number of clients per
round.
clipping: A ClippingConfig specifying the clipping to use. If None,
adaptive clipping with default parameters will be used.
clipped_count_stddev: A float specifying the stddev for clipped counts. If
None, defaults to 0.05 times `clients_per_round`.
"""
py_typecheck.check_type(noise_multiplier, float, 'noise_multiplier')
_check_nonnegative(noise_multiplier, 'noise_multiplier')
self._noise_multiplier = noise_multiplier
py_typecheck.check_type(clients_per_round, float, 'clients_per_round')
_check_positive(clients_per_round, 'clients_per_round')
self._clients_per_round = clients_per_round
if clipping is None:
clipping = ClippingConfig(QuantileEstimationConfig(1e-1, 0.5, 0.2))
else:
py_typecheck.check_type(clipping, ClippingConfig, 'clipping')
self._clipping = clipping
if clipped_count_stddev is None:
# Default to 0.05 * clients_per_round. This way the noised fraction
# of unclipped updates will be within 0.1 of the true fraction with
# 95.4% probability, and will be within 0.15 of the true fraction with
# 99.7% probability. Even in this unlikely case, the error on the update
# would be a factor of exp(0.15) = 1.16, not a huge deviation. So this
# default gives maximal privacy for acceptable probability of deviation.
clipped_count_stddev = 0.05 * clients_per_round
py_typecheck.check_type(clipped_count_stddev, float, 'clipped_count_stddev')
_check_nonnegative(clipped_count_stddev, 'clipped_count_stddev')
self._clipped_count_stddev = clipped_count_stddev
def to_factory(self) -> dp_factory.DifferentiallyPrivateFactory:
"""Creates factory based on config settings."""
if self._clipping.is_fixed:
stddev = self._clipping.clip * self._noise_multiplier
query = tfp.GaussianAverageQuery(
l2_norm_clip=self._clipping.clip,
sum_stddev=stddev,
denominator=self._clients_per_round)
else:
query = tfp.QuantileAdaptiveClipAverageQuery(
initial_l2_norm_clip=self._clipping.clip.initial_estimate,
noise_multiplier=self._noise_multiplier,
denominator=self._clients_per_round,
target_unclipped_quantile=self._clipping.clip.target_quantile,
learning_rate=self._clipping.clip.learning_rate,
clipped_count_stddev=self._clipped_count_stddev,
expected_num_records=self._clients_per_round,
geometric_update=True)
return dp_factory.DifferentiallyPrivateFactory(query)
@attr.s(frozen=True, kw_only=True)
class AdaptiveZeroingConfig:
"""Config for adaptive zeroing based on a quantile estimate.
Estimates value at quantile `Z` of value norm distribution and zeroes out
values whose norm is greater than `rZ + i` for multiplier `r` and increment
`i`. The quantile `Z` is estimated using the geometric method described in
Thakkar et al. 2019, "Differentially Private Learning with Adaptive
Clipping" (https://arxiv.org/abs/1905.03871) without noise added (so not
differentially private).
Default values are recommended for adaptive zeroing for data corruption
mitigation.
Attributes:
initial_quantile_estimate: The initial estimate of `Z`.
target_quantile: The quantile to which `Z` will be adapted.
learning_rate: The learning rate for the adaptive algorithm.
multiplier: The multiplier `r` to determine the zeroing norm.
increment: The increment `i` to determine the zeroing norm.
"""
initial_quantile_estimate: float = attr.ib(
default=10.0,
validator=[attr.validators.instance_of(float), _check_positive],
converter=float)
target_quantile: float = attr.ib(
default=0.98,
validator=[attr.validators.instance_of(float), _check_probability],
converter=float)
learning_rate: float = attr.ib(
default=math.log(10),
validator=[attr.validators.instance_of(float), _check_positive],
converter=float)
multiplier: float = attr.ib(
default=2.0,
validator=[attr.validators.instance_of(float), _check_positive],
converter=float)
increment: float = attr.ib(
default=1.0,
validator=[attr.validators.instance_of(float), _check_positive],
converter=float)
def _build_quantile_estimation_process(initial_estimate, target_quantile,
learning_rate):
return quantile_estimation.PrivateQuantileEstimationProcess(
tfp.NoPrivacyQuantileEstimatorQuery(
initial_estimate=initial_estimate,
target_quantile=target_quantile,
learning_rate=learning_rate,
geometric_update=True))
def _apply_zeroing(config: AdaptiveZeroingConfig,
inner_factory: AggregationFactory) -> AggregationFactory:
"""Applies zeroing to `inner_factory` according to `config`."""
zeroing_quantile = _build_quantile_estimation_process(
config.initial_quantile_estimate, config.target_quantile,
config.learning_rate)
zeroing_norm = zeroing_quantile.map(
_affine_transform(config.multiplier, config.increment))
return clipping_factory.ZeroingFactory(
zeroing_norm, inner_factory, norm_order=float('inf'))
@attr.s(frozen=True)
class FixedClippingConfig:
"""Config for clipping to a fixed value.
Attributes:
clip: The fixed clipping norm.
"""
clip: float = attr.ib(
validator=[attr.validators.instance_of(float), _check_positive],
converter=float)
@attr.s(frozen=True, kw_only=True)
class AdaptiveClippingConfig:
"""Config for adaptive clipping based on a quantile estimate.
Estimates value at quantile `C` of value norm distribution and clips
values whose norm is greater than `C`. The quantile is estimated using the
geometric method described in Thakkar et al. 2019, "Differentially Private
Learning with Adaptive Clipping" (https://arxiv.org/abs/1905.03871) without
noise added (so not differentially private).
Default values are recommended for adaptive clipping for robustness.
Attributes:
initial_clip: The initial estimate of `C`.
target_quantile: The quantile to which `C` will be adapted.
learning_rate: The learning rate for the adaptive algorithm.
"""
initial_clip: float = attr.ib(
default=1.0,
validator=[attr.validators.instance_of(float), _check_positive],
converter=float)
target_quantile: float = attr.ib(
default=0.8,
validator=[attr.validators.instance_of(float), _check_probability],
converter=float)
learning_rate: float = attr.ib(
default=0.2,
validator=[attr.validators.instance_of(float), _check_positive],
converter=float)
ClippingConfig = Union[FixedClippingConfig, AdaptiveClippingConfig]
def _apply_clipping(config: ClippingConfig,
inner_factory: AggregationFactory) -> AggregationFactory:
"""Applies clipping to `inner_factory` according to `config`."""
if isinstance(config, FixedClippingConfig):
return clipping_factory.ClippingFactory(config.clip, inner_factory)
elif isinstance(config, AdaptiveClippingConfig):
clipping_quantile = _build_quantile_estimation_process(
config.initial_clip, config.target_quantile, config.learning_rate)
return clipping_factory.ClippingFactory(clipping_quantile, inner_factory)
else:
raise TypeError(f'config is not a supported type of ClippingConfig. Found '
f'type {type(config)}.')
@attr.s(frozen=True, kw_only=True)
class DifferentialPrivacyConfig:
"""A config class for differential privacy with recommended defaults.
Attributes:
noise_multiplier: The ratio of the noise standard deviation to the clip
norm.
clients_per_round: The number of clients per round.
clipping: A FixedClippingConfig or AdaptiveClippingConfig specifying the
type of clipping. Defaults to an adaptive clip process that starts small
and adapts moderately quickly to the median.
clipped_count_stddev: The standard deviation of the clipped count estimate,
for private adaptation of the clipping norm. If unspecified, defaults to a
value that gives maximal privacy without disrupting the adaptive clipping
norm process too greatly.
"""
noise_multiplier: float = attr.ib(
validator=[attr.validators.instance_of(float), _check_nonnegative],
converter=float)
clients_per_round: float = attr.ib(
validator=[attr.validators.instance_of(float), _check_positive],
converter=float)
clipping: ClippingConfig = attr.ib(
default=AdaptiveClippingConfig(
initial_clip=1e-1, target_quantile=0.5, learning_rate=0.2),
validator=attr.validators.instance_of(
(FixedClippingConfig, AdaptiveClippingConfig)))
clipped_count_stddev: float = attr.ib(
validator=[attr.validators.instance_of(float), _check_nonnegative],
converter=float)
@clipped_count_stddev.default
def _set_default_clipped_count_stddev(self):
# Default to 0.05 * clients_per_round. This way the noised fraction
# of unclipped updates will be within 0.1 of the true fraction with
# 95.4% probability, and will be within 0.15 of the true fraction with
# 99.7% probability. Even in this unlikely case, the error on the update
# would be a factor of exp(0.15) = 1.16, not a huge deviation. So this
# default gives maximal privacy for acceptable probability of deviation.
return 0.05 * self.clients_per_round
def _dp_factory(
config: DifferentialPrivacyConfig
) -> dp_factory.DifferentiallyPrivateFactory:
"""Creates DifferentiallyPrivateFactory based on config settings."""
if isinstance(config.clipping, FixedClippingConfig):
stddev = config.clipping.clip * config.noise_multiplier
query = tfp.GaussianAverageQuery(
l2_norm_clip=config.clipping.clip,
sum_stddev=stddev,
denominator=config.clients_per_round)
elif isinstance(config.clipping, AdaptiveClippingConfig):
query = tfp.QuantileAdaptiveClipAverageQuery(
initial_l2_norm_clip=config.clipping.initial_clip,
noise_multiplier=config.noise_multiplier,
denominator=config.clients_per_round,
target_unclipped_quantile=config.clipping.target_quantile,
learning_rate=config.clipping.learning_rate,
clipped_count_stddev=config.clipped_count_stddev,
expected_num_records=config.clients_per_round,
geometric_update=True)
else:
raise TypeError(
f'config.clipping is not a supported type of ClippingConfig. Found '
f'type {type(config.clipping)}.')
return dp_factory.DifferentiallyPrivateFactory(query)
def model_update_aggregator(
zeroing: Optional[ZeroingConfig] = ZeroingConfig(),
clipping_and_noise: Optional[Union[ClippingConfig, DPConfig]] = None
) -> Union[factory.WeightedAggregationFactory,
factory.UnweightedAggregationFactory]:
"""Builds model update aggregator.
zeroing: Optional[AdaptiveZeroingConfig] = AdaptiveZeroingConfig(),
clipping_and_noise: Optional[Union[ClippingConfig,
DifferentialPrivacyConfig]] = None
) -> AggregationFactory:
"""Builds aggregator for model updates in FL according to configs.
The default aggregator (produced if no arguments are overridden) performs
mean with adaptive zeroing for robustness. To turn off adaptive zeroing set
`zeroing=None`. (Adaptive) clipping and/or differential privacy can
optionally be enabled by setting `clipping_and_noise`.
Args:
zeroing: A ZeroingConfig. If None, no zeroing will be performed.
clipping_and_noise: An optional ClippingConfig or DPConfig. If unspecified,
no clipping or noising will be performed.
clipping_and_noise: An optional ClippingConfig or DifferentialPrivacyConfig.
If unspecified, no clipping or noising will be performed.
Returns:
A `factory.WeightedAggregationFactory` intended for model update aggregation
......@@ -263,10 +290,12 @@ def model_update_aggregator(
if not clipping_and_noise:
factory_ = mean_factory.MeanFactory()
elif isinstance(clipping_and_noise, ClippingConfig):
factory_ = clipping_and_noise.to_factory(mean_factory.MeanFactory())
factory_ = _apply_clipping(clipping_and_noise, mean_factory.MeanFactory())
elif isinstance(clipping_and_noise, DifferentialPrivacyConfig):
factory_ = _dp_factory(clipping_and_noise)
else:
py_typecheck.check_type(clipping_and_noise, DPConfig, 'clipping_and_noise')
factory_ = clipping_and_noise.to_factory()
raise TypeError(f'clipping_and_noise must be a supported type of clipping '
f'or noise config. Found type {type(clipping_and_noise)}.')
if zeroing:
factory_ = zeroing.to_factory(factory_)
factory_ = _apply_zeroing(zeroing, factory_)
return factory_
......@@ -13,124 +13,64 @@
# limitations under the License.
"""Tests for model_update_aggregator."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_federated.python.aggregators import clipping_factory
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import mean_factory
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import estimation_process
from tensorflow_federated.python.learning import model_update_aggregator
_test_qe_config = model_update_aggregator.QuantileEstimationConfig(
initial_estimate=1.0, target_quantile=0.5, learning_rate=1.0)
_test_type = computation_types.TensorType(tf.float32)
class ModelUpdateAggregatorTest(test_case.TestCase, parameterized.TestCase):
def _check_value(self, fn, args, key=None, value=None, error=None):
print(f'key: {key} value: {value}')
print(f'args: {args}')
if key:
args[key] = value
if error is None:
fn(**args)
else:
with self.assertRaises(error):
fn(**args)
@parameterized.named_parameters(
('good', None, None, None),
('initial_estimate_type', 'initial_estimate', 'bad', TypeError),
('initial_estimate_value', 'initial_estimate', 0.0, ValueError),
('target_quantile_type', 'target_quantile', 'bad', TypeError),
('target_quantile_value', 'target_quantile', -1.0, ValueError),
('learning_rate_type', 'learning_rate', 'bad', TypeError),
('learning_rate_value', 'learning_rate', 0.0, ValueError),
)
def test_quantile_estimation_config_args(self, key, value, error):
good_args = dict(
initial_estimate=1.0, target_quantile=0.5, learning_rate=1.0)
self._check_value(model_update_aggregator.QuantileEstimationConfig,
good_args, key, value, error)
class ModelUpdateAggregatorTest(test_case.TestCase):
def test_quantile_estimation_config_to_process(self):
process = _test_qe_config.to_quantile_estimation_process()
self.assertIsInstance(process, estimation_process.EstimationProcess)
@parameterized.named_parameters(
('good', None, None, None),