提交 9db82091 编辑于 作者: Wennan Zhu's avatar Wennan Zhu 提交者: tensorflow-copybara
浏览文件

Automated rollback of commit e873b07b

PiperOrigin-RevId: 392078655
上级 7b3d2b1c
......@@ -70,6 +70,6 @@ RUN ${PIP} install --no-cache-dir --upgrade \
retrying~=1.3.3 \
semantic-version~=2.8.5 \
tensorflow-model-optimization~=0.5.0 \
tensorflow-privacy~=0.6.2 \
tensorflow-privacy~=0.7.2 \
tf-nightly
RUN pip freeze
......@@ -9,8 +9,8 @@ py_library(
srcs = ["hierarchical_histogram_factory.py"],
srcs_version = "PY3",
deps = [
":clipping_factory",
"//tensorflow_federated/python/aggregators:differential_privacy",
"//tensorflow_federated/python/aggregators:secure",
"//tensorflow_federated/python/aggregators:sum_factory",
],
)
......@@ -23,7 +23,7 @@ py_test(
deps = [
":build_tree_from_leaf",
":hierarchical_histogram_factory",
"//tensorflow_federated/python/aggregators:differential_privacy",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/test:execution_contexts",
"//tensorflow_federated/python/core/impl/types:computation_types",
......@@ -38,8 +38,8 @@ py_library(
srcs = ["hierarchical_histogram.py"],
srcs_version = "PY3",
deps = [
":clipping_factory",
":hierarchical_histogram_factory",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
......
......@@ -15,8 +15,8 @@
import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.analytics.hierarchical_histogram import hierarchical_histogram_factory
from tensorflow_federated.python.analytics.hierarchical_histogram import clipping_factory
from tensorflow_federated.python.analytics.hierarchical_histogram import hierarchical_histogram_factory as hihi_factory
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
......@@ -86,88 +86,97 @@ def _discretized_histogram_counts(client_data: tf.data.Dataset,
histogram = client_data.reduce(
tf.zeros([num_bins], dtype=tf.int32), insert_record)
return tf.cast(histogram, tf.float32)
return histogram
def _build_hierarchical_histogram_computation(
lower_bound: float, upper_bound: float, num_bins: int,
aggregation_factory: factory.UnweightedAggregationFactory):
"""Utility function creating tff computation given the parameters and factory.
def build_hierarchical_histogram_computation(
lower_bound: float,
upper_bound: float,
num_bins: int,
arity: int = 2,
clip_mechanism: str = 'sub-sampling',
max_records_per_user: int = 10,
dp_mechanism: str = 'no-noise',
noise_multiplier: float = 0.0):
"""Creates the TFF computation for hierarchical histogram aggregation.
Args:
lower_bound: A `float` specifying the lower bound of the data range.
upper_bound: A `float` specifying the upper bound of the data range.
num_bins: The integer number of bins to compute.
aggregation_factory: The aggregation factory used to construct the federated
computation.
arity: The branching factor of the tree. Defaults to 2.
clip_mechanism: A `str` representing the clipping mechanism. Currently
supported mechanisms are
- 'sub-sampling': (Default) Uniformly sample up to `max_records_per_user`
records without replacement from the client dataset.
- 'distinct': Uniquify client dataset and uniformly sample up to
`max_records_per_user` records without replacement from it.
max_records_per_user: An `int` representing the maximum of records each user
can include in their local histogram. Defaults to 10.
dp_mechanism: A `str` representing the differentially private mechanism to
use. Currently supported mechanisms are
- 'no-noise': (Default) Tree aggregation mechanism without noise.
- 'central-gaussian': Tree aggregation with central Gaussian mechanism.
noise_multiplier: A `float` specifying the noise multiplier (central noise
stddev / L2 clip norm) for model updates. Defaults to 0.0.
Returns:
A tff federated computation function.
A federated computation that performs hierarchical histogram aggregation.
"""
_check_greater_than_equal(upper_bound, lower_bound, 'upper_bound',
'lower_bound')
_check_positive(num_bins, 'num_bins')
_check_greater_than_equal_thres(arity, 2, 'arity')
_check_membership(clip_mechanism, clipping_factory.CLIP_MECHANISMS,
'clip_mechanism')
_check_greater_than_equal_thres(max_records_per_user, 1,
'max_records_per_user')
_check_membership(dp_mechanism, hihi_factory.DP_MECHANISMS, 'dp_mechanism')
_check_greater_than_equal_thres(noise_multiplier, 0., noise_multiplier)
@computations.tf_computation(computation_types.SequenceType(tf.float32))
def client_work(client_data):
return _discretized_histogram_counts(client_data, lower_bound, upper_bound,
num_bins)
aggregator = aggregation_factory.create(client_work.type_signature.result)
agg_factory = hihi_factory.create_hierarchical_histogram_aggregation_factory(
num_bins, arity, clip_mechanism, max_records_per_user, dp_mechanism,
noise_multiplier)
process = agg_factory.create(client_work.type_signature.result)
@computations.federated_computation(
computation_types.at_clients(client_work.type_signature.parameter))
def hierarchical_histogram_computation(federated_client_data):
# Work done at clients.
client_histogram = intrinsics.federated_map(client_work,
federated_client_data)
# Aggregation to server.
return aggregator.next(aggregator.initialize(), client_histogram).result
return process.next(process.initialize(), client_histogram).result
return hierarchical_histogram_computation
def build_central_hierarchical_histogram_computation(
lower_bound: float,
upper_bound: float,
num_bins: int,
arity: int = 2,
max_records_per_user: int = 1,
noise_multiplier: float = 0.0,
secure_sum: bool = False):
"""Create the tff federated computation for central hierarchical histogram aggregation.
Args:
lower_bound: A `float` specifying the lower bound of the data range.
upper_bound: A `float` specifying the upper bound of the data range.
num_bins: The integer number of bins to compute.
arity: The branching factor of the tree. Defaults to 2.
max_records_per_user: The maximum number of records each user is allowed to
contribute. Defaults to 1.
noise_multiplier: A `float` specifying the noise multiplier (central noise
stddev / L2 clip norm) for model updates. Defaults to 0.0.
secure_sum: A boolean deciding whether to use secure aggregation. Defaults
to `False`.
def _check_greater_than_equal(lvalue, rvalue, llabel, rlabel):
if lvalue < rvalue:
raise ValueError(f'`{llabel}` should be no smaller than '
f'`{rlabel}`. Found {lvalue} and '
f'{rvalue}.')
Returns:
A tff.federated_computation function to perform central tree aggregation.
"""
if upper_bound < lower_bound:
raise ValueError(f'upper_bound: {upper_bound} is smaller than '
f'lower_bound: {lower_bound}.')
def _check_greater_than_equal_thres(value, threshold, label):
if value < threshold:
raise ValueError(f'`{label}` must be at least {threshold}. Found {value}.')
if num_bins <= 0:
raise ValueError(f'num_bins: {num_bins} smaller or equal to zero.')
if arity < 2:
raise ValueError(f'Arity should be at least 2.' f'arity={arity} is given.')
def _check_positive(value, label):
if value <= 0:
raise ValueError(f'{label} must be positive. Found {value}.')
if max_records_per_user < 1:
raise ValueError(f'Maximum records per user should be at least 1. '
f'max_records_per_user={max_records_per_user} is given.')
stddev = max_records_per_user * noise_multiplier
def _check_non_negative(value, label):
if value < 0:
raise ValueError(f'{label} must be non-negative. Found {value}.')
central_tree_aggregation_factory = hierarchical_histogram_factory.create_central_hierarchical_histogram_factory(
stddev, arity, max_records_per_user, secure_sum=secure_sum)
return _build_hierarchical_histogram_computation(
lower_bound, upper_bound, num_bins, central_tree_aggregation_factory)
def _check_membership(value, valid_set, label):
if value not in valid_set:
raise ValueError(f'`{label}` must be one of {valid_set}. '
f'Found {value}.')
......@@ -11,62 +11,160 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to create differentially private tree aggregation factory.
The factory is created by wrapping proper tfp.DPQuery in
`DifferentiallyPrivateFactory`.
"""
"""Differentially private tree aggregation factory."""
import math
import tensorflow_privacy as tfp
from tensorflow_federated.python.aggregators import differential_privacy
from tensorflow_federated.python.aggregators import secure
from tensorflow_federated.python.aggregators import sum_factory
from tensorflow_federated.python.analytics.hierarchical_histogram import clipping_factory
# Supported no-noise mechanisms.
NO_NOISE_MECHANISMS = ['no-noise']
# Supported central DP mechanisms.
CENTRAL_DP_MECHANISMS = [
'central-gaussian', # Central Gaussian mechanism.
]
DP_MECHANISMS = CENTRAL_DP_MECHANISMS + NO_NOISE_MECHANISMS
def create_central_hierarchical_histogram_factory(
stddev: float = 0.0,
def create_hierarchical_histogram_aggregation_factory(
num_bins: int,
arity: int = 2,
clip_mechanism: str = 'sub-sampling',
max_records_per_user: int = 10,
secure_sum: bool = False):
"""Creates aggregator for hierarchical histograms with differential privacy.
dp_mechanism: str = 'central-gaussian',
noise_multiplier: float = 0.0):
"""Creates hierarchical histogram aggregation factory.
Hierarchical histogram factory is constructed by composing 3 aggregation
factories.
(1) The inner-most factory is `SumFactory`.
(2) The middle factory is `DifferentiallyPrivateFactory` whose inner query is
`TreeRangeSumQuery`. This factory 1) takes in a clipped histogram,
constructs the hierarchical histogram and checks the norm bound of the
hierarchical histogram at clients, 2) adds noise either at clients or at
server according to `dp_mechanism`.
(3) The outer-most factory is `HistogramClippingSumFactory` which clips the
input histogram to bound each user's contribution.
Args:
stddev: The standard deviation of noise added to each node of the central
tree.
arity: The branching factor of the tree.
max_records_per_user: The maximum of records each user can upload in their
local histogram.
secure_sum: A boolean deciding whether to use secure aggregation. Defaults
to `False`.
num_bins: An `int` representing the input histogram size.
arity: An `int` representing the branching factor of the tree. Defaults to
2.
clip_mechanism: A `str` representing the clipping mechanism. Currently
supported mechanisms are
- 'sub-sampling': (Default) Uniformly sample up to `max_records_per_user`
records without replacement from the client dataset.
- 'distinct': Uniquify client dataset and uniformly sample up to
`max_records_per_user` records without replacement from it.
max_records_per_user: An `int` representing the maximum of records each user
can include in their local histogram. Defaults to 10.
dp_mechanism: A `str` representing the differentially private mechanism to
use. Currently supported mechanisms are
- 'central-gaussian': (Default) Tree aggregation with central Gaussian
mechanism.
- 'no-noise': Tree aggregation mechanism without noise.
noise_multiplier: A `float` specifying the noise multiplier (central noise
stddev / L2 clip norm) for model updates. Only needed when `dp_mechanism`
is not 'no-noise'. Defaults to 0.0.
Returns:
`tff.aggregators.UnWeightedAggregationFactory`.
`tff.aggregators.UnweightedAggregationFactory`.
Raises:
`ValueError`: If 'stddev < 0', `arity < 2`, `max_records_per_user < 1` or
`inner_agg_factory` is illegal.
TypeError: If arguments have the wrong type(s).
ValueError: If arguments have invalid value(s).
"""
if stddev < 0:
raise ValueError(f"Standard deviation should be greater than zero."
f"stddev={stddev} is given.")
_check_positive(num_bins, 'num_bins')
_check_greater_equal(arity, 2, 'arity')
_check_membership(clip_mechanism, clipping_factory.CLIP_MECHANISMS,
'clip_mechanism')
_check_positive(max_records_per_user, 'max_records_per_user')
_check_membership(dp_mechanism, DP_MECHANISMS, 'dp_mechanism')
_check_non_negative(noise_multiplier, 'noise_multiplier')
if arity < 2:
raise ValueError(f"Arity should be at least 2." f"arity={arity} is given.")
# Build nested aggregtion factory from innermost to outermost.
# 1. Sum factory. The most inner factory that sums the preprocessed records.
nested_factory = sum_factory.SumFactory()
if max_records_per_user < 1:
raise ValueError(f"Maximum records per user should be at least 1."
f"max_records_per_user={max_records_per_user} is given.")
# 2. DP operations.
# (1) Converts `max_records_per_user` to the corresponding norm bound
# according to the chosen `clip_mechanism` and `dp_mechanism`.
if dp_mechanism in ['central-gaussian']:
if clip_mechanism == 'sub-sampling':
l2_norm_bound = max_records_per_user * math.sqrt(
_tree_depth(num_bins, arity))
elif clip_mechanism == 'distinct':
# The following code block converts `max_records_per_user` to L2 norm
# bound of the hierarchical histogram layer by layer. For the bottom
# layer with only 0s and at most `max_records_per_user` 1s, the L2 norm
# bound is `sqrt(max_records_per_user)`. For the second layer from bottom,
# the worst case is only 0s and `max_records_per_user/2` 2s. And so on
# until the root node. Another natural L2 norm bound on each layer is
# `max_records_per_user` so we take the minimum between the two bounds.
square_l2_norm_bound = 0.
square_layer_l2_norm_bound = max_records_per_user
for _ in range(_tree_depth(num_bins, arity)):
square_l2_norm_bound += min(max_records_per_user**2,
square_layer_l2_norm_bound)
square_layer_l2_norm_bound *= arity
l2_norm_bound = math.sqrt(square_l2_norm_bound)
central_tree_agg_query = tfp.privacy.dp_query.tree_aggregation_query.CentralTreeSumQuery(
stddev=stddev, arity=arity, l1_bound=max_records_per_user)
if secure_sum:
inner_agg_factory = secure.SecureSumFactory(
upper_bound_threshold=float(max_records_per_user),
lower_bound_threshold=0.)
# (2) Constructs `DifferentiallyPrivateFactory` according to the chosen
# `dp_mechanism`.
if dp_mechanism == 'central-gaussian':
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
l2_norm_bound, noise_multiplier * l2_norm_bound, arity)
# If the inner `DifferentiallyPrivateFactory` uses `GaussianSumQuery`, then
# the record is casted to `tf.float32` before feeding to the DP factory.
cast_to_float = True
elif dp_mechanism == 'no-noise':
inner_query = tfp.privacy.dp_query.no_privacy_query.NoPrivacySumQuery()
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery(
arity=arity, inner_query=inner_query)
# If the inner `DifferentiallyPrivateFactory` uses `NoPrivacyQuery`, then
# the record is kept as `tf.int32` before feeding to the DP factory.
cast_to_float = False
else:
inner_agg_factory = sum_factory.SumFactory()
raise ValueError('Unexpected dp_mechanism.')
nested_factory = differential_privacy.DifferentiallyPrivateFactory(
query, nested_factory)
# 3. Clip as specified by `clip_mechanism`.
nested_factory = clipping_factory.HistogramClippingSumFactory(
clip_mechanism=clip_mechanism,
max_records_per_user=max_records_per_user,
inner_agg_factory=nested_factory,
cast_to_float=cast_to_float)
return nested_factory
def _check_greater_equal(value, threshold, label):
if value < threshold:
raise ValueError(f'`{label}` must be at least {threshold}, got {value}.')
def _check_positive(value, label):
if value <= 0:
raise ValueError(f'{label} must be positive. Found {value}.')
def _check_non_negative(value, label):
if value < 0:
raise ValueError(f'{label} must be non-negative. Found {value}.')
def _check_membership(value, valid_set, label):
if value not in valid_set:
raise ValueError(f'`{label}` must be one of {valid_set}. '
f'Found {value}.')
return differential_privacy.DifferentiallyPrivateFactory(
central_tree_agg_query, inner_agg_factory)
def _tree_depth(num_leaves: int, arity: int):
"""Returns the depth of the tree given the number of leaf nodes and arity."""
return math.ceil(math.log(num_leaves) / math.log(arity)) + 1
......@@ -20,7 +20,7 @@ import numpy as np
import tensorflow as tf
import tensorflow_privacy as tfp
from tensorflow_federated.python.aggregators import differential_privacy
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.analytics.hierarchical_histogram import build_tree_from_leaf
from tensorflow_federated.python.analytics.hierarchical_histogram import hierarchical_histogram_factory as hihi_factory
from tensorflow_federated.python.core.api import test_case
......@@ -30,47 +30,36 @@ from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import measured_process
_test_central_dp_query = tfp.privacy.dp_query.tree_aggregation_query.CentralTreeSumQuery(
stddev=0.0)
class TreeAggregationFactoryComputationTest(test_case.TestCase,
parameterized.TestCase):
@parameterized.named_parameters([
('0', tf.float32, 3, 2),
('1', tf.float32, 3, 3),
('2', tf.float32, 4, 2),
('3', tf.float32, 4, 3),
('4', tf.float32, 7, 2),
('5', tf.float32, 7, 3),
('6', tf.float32, 8, 2),
('7', tf.float32, 8, 3),
('8', tf.int32, 3, 2),
('9', tf.int32, 3, 3),
('10', tf.int32, 4, 2),
('11', tf.int32, 4, 3),
('12', tf.int32, 7, 2),
('13', tf.int32, 7, 3),
('14', tf.int32, 8, 2),
('15', tf.int32, 8, 3),
])
def test_central_aggregation_with_sum(self, value_type, value_shape, arity):
value_type = computation_types.to_type((value_type, (value_shape,)))
factory_ = hihi_factory.create_central_hierarchical_histogram_factory(
arity=arity)
self.assertIsInstance(factory_,
differential_privacy.DifferentiallyPrivateFactory)
process = factory_.create(value_type)
@parameterized.named_parameters(
('test_1_2_sub_sampling', 1, 2, 'sub-sampling'),
('test_5_3_sub_sampling', 5, 3, 'sub-sampling'),
('test_3_2_distinct', 3, 2, 'distinct'),
('test_2_3_distinct', 2, 3, 'distinct'),
)
def test_no_noise_tree_aggregation(self, value_shape, arity, clip_mechanism):
agg_factory = hihi_factory.create_hierarchical_histogram_aggregation_factory(
num_bins=value_shape,
arity=arity,
clip_mechanism=clip_mechanism,
dp_mechanism='no-noise',
)
self.assertIsInstance(agg_factory, factory.UnweightedAggregationFactory)
value_type = computation_types.to_type((tf.int32, (value_shape,)))
process = agg_factory.create(value_type)
self.assertIsInstance(process, aggregation_process.AggregationProcess)
query_state = _test_central_dp_query.initial_global_state()
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery(
arity=arity,
inner_query=tfp.privacy.dp_query.no_privacy_query.NoPrivacySumQuery())
query_state = query.initial_global_state()
query_state_type = type_conversions.type_from_tensors(query_state)
query_metrics_type = type_conversions.type_from_tensors(
_test_central_dp_query.derive_metrics(query_state))
query.derive_metrics(query_state))
server_state_type = computation_types.at_server((query_state_type, ()))
expected_initialize_type = computation_types.FunctionType(
......@@ -81,12 +70,16 @@ class TreeAggregationFactoryComputationTest(test_case.TestCase,
expected_measurements_type = computation_types.at_server(
collections.OrderedDict(dp_query_metrics=query_metrics_type, dp=()))
tree_depth = hihi_factory._tree_depth(value_shape, arity)
flat_tree_shape = (arity**tree_depth - 1) // (arity - 1)
result_value_type = computation_types.to_type(
collections.OrderedDict([
('flat_values',
computation_types.TensorType(tf.float32, tf.TensorShape(None))),
('nested_row_splits', [(tf.int64, (None,))])
computation_types.to_type((tf.int32, (flat_tree_shape,)))),
('nested_row_splits', [(tf.int64, (tree_depth + 1,))])
]))
value_type = computation_types.to_type((tf.int32, (value_shape,)))
expected_next_type = computation_types.FunctionType(
parameter=collections.OrderedDict(
state=server_state_type,
......@@ -95,33 +88,35 @@ class TreeAggregationFactoryComputationTest(test_case.TestCase,
state=server_state_type,
result=computation_types.at_server(result_value_type),
measurements=expected_measurements_type))
self.assertTrue(
process.next.type_signature.is_equivalent_to(expected_next_type))
@parameterized.named_parameters([
('0', 4, 2, 10),
('1', 4, 3, 10),
('2', 7, 2, 10),
('3', 7, 3, 10),
])
def test_central_aggregation_with_secure_sum(self, value_shape, arity,
l1_bound):
value_type = computation_types.to_type((tf.float32, (value_shape,)))
factory_ = hihi_factory.create_central_hierarchical_histogram_factory(
arity=arity, secure_sum=True)
self.assertIsInstance(factory_,
differential_privacy.DifferentiallyPrivateFactory)
@parameterized.named_parameters(
('test_1_2_sub_sampling', 1, 2, 'sub-sampling'),
('test_5_3_sub_sampling', 5, 3, 'sub-sampling'),
('test_3_2_distinct', 3, 2, 'distinct'),
('test_2_3_distinct', 2, 3, 'distinct'),
)
def test_central_gaussian_tree_aggregation(self, value_shape, arity,
clip_mechanism):
process = factory_.create(value_type)
agg_factory = hihi_factory.create_hierarchical_histogram_aggregation_factory(
num_bins=value_shape,
arity=arity,
clip_mechanism=clip_mechanism,
dp_mechanism='central-gaussian',
)
self.assertIsInstance(agg_factory, factory.UnweightedAggregationFactory)
value_type = computation_types.to_type((tf.int32, (value_shape,)))
process = agg_factory.create(value_type)
self.assertIsInstance(process, aggregation_process.AggregationProcess)
query_state = _test_central_dp_query.initial_global_state()
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
l2_norm_clip=1.0, stddev=0.0)
query_state = query.initial_global_state()
query_state_type = type_conversions.type_from_tensors(query_state)
query_metrics_type = type_conversions.type_from_tensors(
_test_central_dp_query.derive_metrics(query_state))
query.derive_metrics(query_state))
server_state_type = computation_types.at_server((query_state_type, ()))