differential_privacy_test.py 9.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

17
from absl.testing import parameterized
18
import tensorflow as tf
19
import tensorflow_privacy
20
from tensorflow_federated.python.core.api import computation_types
21
from tensorflow_federated.python.core.api import placements
22
from tensorflow_federated.python.core.api import test_case
23
from tensorflow_federated.python.core.backends.native import execution_contexts
24
from tensorflow_federated.python.core.impl.types import type_conversions
25
from tensorflow_federated.python.core.templates import measured_process
26
27
28
from tensorflow_federated.python.core.utils import differential_privacy


29
class BuildDpQueryTest(test_case.TestCase):
30

31
32
  def test_build_dp_query_basic(self):
    query = differential_privacy.build_dp_query(1.0, 2.0, 3.0)
33
    self.assertIsInstance(query, tensorflow_privacy.GaussianAverageQuery)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

    self.assertEqual(query._numerator._l2_norm_clip, 1.0)
    self.assertEqual(query._numerator._stddev, 2.0)
    self.assertEqual(query._denominator, 3.0)

  def test_build_dp_query_adaptive(self):
    ccba = 0.1

    query = differential_privacy.build_dp_query(
        1.0,
        2.0,
        3.0,
        adaptive_clip_learning_rate=0.05,
        target_unclipped_quantile=0.5,
        clipped_count_budget_allocation=ccba,
49
        expected_clients_per_round=10)
50
51
    self.assertIsInstance(query,
                          tensorflow_privacy.QuantileAdaptiveClipAverageQuery)
52
53
    self.assertIsInstance(query._numerator,
                          tensorflow_privacy.QuantileAdaptiveClipSumQuery)
54
55
56
57
58
59

    expected_sum_query_noise_multiplier = 2.0 * (1.0 - ccba)**(-0.5)
    self.assertAlmostEqual(query._numerator._noise_multiplier,
                           expected_sum_query_noise_multiplier)
    self.assertEqual(query._denominator, 3.0)

60

61
class BuildDpAggregateProcessTest(test_case.TestCase, parameterized.TestCase):
62
63
64

  @parameterized.named_parameters(
      ('float', 0.0), ('list', [0.0, 0.0]),
65
      ('odict', collections.OrderedDict(a=0.0, b=0.0)))
66
67
68
69
70
71
  def test_process_type_signature(self, value_template):
    query = tensorflow_privacy.GaussianSumQuery(4.0, 0.0)
    value_type = type_conversions.type_from_tensors(value_template)
    dp_aggregate_process = differential_privacy.build_dp_aggregate_process(
        value_type, query)

72
    global_state = query.initial_global_state()
73
    server_state_type = computation_types.FederatedType(
74
        type_conversions.type_from_tensors(global_state), placements.SERVER)
75
76
77
78
79
    self.assertEqual(
        dp_aggregate_process.initialize.type_signature,
        computation_types.FunctionType(
            parameter=None, result=server_state_type))

80
81
82
    metrics_type = type_conversions.type_from_tensors(
        query.derive_metrics(global_state))

83
84
85
86
87
88
    client_value_type = computation_types.FederatedType(value_type,
                                                        placements.CLIENTS)
    client_value_weight_type = computation_types.FederatedType(
        tf.float32, placements.CLIENTS)
    server_result_type = computation_types.FederatedType(
        value_type, placements.SERVER)
89
90
    server_metrics_type = computation_types.FederatedType(
        metrics_type, placements.SERVER)
91
92
93
94
95
96
97
98
99
100
101
    self.assertTrue(
        dp_aggregate_process.next.type_signature.is_equivalent_to(
            computation_types.FunctionType(
                parameter=collections.OrderedDict(
                    global_state=server_state_type,
                    value=client_value_type,
                    weight=client_value_weight_type),
                result=measured_process.MeasuredProcessOutput(
                    state=server_state_type,
                    result=server_result_type,
                    measurements=server_metrics_type))))
102
103
104
105
106
107
108
109
110
111
112
113
114

  def test_dp_sum(self):
    query = tensorflow_privacy.GaussianSumQuery(4.0, 0.0)

    value_type = type_conversions.type_from_tensors(0.0)
    dp_aggregate_process = differential_privacy.build_dp_aggregate_process(
        value_type, query)

    global_state = dp_aggregate_process.initialize()

    output = dp_aggregate_process.next(global_state, [1.0, 3.0, 5.0],
                                       [1.0, 1.0, 1.0])

115
116
117
    self.assertEqual(output.state.l2_norm_clip, 4.0)
    self.assertEqual(output.state.stddev, 0.0)
    self.assertEqual(output.result, 8.0)
118
119
120
121
122

  def test_dp_sum_structure_odict(self):
    query = tensorflow_privacy.GaussianSumQuery(5.0, 0.0)

    def datapoint(a, b):
123
      return collections.OrderedDict(a=(a,), b=[b])
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

    data = [
        datapoint(1.0, 2.0),
        datapoint(2.0, 3.0),
        datapoint(6.0, 8.0),  # Clipped to 3.0, 4.0
    ]

    value_type = type_conversions.type_from_tensors(data[0])
    dp_aggregate_process = differential_privacy.build_dp_aggregate_process(
        value_type, query)

    global_state = dp_aggregate_process.initialize()

    output = dp_aggregate_process.next(global_state, data, [1.0, 1.0, 1.0])

139
140
    self.assertEqual(output.state.l2_norm_clip, 5.0)
    self.assertEqual(output.state.stddev, 0.0)
141

142
143
    self.assertEqual(output.result['a'][0], 6.0)
    self.assertEqual(output.result['b'][0], 9.0)
144
145
146
147
148

  def test_dp_sum_structure_nested_odict(self):
    query = tensorflow_privacy.GaussianSumQuery(5.0, 0.0)

    def datapoint(a, b, c):
149
150
      return collections.OrderedDict(
          a=(a,), bc=collections.OrderedDict(b=[b], c=(c,)))
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

    data = [
        datapoint(1.0, 2.0, 1.0),
        datapoint(2.0, 3.0, 1.0),
        datapoint(6.0, 8.0, 0.0),  # Clipped to 3.0, 4.0, 0.0
    ]

    value_type = type_conversions.type_from_tensors(data[0])
    dp_aggregate_process = differential_privacy.build_dp_aggregate_process(
        value_type, query)

    global_state = dp_aggregate_process.initialize()

    output = dp_aggregate_process.next(global_state, data, [1.0, 1.0, 1.0])

166
167
    self.assertEqual(output.state.l2_norm_clip, 5.0)
    self.assertEqual(output.state.stddev, 0.0)
168

169
170
171
    self.assertEqual(output.result['a'][0], 6.0)
    self.assertEqual(output.result['bc']['b'][0], 9.0)
    self.assertEqual(output.result['bc']['c'][0], 2.0)
172
173
174
175
176

  def test_dp_sum_structure_complex(self):
    query = tensorflow_privacy.GaussianSumQuery(5.0, 0.0)

    def datapoint(a, b, c):
177
      return collections.OrderedDict(a=(a,), bc=([b], (c,)))
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

    data = [
        datapoint(1.0, 2.0, 1.0),
        datapoint(2.0, 3.0, 1.0),
        datapoint(6.0, 8.0, 0.0),  # Clipped to 3.0, 4.0, 0.0
    ]

    value_type = type_conversions.type_from_tensors(data[0])
    dp_aggregate_process = differential_privacy.build_dp_aggregate_process(
        value_type, query)

    global_state = dp_aggregate_process.initialize()

    output = dp_aggregate_process.next(global_state, data, [1.0, 1.0, 1.0])

193
194
    self.assertEqual(output.state.l2_norm_clip, 5.0)
    self.assertEqual(output.state.stddev, 0.0)
195

196
197
198
    self.assertEqual(output.result['a'][0], 6.0)
    self.assertEqual(output.result['bc'][0][0], 9.0)
    self.assertEqual(output.result['bc'][1][0], 2.0)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

  def test_dp_sum_structure_list(self):
    query = tensorflow_privacy.GaussianSumQuery(5.0, 0.0)

    def datapoint(a, b):
      return [tf.Variable(a, name='a'), tf.Variable(b, name='b')]

    data = [
        datapoint(1.0, 2.0),
        datapoint(2.0, 3.0),
        datapoint(6.0, 8.0),  # Clipped to 3.0, 4.0
    ]

    value_type = type_conversions.type_from_tensors(data[0])

    dp_aggregate_process = differential_privacy.build_dp_aggregate_process(
        value_type, query)

    global_state = dp_aggregate_process.initialize()

    output = dp_aggregate_process.next(global_state, data, [1.0, 1.0, 1.0])

221
222
    self.assertEqual(output.state.l2_norm_clip, 5.0)
    self.assertEqual(output.state.stddev, 0.0)
223

224
    result = list(output.result)
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    self.assertEqual(result[0], 6.0)
    self.assertEqual(result[1], 9.0)

  def test_dp_stateful_mean(self):

    class ShrinkingSumQuery(tensorflow_privacy.GaussianSumQuery):

      def get_noised_result(self, sample_state, global_state):
        global_state = self._GlobalState(
            tf.maximum(global_state.l2_norm_clip - 1, 0.0), global_state.stddev)

        return sample_state, global_state

    query = ShrinkingSumQuery(4.0, 0.0)

    value_type = type_conversions.type_from_tensors(0.0)
    dp_aggregate_process = differential_privacy.build_dp_aggregate_process(
        value_type, query)

    global_state = dp_aggregate_process.initialize()

    records = [1.0, 3.0, 5.0]

    def run_and_check(global_state, expected_l2_norm_clip, expected_result):
      output = dp_aggregate_process.next(global_state, records, [1.0, 1.0, 1.0])
250
251
252
      self.assertEqual(output.state.l2_norm_clip, expected_l2_norm_clip)
      self.assertEqual(output.result, expected_result)
      return output.state
253

254
    self.assertEqual(global_state.l2_norm_clip, 4.0)
255
256
257
258
259
260
261
    global_state = run_and_check(global_state, 3.0, 8.0)
    global_state = run_and_check(global_state, 2.0, 7.0)
    global_state = run_and_check(global_state, 1.0, 5.0)
    global_state = run_and_check(global_state, 0.0, 3.0)
    global_state = run_and_check(global_state, 0.0, 0.0)


262
if __name__ == '__main__':
263
  execution_contexts.set_local_execution_context()
264
  test_case.main()