differential_privacy.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for interop with tensorflow_privacy."""

16
17
import math
import numbers
18
import warnings
19
20

import numpy as np
21
import tensorflow as tf
22
import tensorflow_privacy
23
24

from tensorflow_federated.python.common_libs import py_typecheck
25
from tensorflow_federated.python.core.api import computation_types
26
27
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import intrinsics
28
from tensorflow_federated.python.core.api import placements
29
from tensorflow_federated.python.core.impl.types import type_conversions
30
from tensorflow_federated.python.core.templates import measured_process
31
32
33
34

# TODO(b/140236959): Make the nomenclature consistent (b/w 'record' and 'value')
# in this library.

35
36
37
38
# Note if the functions here change, the documentation at
# https://github.com/tensorflow/federated/blob/master/docs/tff_for_research.md
# should be updated.

39

40
41
42
43
44
45
46
47
48
49
def _distribute_clip(clip, vectors):

  def dim(v):
    return math.exp(sum([math.log(d.value) for d in v.shape.dims]))

  dims = tf.nest.map_structure(dim, vectors)
  total_dim = sum(tf.nest.flatten(dims))
  return tf.nest.map_structure(lambda d: clip * np.sqrt(d / total_dim), dims)


50
51
52
53
54
55
def build_dp_query(clip,
                   noise_multiplier,
                   expected_total_weight,
                   adaptive_clip_learning_rate=0,
                   target_unclipped_quantile=None,
                   clipped_count_budget_allocation=None,
56
                   expected_clients_per_round=None,
57
                   geometric_clip_update=True):
58
59
60
61
62
63
64
65
  """Makes a `DPQuery` to estimate vector averages with differential privacy.

  Supports many of the types of query available in tensorflow_privacy, including
  nested ("per-vector") queries as described in
  https://arxiv.org/pdf/1812.06210.pdf, and quantile-based adaptive clipping as
  described in https://arxiv.org/abs/1905.03871.

  Args:
66
67
    clip: The query's L2 norm bound, or the initial clip if adaptive clipping
      is used.
68
69
70
71
    noise_multiplier: The ratio of the (effective) noise stddev to the clip.
    expected_total_weight: The expected total weight of all clients, used as the
      denominator for the average computation.
    adaptive_clip_learning_rate: Learning rate for quantile-based adaptive
72
      clipping. If 0, fixed clipping is used.
73
74
75
    target_unclipped_quantile: Target unclipped quantile for adaptive clipping.
    clipped_count_budget_allocation: The fraction of privacy budget to use for
      estimating clipped counts.
76
77
    expected_clients_per_round: The expected number of clients for estimating
      clipped fractions.
78
    geometric_clip_update: If True, use geometric updating of the clip.
79
80

  Returns:
81
82
83
    A `DPQuery` suitable for use in a call to `build_dp_aggregate` and
    `build_dp_aggregate_process` to perform Federated Averaging with
    differential privacy.
84
85
86
87
88
89
90
91
92
93
94
95
96
  """
  py_typecheck.check_type(clip, numbers.Number, 'clip')
  py_typecheck.check_type(noise_multiplier, numbers.Number, 'noise_multiplier')
  py_typecheck.check_type(expected_total_weight, numbers.Number,
                          'expected_total_weight')

  if adaptive_clip_learning_rate:
    py_typecheck.check_type(adaptive_clip_learning_rate, numbers.Number,
                            'adaptive_clip_learning_rate')
    py_typecheck.check_type(target_unclipped_quantile, numbers.Number,
                            'target_unclipped_quantile')
    py_typecheck.check_type(clipped_count_budget_allocation, numbers.Number,
                            'clipped_count_budget_allocation')
97
98
    py_typecheck.check_type(expected_clients_per_round, numbers.Number,
                            'expected_clients_per_round')
99
    p = clipped_count_budget_allocation
100
    nm = noise_multiplier
101
102
    vectors_noise_multiplier = nm * (1 - p)**(-0.5)
    clipped_count_noise_multiplier = nm * p**(-0.5)
103
104
105

    # Clipped count sensitivity is 0.5.
    clipped_count_stddev = 0.5 * clipped_count_noise_multiplier
106
107
108
109
110
111
112
113
114
115

    return tensorflow_privacy.QuantileAdaptiveClipAverageQuery(
        initial_l2_norm_clip=clip,
        noise_multiplier=vectors_noise_multiplier,
        target_unclipped_quantile=target_unclipped_quantile,
        learning_rate=adaptive_clip_learning_rate,
        clipped_count_stddev=clipped_count_stddev,
        expected_num_records=expected_clients_per_round,
        geometric_update=geometric_clip_update,
        denominator=expected_total_weight)
116
117
118
119
120
121
122
  else:
    if target_unclipped_quantile is not None:
      warnings.warn(
          'target_unclipped_quantile is specified but '
          'adaptive_clip_learning_rate is zero. No adaptive clipping will be '
          'performed. Use adaptive_clip_learning_rate > 0 if you want '
          'adaptive clipping.')
123
    if clipped_count_budget_allocation is not None:
124
125
126
127
128
      warnings.warn(
          'clipped_count_budget_allocation is specified but '
          'adaptive_clip_learning_rate is zero. No adaptive clipping will be '
          'performed. Use adaptive_clip_learning_rate > 0 if you want '
          'adaptive clipping.')
129
130
131
132
    return tensorflow_privacy.GaussianAverageQuery(
        l2_norm_clip=clip,
        sum_stddev=clip * noise_multiplier,
        denominator=expected_total_weight)
133
134


135
136
137
138
139
140
141
142
143
144
def build_dp_aggregate_process(value_type, query):
  """Builds a `MeasuredProcess` for tensorflow_privacy DPQueries.

  The returned `MeasuredProcess` processes values of type value_type which can
  be any nested structure of tensors. Note that client weighting is not
  supported for differential privacy so the `weight` argument to the resulting
  `MeasuredProcess` will be ignored.

  Args:
    value_type: The type of values to be aggregated by the `MeasuredProcess`.
145
      Can be a `tff.TensorType` or a nested structure of `tff.StructType`
146
147
148
149
150
151
152
153
154
155
156
157
      that bottoms out in `tff.TensorType`.
    query: A DPQuery to aggregate. For compatibility with tensorflow_federated,
      the global_state and sample_state of the query must be structures
      supported by tf.nest.

  Returns:
    A `MeasuredProcess` implementing differentially private aggregation using
    the supplied DPQuery. Note that client weighting is not
  supported for differential privacy so the `weight` argument to the resulting
  `MeasuredProcess` will be ignored.
  """
  py_typecheck.check_type(
158
      value_type, (computation_types.TensorType, computation_types.StructType))
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

  @computations.tf_computation
  def initial_state_fn():
    return query.initial_global_state()

  @computations.federated_computation()
  def initial_state_comp():
    return intrinsics.federated_eval(initial_state_fn, placements.SERVER)

  #######################################
  # Define local tf_computations

  global_state_type = initial_state_fn.type_signature.result

  @computations.tf_computation(global_state_type)
  def derive_sample_params(global_state):
    return query.derive_sample_params(global_state)

  @computations.tf_computation(derive_sample_params.type_signature.result,
                               value_type)
  def preprocess_record(params, record):
    return query.preprocess_record(params, record)

  tensor_specs = type_conversions.type_to_tf_tensor_specs(value_type)

  @computations.tf_computation
  def zero():
    return query.initial_sample_state(tensor_specs)

  sample_state_type = zero.type_signature.result

  @computations.tf_computation(sample_state_type,
                               preprocess_record.type_signature.result)
  def accumulate(sample_state, preprocessed_record):
    return query.accumulate_preprocessed_record(sample_state,
                                                preprocessed_record)

  @computations.tf_computation(sample_state_type, sample_state_type)
  def merge(sample_state_1, sample_state_2):
    return query.merge_sample_states(sample_state_1, sample_state_2)

  @computations.tf_computation(merge.type_signature.result)
  def report(sample_state):
    return sample_state

  @computations.tf_computation(sample_state_type, global_state_type)
  def post_process(sample_state, global_state):
    result, new_global_state = query.get_noised_result(sample_state,
                                                       global_state)
    return new_global_state, result

210
211
212
213
  @computations.tf_computation(global_state_type)
  def derive_metrics(global_state):
    return query.derive_metrics(global_state)

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
  @computations.federated_computation(
      initial_state_comp.type_signature.result,
      computation_types.FederatedType(value_type, placements.CLIENTS),
      computation_types.FederatedType(tf.float32, placements.CLIENTS))
  def next_fn(global_state, value, weight):
    """Defines next_fn for MeasuredProcess."""
    # Weighted aggregation is not supported.
    # TODO(b/140236959): Add an assertion that weight is None here, so the
    # contract of this method is better established. Will likely cause some
    # downstream breaks.
    del weight

    sample_params = intrinsics.federated_map(derive_sample_params, global_state)
    client_sample_params = intrinsics.federated_broadcast(sample_params)
    preprocessed_record = intrinsics.federated_map(
        preprocess_record, (client_sample_params, value))
    agg_result = intrinsics.federated_aggregate(preprocessed_record, zero(),
                                                accumulate, merge, report)

    updated_state, result = intrinsics.federated_map(post_process,
                                                     (agg_result, global_state))

236
    metrics = intrinsics.federated_map(derive_metrics, updated_state)
237

238
    return measured_process.MeasuredProcessOutput(
239
        state=updated_state, result=result, measurements=metrics)
240
241
242

  return measured_process.MeasuredProcess(
      initialize_fn=initial_state_comp, next_fn=next_fn)