提交 8b03066c 编辑于 作者: Jakub Konecny's avatar Jakub Konecny 提交者: tensorflow-copybara
浏览文件

Fixes issue with `float('inf')` not being `np.inf`.

`float('inf') is np.inf` evaluates as `False`, but `float('inf') == np.inf` and `float('inf') in [np.inf]` evaluates as `True`, which leads to empty AssertionError if `float('inf')` is provided as we surprisingly do not catch it earlier.

Changing this to use `np.isinf()` to support different ways to express infinity.

Also extends the tests to build `AggregationProces`, some of which would have failed.

PiperOrigin-RevId: 344475293
上级 ec2a5e40
......@@ -292,7 +292,8 @@ class ZeroingFactory(factory.UnweightedAggregationFactory,
norm.
inner_agg_factory: A factory specifying the type of aggregation to be done
after zeroing.
norm_order: A float for the order of the norm. Must be 1, 2, or np.inf.
norm_order: A float for the order of the norm. Must be 1., 2., or
infinity (e.g. `float('inf')`, `np.inf`).
"""
py_typecheck.check_type(inner_agg_factory, _InnerFactoryType.__args__)
self._inner_agg_factory = inner_agg_factory
......@@ -305,8 +306,9 @@ class ZeroingFactory(factory.UnweightedAggregationFactory,
self._zeroing_norm_process = zeroing_norm
py_typecheck.check_type(norm_order, float)
if norm_order not in [1.0, 2.0, np.inf]:
raise ValueError('norm_order must be 1.0, 2.0 or np.inf.')
if not (norm_order in [1.0, 2.0] or np.isinf(norm_order)):
raise ValueError('norm_order must be 1.0, 2.0 or infinity (e.g. '
'float(\'inf\'), np.inf)')
self._norm_order = norm_order
# The aggregation factory that will be used to count the number of zeroed
......@@ -382,7 +384,7 @@ class ZeroingFactory(factory.UnweightedAggregationFactory,
elif self._norm_order == 2.0:
norm = tf.linalg.global_norm(tf.nest.flatten(value))
else:
assert self._norm_order is np.inf
assert np.isinf(self._norm_order)
norm = _global_inf_norm(value)
should_zero = (norm > zeroing_norm)
zeroed_value = tf.cond(
......
......@@ -245,7 +245,9 @@ py_test(
":model_update_aggregator",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/aggregators:mean_factory",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/templates:aggregation_process",
"//tensorflow_federated/python/core/templates:estimation_process",
],
)
......
......@@ -134,8 +134,8 @@ class ZeroingConfig:
zeroing_quantile = self._quantile.to_quantile_estimation_process()
zeroing_norm = zeroing_quantile.map(
_affine_transform(self._multiplier, self._increment))
return clipping_factory.ZeroingFactory(zeroing_norm, inner_factory,
float('inf'))
return clipping_factory.ZeroingFactory(
zeroing_norm, inner_factory, norm_order=float('inf'))
class ClippingConfig:
......
......@@ -14,10 +14,13 @@
"""Tests for model_update_aggregator."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import mean_factory
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import estimation_process
from tensorflow_federated.python.learning import model_update_aggregator
......@@ -25,6 +28,9 @@ _test_qe_config = model_update_aggregator.QuantileEstimationConfig(
initial_estimate=1.0, target_quantile=0.5, learning_rate=1.0)
_test_type = computation_types.TensorType(tf.float32)
class ModelUpdateAggregatorTest(test_case.TestCase, parameterized.TestCase):
def _check_value(self, fn, args, key=None, value=None, error=None):
......@@ -78,6 +84,8 @@ class ModelUpdateAggregatorTest(test_case.TestCase, parameterized.TestCase):
quantile=qe_config, multiplier=10.0, increment=0.5)
factory_ = zeroing_config.to_factory(mean_factory.MeanFactory())
self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
process = factory_.create_weighted(_test_type, _test_type)
self.assertIsInstance(process, aggregation_process.AggregationProcess)
@parameterized.named_parameters(
('good', None, None, None),
......@@ -94,6 +102,8 @@ class ModelUpdateAggregatorTest(test_case.TestCase, parameterized.TestCase):
clipping_config = model_update_aggregator.ClippingConfig()
factory_ = clipping_config.to_factory(mean_factory.MeanFactory())
self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
process = factory_.create_weighted(_test_type, _test_type)
self.assertIsInstance(process, aggregation_process.AggregationProcess)
@parameterized.named_parameters(
('good', None, None, None),
......@@ -115,10 +125,14 @@ class ModelUpdateAggregatorTest(test_case.TestCase, parameterized.TestCase):
noise_multiplier=1.0, clients_per_round=10.0)
factory_ = dp_config.to_factory()
self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
process = factory_.create_unweighted(_test_type)
self.assertIsInstance(process, aggregation_process.AggregationProcess)
def test_model_update_aggregator(self):
factory_ = model_update_aggregator.model_update_aggregator()
self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
process = factory_.create_weighted(_test_type, _test_type)
self.assertIsInstance(process, aggregation_process.AggregationProcess)
if __name__ == '__main__':
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册