Skip to content
Snippets Groups Projects
Commit 8bb2ffb4 authored by Galen Andrew's avatar Galen Andrew Committed by tensorflow-copybara
Browse files

Remove some dependencies on tensorflow_privacy internals and general cleanup.

PiperOrigin-RevId: 322257637
parent e55c8b74
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,10 @@ def _make_value(x, shapes):
return [tf.constant(x, dtype=tf.float32, shape=shape) for shape in shapes]
test_value_types = [('scalar', [()]), ('vector', [(2,)]), ('matrix', [(3, 4)]),
('complex', [(), (2,), (3, 4)])]
class AdaptiveZeroingTest(test.TestCase, parameterized.TestCase):
def _check_result(self, output, expected, shapes):
......@@ -43,8 +47,8 @@ class AdaptiveZeroingTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('float', 0.0),
('list', [0.0, 0.0]),
('odict', _Odict([('a', 0.0), ('b', 0.0)])),
('nested', _Odict([('a', _Odict([('b', [0.0])])), ('c', [0.0, (0.0,)])])),
('odict', _Odict(a=0.0, b=0.0)),
('nested', _Odict(a=_Odict(b=[0.0]), c=[0.0, (0.0,)])),
('tensors', [0.0, tf.zeros([1]), tf.zeros([2, 2])]),
)
def test_process_type_signature(self, value_template):
......@@ -91,12 +95,7 @@ class AdaptiveZeroingTest(test.TestCase, parameterized.TestCase):
adaptive_zeroing.build_adaptive_zeroing_mean_process(
value_type, 100.0, 0.99, 2.0, 1.0, np.inf)
@parameterized.named_parameters(
('scalar', [()]),
('vector', [(2,)]),
('matrix', [(3, 4)]),
('complex', [(), (2,), (3, 4)]),
)
@parameterized.named_parameters(test_value_types)
def test_simple_average(self, shapes):
value_type = type_conversions.type_from_tensors(_make_value(0, shapes))
......@@ -135,25 +134,18 @@ class AdaptiveZeroingTest(test.TestCase, parameterized.TestCase):
metrics = output['measurements']
self.assertEqual(metrics.num_zeroed, 2)
@parameterized.named_parameters(
('scalar', [()]),
('vector', [(2,)]),
('matrix', [(3, 4)]),
('complex', [(), (2,), (3, 4)]),
)
@parameterized.named_parameters(test_value_types)
def test_adaptation_down(self, shapes):
value_type = type_conversions.type_from_tensors(_make_value(0, shapes))
mean_process = adaptive_zeroing.build_adaptive_zeroing_mean_process(
value_type, 100.0, 0.0, 1.0, np.log(2.0), np.inf)
global_state = mean_process.initialize()
self.assertEqual(global_state.current_estimate, 100.0)
values = [_make_value(x, shapes) for x in [0, 1, 2]]
output = mean_process.next(global_state, values, [1, 1, 1])
self._check_result(output, 1, shapes)
global_state = output['state']
self.assertAllClose(global_state.current_estimate, 50.0)
metrics = output['measurements']
self.assertAllClose(metrics.current_threshold, 50.0)
self.assertEqual(metrics.num_zeroed, 0)
......@@ -161,30 +153,22 @@ class AdaptiveZeroingTest(test.TestCase, parameterized.TestCase):
output = mean_process.next(global_state, values, [1, 1, 1])
self._check_result(output, 1, shapes)
global_state = output['state']
self.assertAllClose(global_state.current_estimate, 25.0)
metrics = output['measurements']
self.assertAllClose(metrics.current_threshold, 25.0)
self.assertEqual(metrics.num_zeroed, 0)
@parameterized.named_parameters(
('scalar', [()]),
('vector', [(2,)]),
('matrix', [(3, 4)]),
('complex', [(), (2,), (3, 4)]),
)
@parameterized.named_parameters(test_value_types)
def test_adaptation_up(self, shapes):
value_type = type_conversions.type_from_tensors(_make_value(0, shapes))
mean_process = adaptive_zeroing.build_adaptive_zeroing_mean_process(
value_type, 1.0, 1.0, 1.0, np.log(2.0), np.inf)
global_state = mean_process.initialize()
self.assertEqual(global_state.current_estimate, 1.0)
values = [_make_value(x, shapes) for x in [90, 91, 92]]
output = mean_process.next(global_state, values, [1, 1, 1])
self._check_result(output, 0, shapes)
global_state = output['state']
self.assertAllClose(global_state.current_estimate, 2.0)
metrics = output['measurements']
self.assertAllClose(metrics.current_threshold, 2.0)
self.assertEqual(metrics.num_zeroed, 3)
......@@ -192,24 +176,17 @@ class AdaptiveZeroingTest(test.TestCase, parameterized.TestCase):
output = mean_process.next(global_state, values, [1, 1, 1])
self._check_result(output, 0, shapes)
global_state = output['state']
self.assertAllClose(global_state.current_estimate, 4.0)
metrics = output['measurements']
self.assertAllClose(metrics.current_threshold, 4.0)
self.assertEqual(metrics.num_zeroed, 3)
@parameterized.named_parameters(
('scalar', [()]),
('vector', [(2,)]),
('matrix', [(3, 4)]),
('complex', [(), (2,), (3, 4)]),
)
@parameterized.named_parameters(test_value_types)
def test_adaptation_achieved(self, shapes):
value_type = type_conversions.type_from_tensors(_make_value(0, shapes))
mean_process = adaptive_zeroing.build_adaptive_zeroing_mean_process(
value_type, 100.0, 0.5, 1.0, np.log(4.0), np.inf)
global_state = mean_process.initialize()
self.assertEqual(global_state.current_estimate, 100.0)
values = [_make_value(x, shapes) for x in [30, 60]]
......@@ -218,7 +195,6 @@ class AdaptiveZeroingTest(test.TestCase, parameterized.TestCase):
output = mean_process.next(global_state, values, [1, 1])
self._check_result(output, 45, shapes)
global_state = output['state']
self.assertAllClose(global_state.current_estimate, 50.0)
metrics = output['measurements']
self.assertAllClose(metrics.current_threshold, 50.0)
self.assertEqual(metrics.num_zeroed, 0)
......@@ -228,24 +204,17 @@ class AdaptiveZeroingTest(test.TestCase, parameterized.TestCase):
output = mean_process.next(global_state, values, [1, 1])
self._check_result(output, 30, shapes)
global_state = output['state']
self.assertAllClose(global_state.current_estimate, 50.0)
metrics = output['measurements']
self.assertAllClose(metrics.current_threshold, 50.0)
self.assertEqual(metrics.num_zeroed, 1)
@parameterized.named_parameters(
('scalar', [()]),
('vector', [(2,)]),
('matrix', [(3, 4)]),
('complex', [(), (2,), (3, 4)]),
)
@parameterized.named_parameters(test_value_types)
def test_adaptation_achieved_with_multiplier(self, shapes):
value_type = type_conversions.type_from_tensors(_make_value(0, shapes))
mean_process = adaptive_zeroing.build_adaptive_zeroing_mean_process(
value_type, 200.0, 0.5, 2.0, np.log(4.0), np.inf)
global_state = mean_process.initialize()
self.assertEqual(global_state.current_estimate, 100.0)
values = [_make_value(x, shapes) for x in [30, 60]]
......@@ -254,7 +223,6 @@ class AdaptiveZeroingTest(test.TestCase, parameterized.TestCase):
output = mean_process.next(global_state, values, [1, 1])
self._check_result(output, 45, shapes)
global_state = output['state']
self.assertAllClose(global_state.current_estimate, 50.0)
metrics = output['measurements']
self.assertAllClose(metrics.current_threshold, 100.0)
self.assertEqual(metrics.num_zeroed, 0)
......@@ -264,7 +232,6 @@ class AdaptiveZeroingTest(test.TestCase, parameterized.TestCase):
output = mean_process.next(global_state, values, [1, 1])
self._check_result(output, 45, shapes)
global_state = output['state']
self.assertAllClose(global_state.current_estimate, 50.0)
metrics = output['measurements']
self.assertAllClose(metrics.current_threshold, 100.0)
self.assertEqual(metrics.num_zeroed, 0)
......
......@@ -93,9 +93,10 @@ class BuildDpQueryTest(test.TestCase):
def make_mock_tensor(*dims):
return MockTensor(mock_shape([mock_dim(dim) for dim in dims]))
vectors = collections.OrderedDict([('a', make_mock_tensor(2)),
('b', make_mock_tensor(2, 3)),
('c', make_mock_tensor(1, 3, 4))])
vectors = collections.OrderedDict(
a=make_mock_tensor(2),
b=make_mock_tensor(2, 3),
c=make_mock_tensor(1, 3, 4))
model = mock_model(mock_weights(vectors))
query = differential_privacy.build_dp_query(
......@@ -238,7 +239,7 @@ class BuildDpAggregateProcessTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('float', 0.0), ('list', [0.0, 0.0]),
('odict', collections.OrderedDict([('a', 0.0), ('b', 0.0)])))
('odict', collections.OrderedDict(a=0.0, b=0.0)))
def test_process_type_signature(self, value_template):
query = tensorflow_privacy.GaussianSumQuery(4.0, 0.0)
value_type = type_conversions.type_from_tensors(value_template)
......@@ -246,7 +247,8 @@ class BuildDpAggregateProcessTest(test.TestCase, parameterized.TestCase):
value_type, query)
server_state_type = computation_types.FederatedType(
query._GlobalState(tf.float32, tf.float32), placements.SERVER)
type_conversions.type_from_tensors(query.initial_global_state()),
placements.SERVER)
self.assertEqual(
dp_aggregate_process.initialize.type_signature,
computation_types.FunctionType(
......@@ -264,10 +266,10 @@ class BuildDpAggregateProcessTest(test.TestCase, parameterized.TestCase):
computation_types.FunctionType(
parameter=(server_state_type, client_value_type,
client_value_weight_type),
result=collections.OrderedDict([('state', server_state_type),
('result', server_result_type),
('measurements',
server_metrics_type)])))
result=collections.OrderedDict(
state=server_state_type,
result=server_result_type,
measurements=server_metrics_type)))
def test_dp_sum(self):
query = tensorflow_privacy.GaussianSumQuery(4.0, 0.0)
......@@ -289,7 +291,7 @@ class BuildDpAggregateProcessTest(test.TestCase, parameterized.TestCase):
query = tensorflow_privacy.GaussianSumQuery(5.0, 0.0)
def datapoint(a, b):
return collections.OrderedDict([('a', (a,)), ('b', [b])])
return collections.OrderedDict(a=(a,), b=[b])
data = [
datapoint(1.0, 2.0),
......@@ -315,10 +317,8 @@ class BuildDpAggregateProcessTest(test.TestCase, parameterized.TestCase):
query = tensorflow_privacy.GaussianSumQuery(5.0, 0.0)
def datapoint(a, b, c):
return collections.OrderedDict([('a', (a,)),
('bc',
collections.OrderedDict([('b', [b]),
('c', (c,))]))])
return collections.OrderedDict(
a=(a,), bc=collections.OrderedDict(b=[b], c=(c,)))
data = [
datapoint(1.0, 2.0, 1.0),
......@@ -345,7 +345,7 @@ class BuildDpAggregateProcessTest(test.TestCase, parameterized.TestCase):
query = tensorflow_privacy.GaussianSumQuery(5.0, 0.0)
def datapoint(a, b, c):
return collections.OrderedDict([('a', (a,)), ('bc', ([b], (c,)))])
return collections.OrderedDict(a=(a,), bc=([b], (c,)))
data = [
datapoint(1.0, 2.0, 1.0),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment