diff --git a/tensorflow_federated/python/aggregators/BUILD b/tensorflow_federated/python/aggregators/BUILD index 58bd37dc0c5bcb195ab8486bb23cfefd7cb4dd3a..096e85692908a41a569aadf1870a68236c3f3809 100644 --- a/tensorflow_federated/python/aggregators/BUILD +++ b/tensorflow_federated/python/aggregators/BUILD @@ -33,6 +33,7 @@ py_library( srcs_version = "PY3", deps = [ ":factory", + ":sum_factory", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/api:computation_base", diff --git a/tensorflow_federated/python/aggregators/clipping_factory.py b/tensorflow_federated/python/aggregators/clipping_factory.py index 128eb7eeae0c46d4ca09b718cb29df3d843e54a2..e1611cd640ce6811302024354192f6cad7a03fda 100644 --- a/tensorflow_federated/python/aggregators/clipping_factory.py +++ b/tensorflow_federated/python/aggregators/clipping_factory.py @@ -13,12 +13,14 @@ # limitations under the License. """Factory for clipping/zeroing of large values.""" +import collections from typing import Union import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import factory +from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.api import computation_base @@ -30,7 +32,8 @@ from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import estimation_process from tensorflow_federated.python.core.templates import measured_process -NORM_TYPE = tf.float32 +NORM_TF_TYPE = tf.float32 +COUNT_TF_TYPE = tf.int32 def _constant_process(value): @@ -39,7 +42,7 @@ def _constant_process(value): lambda: intrinsics.federated_value((), placements.SERVER)) next_fn = computations.federated_computation( lambda state, value: state, init_fn.type_signature.result, - computation_types.at_clients(NORM_TYPE)) + computation_types.at_clients(NORM_TF_TYPE)) report_fn = computations.federated_computation( lambda s: intrinsics.federated_value(value, placements.SERVER), init_fn.type_signature.result) @@ -51,8 +54,8 @@ def _check_norm_process(norm_process: estimation_process.EstimationProcess, """Checks type properties for norm_process. The norm_process must be an `EstimationProcess` with `next` function of type - signature (<state@SERVER, NORM_TYPE@CLIENTS> -> state@SERVER), and `report` - with type signature (state@SERVER -> NORM_TYPE@SERVER). + signature (<state@SERVER, NORM_TF_TYPE@CLIENTS> -> state@SERVER), and `report` + with type signature (state@SERVER -> NORM_TF_TYPE@SERVER). Args: norm_process: A process to check. @@ -66,10 +69,10 @@ def _check_norm_process(norm_process: estimation_process.EstimationProcess, raise TypeError(f'`{name}.next` must take two arguments but found:\n' f'{next_parameter_type}') - norm_type_at_clients = computation_types.at_clients(NORM_TYPE) + norm_type_at_clients = computation_types.at_clients(NORM_TF_TYPE) if not next_parameter_type[1].is_assignable_from(norm_type_at_clients): raise TypeError(f'Second argument of `{name}.next` must be assignable from ' - f'NORM_TYPE@CLIENTS but found {next_parameter_type[1]}') + f'NORM_TF_TYPE@CLIENTS but found {next_parameter_type[1]}') next_result_type = norm_process.next.type_signature.result if not norm_process.state_type.is_assignable_from(next_result_type): @@ -78,10 +81,10 @@ def _check_norm_process(norm_process: estimation_process.EstimationProcess, f'while the state type is:\n{norm_process.state_type}') result_type = norm_process.report.type_signature.result - norm_type_at_server = computation_types.at_server(NORM_TYPE) + norm_type_at_server = computation_types.at_server(NORM_TF_TYPE) if not norm_type_at_server.is_assignable_from(result_type): raise TypeError(f'Result type of `{name}.report` must be assignable to ' - f'NORM_TYPE@SERVER but found {result_type}.') + f'NORM_TF_TYPE@SERVER but found {result_type}.') class ClippingFactory(factory.AggregationProcessFactory): @@ -126,6 +129,11 @@ class ClippingFactory(factory.AggregationProcessFactory): _check_norm_process(clipping_norm, 'clipping_norm') self._clipping_norm_process = clipping_norm + # The aggregation factory that will be used to count the number of clipped + # values at each iteration. For now we are just creating it here, but soon + # we will make this customizable to allow DP measurements. + self._clipped_count_agg_factory = sum_factory.SumFactory() + def create( self, value_type: factory.ValueType) -> aggregation_process.AggregationProcess: @@ -138,38 +146,56 @@ class ClippingFactory(factory.AggregationProcessFactory): inner_agg_process = self._inner_agg_factory.create(value_type) + count_type = computation_types.to_type(COUNT_TF_TYPE) + clipped_count_agg_process = self._clipped_count_agg_factory.create( + count_type) + @computations.federated_computation() def init_fn(): - return intrinsics.federated_zip((self._clipping_norm_process.initialize(), - inner_agg_process.initialize())) + return intrinsics.federated_zip( + collections.OrderedDict( + clipping_norm=self._clipping_norm_process.initialize(), + inner_agg=inner_agg_process.initialize(), + clipped_count_agg=clipped_count_agg_process.initialize())) - @computations.tf_computation(value_type, NORM_TYPE) + @computations.tf_computation(value_type, NORM_TF_TYPE) def clip(value, clipping_norm): clipped_value_as_list, global_norm = tf.clip_by_global_norm( tf.nest.flatten(value), clipping_norm) clipped_value = tf.nest.pack_sequence_as(value, clipped_value_as_list) - return clipped_value, global_norm + was_clipped = tf.cast((global_norm > clipping_norm), COUNT_TF_TYPE) + return clipped_value, global_norm, was_clipped @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type), computation_types.at_clients(tf.float32)) def next_fn(state, value, weight): - clipping_norm_state, agg_state = state + clipping_norm_state, agg_state, clipped_count_state = state clipping_norm = self._clipping_norm_process.report(clipping_norm_state) - clipped_value, global_norm = intrinsics.federated_map( + clipped_value, global_norm, was_clipped = intrinsics.federated_map( clip, (value, intrinsics.federated_broadcast(clipping_norm))) - agg_output = inner_agg_process.next(agg_state, clipped_value, weight) new_clipping_norm_state = self._clipping_norm_process.next( clipping_norm_state, global_norm) + agg_output = inner_agg_process.next(agg_state, clipped_value, weight) + clipped_count_output = clipped_count_agg_process.next( + clipped_count_state, was_clipped) + + new_state = collections.OrderedDict( + clipping_norm=new_clipping_norm_state, + inner_agg=agg_output.state, + clipped_count_agg=clipped_count_output.state) + measurements = collections.OrderedDict( + agg_process=agg_output.measurements, + clipping_norm=clipping_norm, + clipped_count=clipped_count_output.result) return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_zip( - (new_clipping_norm_state, agg_output.state)), + state=intrinsics.federated_zip(new_state), result=agg_output.result, - measurements=agg_output.measurements) + measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn) @@ -195,7 +221,7 @@ class ZeroingFactory(factory.AggregationProcessFactory): def __init__(self, zeroing_norm: Union[float, estimation_process.EstimationProcess], inner_agg_factory: factory.AggregationProcessFactory, - norm_order: float = np.inf): + norm_order: float = 2.0): """Initializes a ZeroingFactory. Args: @@ -206,8 +232,7 @@ class ZeroingFactory(factory.AggregationProcessFactory): norm. inner_agg_factory: A factory specifying the type of aggregation to be done after zeroing. - norm_order: A float for the order of the norm. For example, may be 1, 2, - or np.inf. + norm_order: A float for the order of the norm. Must be 1, 2, or np.inf. """ py_typecheck.check_type(inner_agg_factory, factory.AggregationProcessFactory) @@ -221,8 +246,15 @@ class ZeroingFactory(factory.AggregationProcessFactory): self._zeroing_norm_process = zeroing_norm py_typecheck.check_type(norm_order, float) + if norm_order not in [1.0, 2.0, np.inf]: + raise ValueError('norm_order must be 1.0, 2.0 or np.inf.') self._norm_order = norm_order + # The aggregation factory that will be used to count the number of zeroed + # values at each iteration. For now we are just creating it here, but soon + # we will make this customizable to allow DP measurements. + self._zeroed_count_agg_factory = sum_factory.SumFactory() + def create( self, value_type: factory.ValueType) -> aggregation_process.AggregationProcess: @@ -236,40 +268,63 @@ class ZeroingFactory(factory.AggregationProcessFactory): inner_agg_process = self._inner_agg_factory.create(value_type) + count_type = computation_types.to_type(COUNT_TF_TYPE) + zeroed_count_agg_process = self._zeroed_count_agg_factory.create(count_type) + @computations.federated_computation() def init_fn(): - return intrinsics.federated_zip((self._zeroing_norm_process.initialize(), - inner_agg_process.initialize())) + return intrinsics.federated_zip( + collections.OrderedDict( + zeroing_norm=self._zeroing_norm_process.initialize(), + inner_agg=inner_agg_process.initialize(), + zeroed_count_agg=zeroed_count_agg_process.initialize())) - @computations.tf_computation(value_type, NORM_TYPE) + @computations.tf_computation(value_type, NORM_TF_TYPE) def zero(value, zeroing_norm): - # Concat to take norm will introduce memory overhead. Consider optimizing. - vectors = tf.nest.map_structure(lambda v: tf.reshape(v, [-1]), value) - norm = tf.norm( - tf.concat(tf.nest.flatten(vectors), axis=0), ord=self._norm_order) - zeroed = _zero_over(value, norm, zeroing_norm) - return zeroed, norm + if self._norm_order == 1.0: + norm = _global_l1_norm(value) + elif self._norm_order == 2.0: + norm = tf.linalg.global_norm(tf.nest.flatten(value)) + else: + assert self._norm_order is np.inf + norm = _global_inf_norm(value) + should_zero = (norm > zeroing_norm) + zeroed_value = tf.cond( + should_zero, lambda: tf.nest.map_structure(tf.zeros_like, value), + lambda: value) + was_zeroed = tf.cast(should_zero, COUNT_TF_TYPE) + return zeroed_value, norm, was_zeroed @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type), computation_types.at_clients(tf.float32)) def next_fn(state, value, weight): - zeroing_norm_state, agg_state = state + zeroing_norm_state, agg_state, zeroed_count_state = state zeroing_norm = self._zeroing_norm_process.report(zeroing_norm_state) - zeroed, norm = intrinsics.federated_map( + zeroed_value, norm, was_zeroed = intrinsics.federated_map( zero, (value, intrinsics.federated_broadcast(zeroing_norm))) - agg_output = inner_agg_process.next(agg_state, zeroed, weight) new_zeroing_norm_state = self._zeroing_norm_process.next( zeroing_norm_state, norm) + agg_output = inner_agg_process.next(agg_state, zeroed_value, weight) + zeroed_count_output = zeroed_count_agg_process.next( + zeroed_count_state, was_zeroed) + + new_state = collections.OrderedDict( + zeroing_norm=new_zeroing_norm_state, + inner_agg=agg_output.state, + zeroed_count_agg=zeroed_count_output.state) + measurements = collections.OrderedDict( + agg_process=agg_output.measurements, + zeroing_norm=zeroing_norm, + zeroed_count=zeroed_count_output.result) return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_zip( - (new_zeroing_norm_state, agg_output.state)), + state=intrinsics.federated_zip(new_state), result=agg_output.result, - measurements=agg_output.measurements) + measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn) @@ -339,12 +394,18 @@ class ZeroingClippingFactory(factory.AggregationProcessFactory): f'and the argument of `zeroing_norm_fn` is {zeroing_norm_arg_type}.') zeroing_norm_result_type = zeroing_norm_fn.type_signature.result - float_type = computation_types.to_type(NORM_TYPE) + float_type = computation_types.to_type(NORM_TF_TYPE) if not float_type.is_assignable_from(zeroing_norm_result_type): raise TypeError(f'Result of `zeroing_norm_fn` must be assignable to ' - f'NORM_TYPE but found {zeroing_norm_result_type}.') + f'NORM_TF_TYPE but found {zeroing_norm_result_type}.') self._zeroing_norm_fn = zeroing_norm_fn + # The aggregation factories that will be used to count the number of zeroed + # and clipped values at each iteration. For now we are just creating them + # here, but soon we will make this customizable to allow DP measurements. + self._clipped_count_agg_factory = sum_factory.SumFactory() + self._zeroed_count_agg_factory = sum_factory.SumFactory() + def create( self, value_type: factory.ValueType) -> aggregation_process.AggregationProcess: @@ -357,47 +418,82 @@ class ZeroingClippingFactory(factory.AggregationProcessFactory): inner_agg_process = self._inner_agg_factory.create(value_type) + count_type = computation_types.to_type(COUNT_TF_TYPE) + clipped_count_agg_process = self._clipped_count_agg_factory.create( + count_type) + zeroed_count_agg_process = self._zeroed_count_agg_factory.create(count_type) + @computations.federated_computation() def init_fn(): - return intrinsics.federated_zip((self._clipping_norm_process.initialize(), - inner_agg_process.initialize())) - - @computations.tf_computation(value_type, NORM_TYPE, NORM_TYPE) + return intrinsics.federated_zip( + collections.OrderedDict( + clipping_norm=self._clipping_norm_process.initialize(), + inner_agg=inner_agg_process.initialize(), + clipped_count_agg=clipped_count_agg_process.initialize(), + zeroed_count_agg=zeroed_count_agg_process.initialize())) + + @computations.tf_computation(value_type, NORM_TF_TYPE, NORM_TF_TYPE) def clip_and_zero(value, clipping_norm, zeroing_norm): clipped_value_as_list, global_norm = tf.clip_by_global_norm( tf.nest.flatten(value), clipping_norm) clipped_value = tf.nest.pack_sequence_as(value, clipped_value_as_list) - zeroed_and_clipped = _zero_over(clipped_value, global_norm, zeroing_norm) - return zeroed_and_clipped, global_norm + was_clipped = tf.cast((global_norm > clipping_norm), COUNT_TF_TYPE) + should_zero = (global_norm > zeroing_norm) + zeroed_and_clipped = tf.cond( + should_zero, lambda: tf.nest.map_structure(tf.zeros_like, value), + lambda: clipped_value) + was_zeroed = tf.cast(should_zero, COUNT_TF_TYPE) + return zeroed_and_clipped, global_norm, was_clipped, was_zeroed @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type), computation_types.at_clients(tf.float32)) def next_fn(state, value, weight): - clipping_norm_state, agg_state = state + (clipping_norm_state, agg_state, clipped_count_state, + zeroed_count_state) = state clipping_norm = self._clipping_norm_process.report(clipping_norm_state) zeroing_norm = intrinsics.federated_map(self._zeroing_norm_fn, clipping_norm) - zeroed_and_clipped, global_norm = intrinsics.federated_map( - clip_and_zero, (value, intrinsics.federated_broadcast(clipping_norm), - intrinsics.federated_broadcast(zeroing_norm))) + (zeroed_and_clipped, global_norm, + was_clipped, was_zeroed) = intrinsics.federated_map( + clip_and_zero, (value, intrinsics.federated_broadcast(clipping_norm), + intrinsics.federated_broadcast(zeroing_norm))) - agg_output = inner_agg_process.next(agg_state, zeroed_and_clipped, weight) new_clipping_norm_state = self._clipping_norm_process.next( clipping_norm_state, global_norm) + agg_output = inner_agg_process.next(agg_state, zeroed_and_clipped, weight) + clipped_count_output = clipped_count_agg_process.next( + clipped_count_state, was_clipped) + zeroed_count_output = zeroed_count_agg_process.next( + zeroed_count_state, was_zeroed) + + new_state = collections.OrderedDict( + clipping_norm=new_clipping_norm_state, + inner_agg=agg_output.state, + clipped_count_agg=clipped_count_output.state, + zeroed_count_agg=zeroed_count_output.state) + measurements = collections.OrderedDict( + agg_process=agg_output.measurements, + clipping_norm=clipping_norm, + zeroing_norm=zeroing_norm, + clipped_count=clipped_count_output.result, + zeroed_count=zeroed_count_output.result) return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_zip( - (new_clipping_norm_state, agg_output.state)), + state=intrinsics.federated_zip(new_state), result=agg_output.result, - measurements=agg_output.measurements) + measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn) -def _zero_over(value, norm, zeroing_norm): - return tf.cond((norm > zeroing_norm), - lambda: tf.nest.map_structure(tf.zeros_like, value), - lambda: value) +def _global_inf_norm(l): + norms = [tf.reduce_max(tf.abs(a)) for a in tf.nest.flatten(l)] + return tf.reduce_max(tf.stack(norms)) + + +def _global_l1_norm(l): + norms = [tf.reduce_sum(tf.abs(a)) for a in tf.nest.flatten(l)] + return tf.reduce_sum(tf.stack(norms)) diff --git a/tensorflow_federated/python/aggregators/clipping_factory_test.py b/tensorflow_federated/python/aggregators/clipping_factory_test.py index 2ba8473eeb3dde67d0b0af816dff2e1ed7a5dc1e..d0cc027dd6d2c1a4bea378f1f576856e15de1ae2 100644 --- a/tensorflow_federated/python/aggregators/clipping_factory_test.py +++ b/tensorflow_federated/python/aggregators/clipping_factory_test.py @@ -14,8 +14,10 @@ """Tests for ClippingFactory, ZeroingFactory and ZeroingClippingFactory.""" import collections +import itertools from absl.testing import parameterized +import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import clipping_factory @@ -37,15 +39,16 @@ def _make_test_struct_value(x): return [tf.constant(x, dtype=tf.float32, shape=(3,)), x] -def _clip_cons(clip=2.): +def _clip_cons(clip=2.0): return clipping_factory.ClippingFactory(clip, mean_factory.MeanFactory()) -def _zero_cons(clip=2.): - return clipping_factory.ZeroingFactory(clip, mean_factory.MeanFactory()) +def _zero_cons(clip=2.0, norm_order=2.0): + return clipping_factory.ZeroingFactory(clip, mean_factory.MeanFactory(), + norm_order) -def _zero_clip_cons(clip=2.): +def _zero_clip_cons(clip=2.0): zeroing_norm_fn = computations.tf_computation(lambda x: x + 3, tf.float32) return clipping_factory.ZeroingClippingFactory(clip, zeroing_norm_fn, mean_factory.MeanFactory()) @@ -82,22 +85,101 @@ class ClippingFactoryComputationTest(test_case.TestCase, parameterized.TestCase): @parameterized.named_parameters( - ('clip_float', tf.float32, _clip_cons), - ('clip_struct', _test_struct_type, _clip_cons), - ('zero_float', tf.float32, _zero_cons), - ('zero_struct', _test_struct_type, _zero_cons), - ('zero_clip_float', tf.float32, _zero_clip_cons), - ('zero_clip_struct', _test_struct_type, _zero_clip_cons), + ('float', tf.float32), + ('struct', _test_struct_type), ) - def test_type_properties(self, value_type, factory_cons): - factory = factory_cons() + def test_clip_type_properties(self, value_type): + factory = _clip_cons() + value_type = computation_types.to_type(value_type) + process = factory.create(value_type) + self.assertIsInstance(process, aggregation_process.AggregationProcess) + + mean_state_type = collections.OrderedDict( + value_sum_process=(), weight_sum_process=()) + server_state_type = computation_types.at_server( + collections.OrderedDict( + clipping_norm=(), inner_agg=mean_state_type, clipped_count_agg=())) + expected_initialize_type = computation_types.FunctionType( + parameter=None, result=server_state_type) + self.assertTrue( + process.initialize.type_signature.is_equivalent_to( + expected_initialize_type)) + + expected_measurements_type = computation_types.at_server( + collections.OrderedDict( + agg_process=collections.OrderedDict( + value_sum_process=(), weight_sum_process=()), + clipping_norm=clipping_factory.NORM_TF_TYPE, + clipped_count=clipping_factory.COUNT_TF_TYPE)) + expected_next_type = computation_types.FunctionType( + parameter=collections.OrderedDict( + state=server_state_type, + value=computation_types.at_clients(value_type), + weight=computation_types.at_clients(tf.float32)), + result=measured_process.MeasuredProcessOutput( + state=server_state_type, + result=computation_types.at_server(value_type), + measurements=expected_measurements_type)) + self.assertTrue( + process.next.type_signature.is_equivalent_to(expected_next_type)) + + @parameterized.named_parameters( + ('float', tf.float32), + ('struct', _test_struct_type), + ) + def test_zero_type_properties(self, value_type): + factory = _zero_cons() + value_type = computation_types.to_type(value_type) + process = factory.create(value_type) + self.assertIsInstance(process, aggregation_process.AggregationProcess) + + mean_state_type = collections.OrderedDict( + value_sum_process=(), weight_sum_process=()) + server_state_type = computation_types.at_server( + collections.OrderedDict( + zeroing_norm=(), inner_agg=mean_state_type, zeroed_count_agg=())) + expected_initialize_type = computation_types.FunctionType( + parameter=None, result=server_state_type) + self.assertTrue( + process.initialize.type_signature.is_equivalent_to( + expected_initialize_type)) + + expected_measurements_type = computation_types.at_server( + collections.OrderedDict( + agg_process=collections.OrderedDict( + value_sum_process=(), weight_sum_process=()), + zeroing_norm=clipping_factory.NORM_TF_TYPE, + zeroed_count=clipping_factory.COUNT_TF_TYPE)) + expected_next_type = computation_types.FunctionType( + parameter=collections.OrderedDict( + state=server_state_type, + value=computation_types.at_clients(value_type), + weight=computation_types.at_clients(tf.float32)), + result=measured_process.MeasuredProcessOutput( + state=server_state_type, + result=computation_types.at_server(value_type), + measurements=expected_measurements_type)) + self.assertTrue( + process.next.type_signature.is_equivalent_to(expected_next_type)) + + @parameterized.named_parameters( + ('float', tf.float32), + ('struct', _test_struct_type), + ) + def test_zero_clip_type_properties(self, value_type): + factory = _zero_clip_cons() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) + mean_state_type = collections.OrderedDict( + value_sum_process=(), weight_sum_process=()) server_state_type = computation_types.at_server( - ((), - collections.OrderedDict(value_sum_process=(), weight_sum_process=()))) + collections.OrderedDict( + clipping_norm=(), + inner_agg=mean_state_type, + clipped_count_agg=(), + zeroed_count_agg=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( @@ -105,7 +187,13 @@ class ClippingFactoryComputationTest(test_case.TestCase, expected_initialize_type)) expected_measurements_type = computation_types.at_server( - collections.OrderedDict(value_sum_process=(), weight_sum_process=())) + collections.OrderedDict( + agg_process=collections.OrderedDict( + value_sum_process=(), weight_sum_process=()), + clipping_norm=clipping_factory.NORM_TF_TYPE, + zeroing_norm=clipping_factory.NORM_TF_TYPE, + clipped_count=clipping_factory.COUNT_TF_TYPE, + zeroed_count=clipping_factory.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, @@ -134,7 +222,7 @@ class ClippingFactoryComputationTest(test_case.TestCase, norm = _test_norm_process(report_fn=report_fn) with self.assertRaisesRegex( - TypeError, r'Result type .* assignable to NORM_TYPE@SERVER'): + TypeError, r'Result type .* assignable to NORM_TF_TYPE@SERVER'): factory_cons(norm) @parameterized.named_parameters( @@ -177,7 +265,7 @@ class ClippingFactoryComputationTest(test_case.TestCase, norm = _test_norm_process(next_fn=next_fn) with self.assertRaisesRegex( - TypeError, 'Second argument .* assignable from NORM_TYPE@CLIENTS'): + TypeError, 'Second argument .* assignable from NORM_TF_TYPE@CLIENTS'): factory_cons(norm) @parameterized.named_parameters( @@ -224,6 +312,8 @@ class ClippingFactoryExecutionTest(test_case.TestCase): client_weight = [1.0, 2.0, 1.0] output = process.next(state, client_data, client_weight) self.assertAllClose(7 / 4, output.result) + self.assertAllClose(2.0, output.measurements['clipping_norm']) + self.assertEqual(2, output.measurements['clipped_count']) def test_fixed_clip_mean_struct(self): factory = _clip_cons(4.0) @@ -238,6 +328,8 @@ class ClippingFactoryExecutionTest(test_case.TestCase): client_weight = [1.0, 2.0, 1.0] output = process.next(state, client_data, client_weight) self._check_result(7 / 4, output.result) + self.assertAllClose(4.0, output.measurements['clipping_norm']) + self.assertEqual(1, output.measurements['clipped_count']) def test_increasing_clip_mean(self): factory = _clip_cons(_test_norm_process()) @@ -250,13 +342,19 @@ class ClippingFactoryExecutionTest(test_case.TestCase): client_data = [1.0, 3.0, 5.0] client_weight = [1.0, 2.0, 1.0] output = process.next(state, client_data, client_weight) - self.assertAllClose(1, output.result) + self.assertAllClose(1.0, output.result) + self.assertAllClose(1.0, output.measurements['clipping_norm']) + self.assertEqual(2, output.measurements['clipped_count']) output = process.next(output.state, client_data, client_weight) self.assertAllClose(7 / 4, output.result) + self.assertAllClose(2.0, output.measurements['clipping_norm']) + self.assertEqual(2, output.measurements['clipped_count']) output = process.next(output.state, client_data, client_weight) self.assertAllClose(10 / 4, output.result) + self.assertAllClose(3.0, output.measurements['clipping_norm']) + self.assertEqual(1, output.measurements['clipped_count']) def test_fixed_zero_mean(self): factory = _zero_cons() @@ -270,9 +368,11 @@ class ClippingFactoryExecutionTest(test_case.TestCase): client_weight = [1.0, 2.0, 2.0] output = process.next(state, client_data, client_weight) self.assertAllClose(5 / 5, output.result) + self.assertAllClose(2.0, output.measurements['zeroing_norm']) + self.assertEqual(1, output.measurements['zeroed_count']) def test_fixed_zero_mean_struct(self): - factory = _zero_cons() + factory = _zero_cons(4.0) value_type = computation_types.to_type(_test_struct_type) process = factory.create(value_type) @@ -283,6 +383,23 @@ class ClippingFactoryExecutionTest(test_case.TestCase): client_weight = [1.0, 2.0, 2.0] output = process.next(state, client_data, client_weight) self._check_result(5 / 5, output.result) + self.assertAllClose(4.0, output.measurements['zeroing_norm']) + self.assertEqual(1, output.measurements['zeroed_count']) + + def test_fixed_zero_mean_struct_inf_norm(self): + factory = _zero_cons(2.0, np.inf) + + value_type = computation_types.to_type(_test_struct_type) + process = factory.create(value_type) + + state = process.initialize() + + client_data = [_make_test_struct_value(v) for v in [1.0, 2.0, 5.0]] + client_weight = [1.0, 2.0, 2.0] + output = process.next(state, client_data, client_weight) + self._check_result(5 / 5, output.result) + self.assertAllClose(2.0, output.measurements['zeroing_norm']) + self.assertEqual(1, output.measurements['zeroed_count']) def test_increasing_zero_mean(self): factory = _zero_cons(_test_norm_process()) @@ -296,12 +413,18 @@ class ClippingFactoryExecutionTest(test_case.TestCase): client_weight = [1.0, 2.0, 1.0] output = process.next(state, client_data, client_weight) self.assertAllClose(0.5 / 4, output.result) + self.assertAllClose(1.0, output.measurements['zeroing_norm']) + self.assertAllClose(2.0, output.measurements['zeroed_count']) output = process.next(output.state, client_data, client_weight) self.assertAllClose(3.5 / 4, output.result) + self.assertAllClose(2.0, output.measurements['zeroing_norm']) + self.assertAllClose(1.0, output.measurements['zeroed_count']) output = process.next(output.state, client_data, client_weight) self.assertAllClose(6 / 4, output.result) + self.assertAllClose(3.0, output.measurements['zeroing_norm']) + self.assertEqual(0, output.measurements['zeroed_count']) def test_fixed_zero_clip_mean(self): factory = _zero_clip_cons() @@ -317,6 +440,10 @@ class ClippingFactoryExecutionTest(test_case.TestCase): # Zeroing norm is 5.0, clipping norm is 2.0 output = process.next(state, client_data, client_weight) self.assertAllClose(5 / 4, output.result) + self.assertAllClose(2.0, output.measurements['clipping_norm']) + self.assertAllClose(5.0, output.measurements['zeroing_norm']) + self.assertAllClose(2, output.measurements['clipped_count']) + self.assertAllClose(1, output.measurements['zeroed_count']) def test_fixed_zero_clip_mean_struct(self): factory = _zero_clip_cons(4.0) @@ -333,6 +460,10 @@ class ClippingFactoryExecutionTest(test_case.TestCase): # Norms are 2.0, 6.0, 8.0 output = process.next(state, client_data, client_weight) self._check_result(5 / 4, output.result) + self.assertAllClose(4.0, output.measurements['clipping_norm']) + self.assertAllClose(7.0, output.measurements['zeroing_norm']) + self.assertAllClose(2, output.measurements['clipped_count']) + self.assertAllClose(1, output.measurements['zeroed_count']) def test_increasing_zero_clip_mean(self): factory = _zero_clip_cons(_test_norm_process()) @@ -348,14 +479,36 @@ class ClippingFactoryExecutionTest(test_case.TestCase): # Zeroing norm is 4.0, clipping norm is 1.0 output = process.next(state, client_data, client_weight) self.assertAllClose(2 / 5, output.result) + self.assertAllClose(1.0, output.measurements['clipping_norm']) + self.assertAllClose(4.0, output.measurements['zeroing_norm']) + self.assertAllClose(3, output.measurements['clipped_count']) + self.assertAllClose(2, output.measurements['zeroed_count']) # Zeroing norm is 5.0, clipping norm is 2.0 output = process.next(output.state, client_data, client_weight) self.assertAllClose(8 / 5, output.result) + self.assertAllClose(2.0, output.measurements['clipping_norm']) + self.assertAllClose(5.0, output.measurements['zeroing_norm']) + self.assertAllClose(2, output.measurements['clipped_count']) + self.assertAllClose(1, output.measurements['zeroed_count']) # Zeroing norm is 6.0, clipping norm is 3.0 output = process.next(output.state, client_data, client_weight) self.assertAllClose(13 / 5, output.result) + self.assertAllClose(3.0, output.measurements['clipping_norm']) + self.assertAllClose(6.0, output.measurements['zeroing_norm']) + self.assertAllClose(2, output.measurements['clipped_count']) + self.assertAllClose(0, output.measurements['zeroed_count']) + + +class NormTest(test_case.TestCase): + + def test_norms(self): + values = [1.0, -2.0, 3.0, -4.0] + for l in itertools.permutations(values): + v = [tf.constant(l[0]), (tf.constant([l[1], l[2]]), tf.constant([l[3]]))] + self.assertAllClose(4.0, clipping_factory._global_inf_norm(v).numpy()) + self.assertAllClose(10.0, clipping_factory._global_l1_norm(v).numpy()) if __name__ == '__main__':