提交 4daa2df3 编辑于 作者: Zheng Xu's avatar Zheng Xu 提交者: tensorflow-copybara
浏览文件

Cleanup how TFP queries are called in tff.analytics.

PiperOrigin-RevId: 392762446
上级 91cab511
......@@ -202,6 +202,8 @@ def create_hierarchical_histogram_aggregation_factory(
# Constructs `DifferentiallyPrivateFactory` according to the chosen
# `dp_mechanism`.
if dp_mechanism == 'central-gaussian':
# TODO(b/197596864): change to `tfp.TreeRangeSumQuery` after next
# TFP release.
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery.build_central_gaussian_query(
l2_norm_bound, noise_multiplier * l2_norm_bound, arity)
# If the inner `DifferentiallyPrivateFactory` uses `GaussianSumQuery`, then
......@@ -216,7 +218,7 @@ def create_hierarchical_histogram_aggregation_factory(
# before feeding to the DP factory.
cast_to_float = False
elif dp_mechanism == 'no-noise':
inner_query = tfp.privacy.dp_query.no_privacy_query.NoPrivacySumQuery()
inner_query = tfp.NoPrivacySumQuery()
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery(
arity=arity, inner_query=inner_query)
# If the inner `DifferentiallyPrivateFactory` uses `NoPrivacyQuery`, then
......
......@@ -54,8 +54,7 @@ class TreeAggregationFactoryComputationTest(test_case.TestCase,
self.assertIsInstance(process, aggregation_process.AggregationProcess)
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery(
arity=arity,
inner_query=tfp.privacy.dp_query.no_privacy_query.NoPrivacySumQuery())
arity=arity, inner_query=tfp.NoPrivacySumQuery())
query_state = query.initial_global_state()
query_state_type = type_conversions.type_from_tensors(query_state)
query_metrics_type = type_conversions.type_from_tensors(
......@@ -169,8 +168,7 @@ class TreeAggregationFactoryComputationTest(test_case.TestCase,
query = tfp.privacy.dp_query.tree_aggregation_query.TreeRangeSumQuery(
arity=arity,
inner_query=tfp.privacy.dp_query.distributed_discrete_gaussian_query
.DistributedDiscreteGaussianSumQuery(
inner_query=tfp.DistributedDiscreteGaussianSumQuery(
l2_norm_bound=1.0, local_stddev=1.0))
query_state = query.initial_global_state()
query_state_type = type_conversions.type_from_tensors(query_state)
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册