提交 31d16fa2 编辑于 作者: A. Unique TensorFlower's avatar A. Unique TensorFlower 提交者: tensorflow-copybara
浏览文件

Refactors the aggregator and tff computation for hierarchical histogram.

Changes:
(1) Replaces the old dp_query `CentralTreeSumQuery` with the updated one `TreeRangeSumQuery`. The updated query has several advantages:
    1) It easily supports different DP mechanisms (e.g. GaussianSumQuery,
    DistributedDiscreteGaussianSumQuery) by taking them as inner queries;
    2) `TreeRangeSumQuery` does not hard-code any clipping inside so it is
    composable with various clipping factories. To avoid DP error, it does norm
   checking inside to make sure appropriate clipping happens outside.
(2) Removes the type conversion from int32 to float32 before return in `discretized_histogram_counts`. Now the function outputs int32 tensors. This change is due to the type requirement of `HistogramClippingSumFactory`.
(3) Updates tff's dependency on tensorflow privacy to 0.7.0.
PiperOrigin-RevId: 390277449
上级 39d795ba
......@@ -70,6 +70,6 @@ RUN ${PIP} install --no-cache-dir --upgrade \
retrying~=1.3.3 \
semantic-version~=2.8.5 \
tensorflow-model-optimization~=0.5.0 \
tensorflow-privacy~=0.6.2 \
tensorflow-privacy~=0.7.1 \
tf-nightly
RUN pip freeze
......@@ -9,8 +9,8 @@ py_library(
srcs = ["hierarchical_histogram_factory.py"],
srcs_version = "PY3",
deps = [
":clipping_factory",
"//tensorflow_federated/python/aggregators:differential_privacy",
"//tensorflow_federated/python/aggregators:secure",
"//tensorflow_federated/python/aggregators:sum_factory",
],
)
......@@ -23,7 +23,7 @@ py_test(
deps = [
":build_tree_from_leaf",
":hierarchical_histogram_factory",
"//tensorflow_federated/python/aggregators:differential_privacy",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/test:execution_contexts",
"//tensorflow_federated/python/core/impl/types:computation_types",
......@@ -38,8 +38,8 @@ py_library(
srcs = ["hierarchical_histogram.py"],
srcs_version = "PY3",
deps = [
":clipping_factory",
":hierarchical_histogram_factory",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
......
......@@ -15,8 +15,8 @@
import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.analytics.hierarchical_histogram import hierarchical_histogram_factory
from tensorflow_federated.python.analytics.hierarchical_histogram import clipping_factory
from tensorflow_federated.python.analytics.hierarchical_histogram import hierarchical_histogram_factory as hihi_factory
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
......@@ -86,88 +86,97 @@ def _discretized_histogram_counts(client_data: tf.data.Dataset,
histogram = client_data.reduce(
tf.zeros([num_bins], dtype=tf.int32), insert_record)
return tf.cast(histogram, tf.float32)
return histogram
def _build_hierarchical_histogram_computation(
lower_bound: float, upper_bound: float, num_bins: int,
aggregation_factory: factory.UnweightedAggregationFactory):
"""Utility function creating tff computation given the parameters and factory.
def build_hierarchical_histogram_computation(
lower_bound: float,
upper_bound: float,
num_bins: int,
arity: int = 2,
clip_mechanism: str = 'sub-sampling',
max_records_per_user: int = 10,
dp_mechanism: str = 'no-noise',
noise_multiplier: float = 0.0):
"""Creates the TFF computation for hierarchical histogram aggregation.
Args:
lower_bound: A `float` specifying the lower bound of the data range.
upper_bound: A `float` specifying the upper bound of the data range.
num_bins: The integer number of bins to compute.
aggregation_factory: The aggregation factory used to construct the federated
computation.
arity: The branching factor of the tree. Defaults to 2.
clip_mechanism: A `str` representing the clipping mechanism. Currently
supported mechanisms are
- 'sub-sampling': (Default) Uniformly sample up to `max_records_per_user`
records without replacement from the client dataset.
- 'distinct': Uniquify client dataset and uniformly sample up to
`max_records_per_user` records without replacement from it.
max_records_per_user: An `int` representing the maximum of records each user
can include in their local histogram. Defaults to 10.
dp_mechanism: A `str` representing the differentially private mechanism to
use. Currently supported mechanisms are
- 'no-noise': (Default) Tree aggregation mechanism without noise.
- 'central-gaussian': Tree aggregation with central Gaussian mechanism.
noise_multiplier: A `float` specifying the noise multiplier (central noise
stddev / L2 clip norm) for model updates. Defaults to 0.0.
Returns:
A tff federated computation function.
A federated computation that performs hierarchical histogram aggregation.
"""
_check_greater_than_equal(upper_bound, lower_bound, 'upper_bound',
'lower_bound')
_check_positive(num_bins, 'num_bins')
_check_greater_than_equal_thres(arity, 2, 'arity')
_check_membership(clip_mechanism, clipping_factory.CLIP_MECHANISMS,
'clip_mechanism')
_check_greater_than_equal_thres(max_records_per_user, 1,
'max_records_per_user')
_check_membership(dp_mechanism, hihi_factory.DP_MECHANISMS, 'dp_mechanism')
_check_greater_than_equal_thres(noise_multiplier, 0., noise_multiplier)
@computations.tf_computation(computation_types.SequenceType(tf.float32))
def client_work(client_data):
return _discretized_histogram_counts(client_data, lower_bound, upper_bound,
num_bins)
aggregator = aggregation_factory.create(client_work.type_signature.result)
agg_factory = hihi_factory.create_hierarchical_histogram_aggregation_factory(
num_bins, arity, clip_mechanism, max_records_per_user, dp_mechanism,
noise_multiplier)
process = agg_factory.create(client_work.type_signature.result)
@computations.federated_computation(
computation_types.at_clients(client_work.type_signature.parameter))
def hierarchical_histogram_computation(federated_client_data):
# Work done at clients.
client_histogram = intrinsics.federated_map(client_work,
federated_client_data)
# Aggregation to server.
return aggregator.next(aggregator.initialize(), client_histogram).result
return process.next(process.initialize(), client_histogram).result
return hierarchical_histogram_computation
def build_central_hierarchical_histogram_computation(
lower_bound: float,
upper_bound: float,
num_bins: int,
arity: int = 2,
max_records_per_user: int = 1,
noise_multiplier: float = 0.0,
secure_sum: bool = False):
"""Create the tff federated computation for central hierarchical histogram aggregation.
Args:
lower_bound: A `float` specifying the lower bound of the data range.
upper_bound: A `float` specifying the upper bound of the data range.
num_bins: The integer number of bins to compute.
arity: The branching factor of the tree. Defaults to 2.
max_records_per_user: The maximum number of records each user is allowed to
contribute. Defaults to 1.
noise_multiplier: A `float` specifying the noise multiplier (central noise
stddev / L2 clip norm) for model updates. Defaults to 0.0.
secure_sum: A boolean deciding whether to use secure aggregation. Defaults
to `False`.
def _check_greater_than_equal(lvalue, rvalue, llabel, rlabel):
if lvalue < rvalue:
raise ValueError(f'`{llabel}` should be no smaller than '
f'`{rlabel}`. Found {lvalue} and '
f'{rvalue}.')
Returns:
A tff.federated_computation function to perform central tree aggregation.
"""
if upper_bound < lower_bound:
raise ValueError(f'upper_bound: {upper_bound} is smaller than '
f'lower_bound: {lower_bound}.')
def _check_greater_than_equal_thres(value, threshold, label):
if value < threshold:
raise ValueError(f'`{label}` must be at least {threshold}. Found {value}.')
if num_bins <= 0:
raise ValueError(f'num_bins: {num_bins} smaller or equal to zero.')
if arity < 2:
raise ValueError(f'Arity should be at least 2.' f'arity={arity} is given.')
def _check_positive(value, label):
if value <= 0:
raise ValueError(f'{label} must be positive. Found {value}.')
if max_records_per_user < 1:
raise ValueError(f'Maximum records per user should be at least 1. '
f'max_records_per_user={max_records_per_user} is given.')
stddev = max_records_per_user * noise_multiplier
def _check_non_negative(value, label):
if value < 0:
raise ValueError(f'{label} must be non-negative. Found {value}.')
central_tree_aggregation_factory = hierarchical_histogram_factory.create_central_hierarchical_histogram_factory(
stddev, arity, max_records_per_user, secure_sum=secure_sum)
return _build_hierarchical_histogram_computation(
lower_bound, upper_bound, num_bins, central_tree_aggregation_factory)
def _check_membership(value, valid_set, label):
if value not in valid_set:
raise ValueError(f'`{label}` must be one of {valid_set}. '
f'Found {value}.')
......@@ -11,62 +11,160 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to create differentially private tree aggregation factory.
The factory is created by wrapping proper tfp.DPQuery in
`DifferentiallyPrivateFactory`.
"""
"""Differentially private tree aggregation factory."""
import math
import tensorflow_privacy as tfp
from tensorflow_federated.python.aggregators import differential_privacy
from tensorflow_federated.python.aggregators import secure
from tensorflow_federated.python.aggregators import sum_factory
from tensorflow_federated.python.analytics.hierarchical_histogram import clipping_factory
# Supported no-noise mechanisms.
NO_NOISE_MECHANISMS = ['no-noise']
# Supported central DP mechanisms.
CENTRAL_DP_MECHANISMS = [
'central-gaussian', # Central Gaussian mechanism.
]
DP_MECHANISMS = CENTRAL_DP_MECHANISMS + NO_NOISE_MECHANISMS
def create_central_hierarchical_histogram_factory(
stddev: float = 0.0,
def create_hierarchical_histogram_aggregation_factory(
num_bins: int,
arity: int = 2,
clip_mechanism: str = 'sub-sampling',
max_records_per_user: int = 10,
secure_sum: bool = False):
"""Creates aggregator for hierarchical histograms with differential privacy.
dp_mechanism: str = 'central-gaussian',
noise_multiplier: float = 0.0):
"""Creates hierarchical histogram aggregation factory.
Hierarchical histogram factory is constructed by composing 3 aggregation
factories.
(1) The inner-most factory is `SumFactory`.
(2) The middle factory is `DifferentiallyPrivateFactory` whose inner query is
`TreeRangeSumQuery`. This factory 1) takes in a clipped histogram,
constructs the hierarchical histogram and checks the norm bound of the
hierarchical histogram at clients, 2) adds noise either at clients or at
server according to `dp_mechanism`.
(3) The outer-most factory is `HistogramClippingSumFactory` which clips the
input histogram to bound each user's contribution.
Args:
stddev: The standard deviation of noise added to each node of the central
tree.
arity: The branching factor of the tree.
max_records_per_user: The maximum of records each user can upload in their
local histogram.
secure_sum: A boolean deciding whether to use secure aggregation. Defaults
to `False`.
num_bins: An `int` representing the input histogram size.
arity: An `int` representing the branching factor of the tree. Defaults to
2.
clip_mechanism: A `str` representing the clipping mechanism. Currently
supported mechanisms are
- 'sub-sampling': (Default) Uniformly sample up to `max_records_per_user`
records without replacement from the client dataset.
- 'distinct': Uniquify client dataset and uniformly sample up to
`max_records_per_user` records without replacement from it.
max_records_per_user: An `int` representing the maximum of records each user
can include in their local histogram. Defaults to 10.
dp_mechanism: A `str` representing the differentially private mechanism to
use. Currently supported mechanisms are
- 'central-gaussian': (Default) Tree aggregation with central Gaussian
mechanism.
- 'no-noise': Tree aggregation mechanism without noise.
noise_multiplier: A `float` specifying the noise multiplier (central noise
stddev / L2 clip norm) for model updates. Only needed when `dp_mechanism`
is not 'no-noise'. Defaults to 0.0.
Returns:
`tff.aggregators.UnWeightedAggregationFactory`.
`tff.aggregators.UnweightedAggregationFactory`.
Raises:
`ValueError`: If 'stddev < 0', `arity < 2`, `max_records_per_user < 1` or
`inner_agg_factory` is illegal.
TypeError: If arguments have the wrong type(s).
ValueError: If arguments have invalid value(s).
"""
if stddev < 0:
raise ValueError(f"Standard deviation should be greater than zero."
f"stddev={stddev} is given.")
_check_positive(num_bins, 'num_bins')
_check_greater_equal(arity, 2, 'arity')
_check_membership(clip_mechanism, clipping_factory.CLIP_MECHANISMS,
'clip_mechanism')
_check_positive(max_records_per_user, 'max_records_per_user')
_check_membership(dp_mechanism, DP_MECHANISMS, 'dp_mechanism')
_check_non_negative(noise_multiplier, 'noise_multiplier')
if arity < 2:
raise ValueError(f"Arity should be at least 2." f"arity={arity} is given.")
# Build nested aggregtion factory from innermost to outermost.
# 1. Sum factory. The most inner factory that sums the preprocessed records.
nested_factory = sum_factory.SumFactory()
if max_records_per_user < 1:
raise ValueError(f"Maximum records per user should be at least 1."
f"max_records_per_user={max_records_per_user} is given.")
# 2. DP operations.
# (1) Converts `max_records_per_user` to the corresponding norm bound
# according to the chosen `clip_mechanism` and `dp_mechanism`.
if dp_mechanism in ['central-gaussian']:
if clip_mechanism == 'sub-sampling':
l2_norm_bound = max_records_per_user * math.sqrt(
_tree_depth(num_bins, arity))
elif clip_mechanism == 'distinct':
# The following code block converts `max_records_per_user` to L2 norm
# bound of the hierarchical histogram layer by layer. For the bottom
# layer with only 0s and at most `max_records_per_user` 1s, the L2 norm
# bound is `sqrt(max_records_per_user)`. For the second layer from bottom,
# the worst case is only 0s and `max_records_per_user/2` 2s. And so on
# until the root node. Another natural L2 norm bound on each layer is
# `max_records_per_user` so we take the minimum between the two bounds.
square_l2_norm_bound = 0.
square_layer_l2_norm_bound = max_records_per_user
for _ in range(_tree_depth(num_bins, arity)):
square_l2_norm_bound += min(max_records_per_user**2,
square_layer_l2_norm_bound)
square_layer_l2_norm_bound *= arity
l2_norm_bound = math.sqrt(square_l2_norm_bound)
central_tree_agg_query = tfp.privacy.dp_query.tree_aggregation_query.CentralTreeSumQuery(
stddev=stddev, arity=arity, l1_bound=max_records_per_user)
if secure_sum:
inner_agg_factory = secure.SecureSumFactory(
upper_bound_threshold=float(max_records_per_user),
lower_bound_threshold=0.)
# (2) Constructs `DifferentiallyPrivateFactory` according to the chosen
# `dp_mechanism`.
if dp_mechanism == 'central-gaussian':
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
l2_norm_bound, noise_multiplier * l2_norm_bound, arity)
# If the inner `DifferentiallyPrivateFactory` uses `GaussianSumQuery`, then
# the record is casted to `tf.float32` before feeding to the DP factory.
cast_to_float = True
elif dp_mechanism == 'no-noise':
inner_query = tfp.privacy.dp_query.no_privacy_query.NoPrivacySumQuery()
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery(
arity=arity, inner_query=inner_query)
# If the inner `DifferentiallyPrivateFactory` uses `NoPrivacyQuery`, then
# the record is kept as `tf.int32` before feeding to the DP factory.
cast_to_float = False
else:
inner_agg_factory = sum_factory.SumFactory()
raise ValueError('Unexpected dp_mechanism.')
nested_factory = differential_privacy.DifferentiallyPrivateFactory(
query, nested_factory)
# 3. Clip as specified by `clip_mechanism`.
nested_factory = clipping_factory.HistogramClippingSumFactory(
clip_mechanism=clip_mechanism,
max_records_per_user=max_records_per_user,
inner_agg_factory=nested_factory,
cast_to_float=cast_to_float)
return nested_factory
def _check_greater_equal(value, threshold, label):
if value < threshold:
raise ValueError(f'`{label}` must be at least {threshold}, got {value}.')
def _check_positive(value, label):
if value <= 0:
raise ValueError(f'{label} must be positive. Found {value}.')
def _check_non_negative(value, label):
if value < 0:
raise ValueError(f'{label} must be non-negative. Found {value}.')
def _check_membership(value, valid_set, label):
if value not in valid_set:
raise ValueError(f'`{label}` must be one of {valid_set}. '
f'Found {value}.')
return differential_privacy.DifferentiallyPrivateFactory(
central_tree_agg_query, inner_agg_factory)
def _tree_depth(num_leaves: int, arity: int):
"""Returns the depth of the tree given the number of leaf nodes and arity."""
return math.ceil(math.log(num_leaves) / math.log(arity)) + 1
......@@ -20,7 +20,7 @@ import numpy as np
import tensorflow as tf
import tensorflow_privacy as tfp
from tensorflow_federated.python.aggregators import differential_privacy
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.analytics.hierarchical_histogram import build_tree_from_leaf
from tensorflow_federated.python.analytics.hierarchical_histogram import hierarchical_histogram_factory as hihi_factory
from tensorflow_federated.python.core.api import test_case
......@@ -30,47 +30,36 @@ from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import measured_process
_test_central_dp_query = tfp.privacy.dp_query.tree_aggregation_query.CentralTreeSumQuery(
stddev=0.0)
class TreeAggregationFactoryComputationTest(test_case.TestCase,
parameterized.TestCase):
@parameterized.named_parameters([
('0', tf.float32, 3, 2),
('1', tf.float32, 3, 3),
('2', tf.float32, 4, 2),
('3', tf.float32, 4, 3),
('4', tf.float32, 7, 2),
('5', tf.float32, 7, 3),
('6', tf.float32, 8, 2),
('7', tf.float32, 8, 3),
('8', tf.int32, 3, 2),
('9', tf.int32, 3, 3),
('10', tf.int32, 4, 2),
('11', tf.int32, 4, 3),
('12', tf.int32, 7, 2),
('13', tf.int32, 7, 3),
('14', tf.int32, 8, 2),
('15', tf.int32, 8, 3),
])
def test_central_aggregation_with_sum(self, value_type, value_shape, arity):
value_type = computation_types.to_type((value_type, (value_shape,)))
factory_ = hihi_factory.create_central_hierarchical_histogram_factory(
arity=arity)
self.assertIsInstance(factory_,
differential_privacy.DifferentiallyPrivateFactory)
process = factory_.create(value_type)
@parameterized.named_parameters(
('test_1_2_sub_sampling', 1, 2, 'sub-sampling'),
('test_5_3_sub_sampling', 5, 3, 'sub-sampling'),
('test_3_2_distinct', 3, 2, 'distinct'),
('test_2_3_distinct', 2, 3, 'distinct'),
)
def test_no_noise_tree_aggregation(self, value_shape, arity, clip_mechanism):
agg_factory = hihi_factory.create_hierarchical_histogram_aggregation_factory(
num_bins=value_shape,
arity=arity,
clip_mechanism=clip_mechanism,
dp_mechanism='no-noise',
)
self.assertIsInstance(agg_factory, factory.UnweightedAggregationFactory)
value_type = computation_types.to_type((tf.int32, (value_shape,)))
process = agg_factory.create(value_type)
self.assertIsInstance(process, aggregation_process.AggregationProcess)
query_state = _test_central_dp_query.initial_global_state()
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery(
arity=arity,
inner_query=tfp.privacy.dp_query.no_privacy_query.NoPrivacySumQuery())
query_state = query.initial_global_state()
query_state_type = type_conversions.type_from_tensors(query_state)
query_metrics_type = type_conversions.type_from_tensors(
_test_central_dp_query.derive_metrics(query_state))
query.derive_metrics(query_state))
server_state_type = computation_types.at_server((query_state_type, ()))
expected_initialize_type = computation_types.FunctionType(
......@@ -81,12 +70,16 @@ class TreeAggregationFactoryComputationTest(test_case.TestCase,
expected_measurements_type = computation_types.at_server(
collections.OrderedDict(dp_query_metrics=query_metrics_type, dp=()))
tree_depth = hihi_factory._tree_depth(value_shape, arity)
flat_tree_shape = (arity**tree_depth - 1) // (arity - 1)
result_value_type = computation_types.to_type(
collections.OrderedDict([
('flat_values',
computation_types.TensorType(tf.float32, tf.TensorShape(None))),
('nested_row_splits', [(tf.int64, (None,))])
computation_types.to_type((tf.int32, (flat_tree_shape,)))),
('nested_row_splits', [(tf.int64, (tree_depth + 1,))])
]))
value_type = computation_types.to_type((tf.int32, (value_shape,)))
expected_next_type = computation_types.FunctionType(
parameter=collections.OrderedDict(
state=server_state_type,
......@@ -95,33 +88,35 @@ class TreeAggregationFactoryComputationTest(test_case.TestCase,
state=server_state_type,
result=computation_types.at_server(result_value_type),
measurements=expected_measurements_type))
self.assertTrue(
process.next.type_signature.is_equivalent_to(expected_next_type))
@parameterized.named_parameters([
('0', 4, 2, 10),
('1', 4, 3, 10),
('2', 7, 2, 10),
('3', 7, 3, 10),
])
def test_central_aggregation_with_secure_sum(self, value_shape, arity,
l1_bound):
value_type = computation_types.to_type((tf.float32, (value_shape,)))
factory_ = hihi_factory.create_central_hierarchical_histogram_factory(
arity=arity, secure_sum=True)
self.assertIsInstance(factory_,
differential_privacy.DifferentiallyPrivateFactory)
@parameterized.named_parameters(
('test_1_2_sub_sampling', 1, 2, 'sub-sampling'),
('test_5_3_sub_sampling', 5, 3, 'sub-sampling'),
('test_3_2_distinct', 3, 2, 'distinct'),
('test_2_3_distinct', 2, 3, 'distinct'),
)
def test_central_gaussian_tree_aggregation(self, value_shape, arity,
clip_mechanism):
process = factory_.create(value_type)
agg_factory = hihi_factory.create_hierarchical_histogram_aggregation_factory(
num_bins=value_shape,
arity=arity,
clip_mechanism=clip_mechanism,
dp_mechanism='central-gaussian',
)
self.assertIsInstance(agg_factory, factory.UnweightedAggregationFactory)
value_type = computation_types.to_type((tf.int32, (value_shape,)))
process = agg_factory.create(value_type)
self.assertIsInstance(process, aggregation_process.AggregationProcess)
query_state = _test_central_dp_query.initial_global_state()
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
l2_norm_clip=1.0, stddev=0.0)
query_state = query.initial_global_state()
query_state_type = type_conversions.type_from_tensors(query_state)
query_metrics_type = type_conversions.type_from_tensors(
_test_central_dp_query.derive_metrics(query_state))
query.derive_metrics(query_state))