提交 5ec5b2ac 编辑于 作者: Zachary Garrett's avatar Zachary Garrett 提交者: tensorflow-copybara
浏览文件

Remove numpy dependency, only used for `isinf` testing which is overkill.

PiperOrigin-RevId: 344832769
上级 827a7db5
......@@ -14,9 +14,9 @@
"""Factory for clipping/zeroing of large values."""
import collections
import math
from typing import Union
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
......@@ -293,7 +293,7 @@ class ZeroingFactory(factory.UnweightedAggregationFactory,
inner_agg_factory: A factory specifying the type of aggregation to be done
after zeroing.
norm_order: A float for the order of the norm. Must be 1., 2., or
infinity (e.g. `float('inf')`, `np.inf`).
infinity.
"""
py_typecheck.check_type(inner_agg_factory, _InnerFactoryType.__args__)
self._inner_agg_factory = inner_agg_factory
......@@ -306,9 +306,8 @@ class ZeroingFactory(factory.UnweightedAggregationFactory,
self._zeroing_norm_process = zeroing_norm
py_typecheck.check_type(norm_order, float)
if not (norm_order in [1.0, 2.0] or np.isinf(norm_order)):
raise ValueError('norm_order must be 1.0, 2.0 or infinity (e.g. '
'float(\'inf\'), np.inf)')
if not (norm_order in [1.0, 2.0] or math.isinf(norm_order)):
raise ValueError('norm_order must be 1.0, 2.0 or infinity')
self._norm_order = norm_order
# The aggregation factory that will be used to count the number of zeroed
......@@ -384,7 +383,7 @@ class ZeroingFactory(factory.UnweightedAggregationFactory,
elif self._norm_order == 2.0:
norm = tf.linalg.global_norm(tf.nest.flatten(value))
else:
assert np.isinf(self._norm_order)
assert math.isinf(self._norm_order)
norm = _global_inf_norm(value)
should_zero = (norm > zeroing_norm)
zeroed_value = tf.cond(
......
......@@ -17,7 +17,6 @@ import collections
import itertools
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.aggregators import clipping_factory
......@@ -490,7 +489,7 @@ class ClippingFactoryExecutionTest(test_case.TestCase):
self.assertEqual(1, output.measurements['zeroed_count'])
def test_fixed_zero_sum_struct_inf_norm(self):
factory = _zeroed_sum(2.0, np.inf)
factory = _zeroed_sum(2.0, float('inf'))
value_type = computation_types.to_type(_test_struct_type)
process = factory.create_unweighted(value_type)
......@@ -504,7 +503,7 @@ class ClippingFactoryExecutionTest(test_case.TestCase):
self.assertEqual(1, output.measurements['zeroed_count'])
def test_fixed_zero_mean_struct_inf_norm(self):
factory = _zeroed_mean(2.0, np.inf)
factory = _zeroed_mean(2.0, float('inf'))
value_type = computation_types.to_type(_test_struct_type)
weight_type = computation_types.to_type(tf.float32)
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册