提交 edda7038 编辑于 作者: Jakub Konecny's avatar Jakub Konecny 提交者: tensorflow-copybara
浏览文件

Splits `AggregationProcessFactory` to weighted and non-weighted variants.

New classes `WeightedAggregationFactory` and `NonWeightedAggregationFactory` replace the previous `AggregationProcessFactory` in the API

This change makes explicit difference between ways to create `tff.templates.AggreagationProcess` objects with input type signatures of its `next` function `<state, value>` and `<state, value, weight>`.

Also updates `ClippingFactory` and `ZeroingFactory` to implement both newly created classes.

`MeanFactory` now only implements `WeightedAggregationFactory` class.

PiperOrigin-RevId: 341988708
上级 ab0918c0
......@@ -64,11 +64,11 @@ py_library(
":factory",
":sum_factory",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:intrinsics",
"//tensorflow_federated/python/core/api:placements",
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/templates:aggregation_process",
"//tensorflow_federated/python/core/templates:estimation_process",
"//tensorflow_federated/python/core/templates:measured_process",
......@@ -83,6 +83,7 @@ py_test(
deps = [
":clipping_factory",
":mean_factory",
":sum_factory",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:intrinsics",
......
......@@ -16,7 +16,8 @@
from tensorflow_federated.python.aggregators.clipping_factory import ClippingFactory
from tensorflow_federated.python.aggregators.clipping_factory import ZeroingFactory
from tensorflow_federated.python.aggregators.dp_factory import DifferentiallyPrivateFactory
from tensorflow_federated.python.aggregators.factory import AggregationProcessFactory
from tensorflow_federated.python.aggregators.factory import UnweightedAggregationFactory
from tensorflow_federated.python.aggregators.factory import WeightedAggregationFactory
from tensorflow_federated.python.aggregators.mean_factory import MeanFactory
from tensorflow_federated.python.aggregators.quantile_estimation import PrivateQuantileEstimationProcess
from tensorflow_federated.python.aggregators.sum_factory import SumFactory
......@@ -51,7 +51,7 @@ def adaptive_zeroing_mean(
increment: float,
learning_rate: float,
norm_order: bool,
no_nan_mean: bool = False) -> factory.AggregationProcessFactory:
no_nan_mean: bool = False) -> factory.WeightedAggregationFactory:
"""Creates a factory for mean with adaptive zeroing.
Estimates value at quantile `Z` of value norm distribution and zeroes out
......@@ -96,7 +96,7 @@ def adaptive_zeroing_clipping_mean(
initial_clipping_quantile_estimate: float,
target_clipping_quantile: float,
clipping_learning_rate: float,
no_nan_mean: bool = False) -> factory.AggregationProcessFactory:
no_nan_mean: bool = False) -> factory.WeightedAggregationFactory:
"""Makes a factory for mean with adaptive zeroing and clipping.
Estimates value at quantile `Z` of value norm distribution and zeroes out
......
......@@ -34,9 +34,11 @@ class ClippingCompositionsTest(test_case.TestCase):
increment=1.0,
learning_rate=np.log(4.0),
norm_order=np.inf)
self.assertIsInstance(factory_, factory.AggregationProcessFactory)
self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
process = factory_.create(computation_types.to_type(tf.float32))
process = factory_.create_weighted(
value_type=computation_types.to_type(tf.float32),
weight_type=computation_types.to_type(tf.float32))
state = process.initialize()
......@@ -67,9 +69,11 @@ class ClippingCompositionsTest(test_case.TestCase):
initial_clipping_quantile_estimate=2.0,
target_clipping_quantile=0.0,
clipping_learning_rate=np.log(4.0))
self.assertIsInstance(factory_, factory.AggregationProcessFactory)
self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
process = factory_.create(computation_types.to_type(tf.float32))
process = factory_.create_weighted(
value_type=computation_types.to_type(tf.float32),
weight_type=computation_types.to_type(tf.float32))
state = process.initialize()
......
......@@ -22,11 +22,11 @@ import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import sum_factory
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import intrinsics
from tensorflow_federated.python.core.api import placements
from tensorflow_federated.python.core.impl.types import type_analysis
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import estimation_process
from tensorflow_federated.python.core.templates import measured_process
......@@ -34,9 +34,12 @@ from tensorflow_federated.python.core.templates import measured_process
NORM_TF_TYPE = tf.float32
COUNT_TF_TYPE = tf.int32
_InnerFactoryType = Union[factory.UnweightedAggregationFactory,
factory.WeightedAggregationFactory]
def _constant_process(value):
"""Creates an `EstimationProcess` that returns a constant value."""
"""Creates an `EstimationProcess` that reports a constant value."""
init_fn = computations.federated_computation(
lambda: intrinsics.federated_value((), placements.SERVER))
next_fn = computations.federated_computation(
......@@ -48,6 +51,10 @@ def _constant_process(value):
return estimation_process.EstimationProcess(init_fn, next_fn, report_fn)
def _contains_non_float_dtype(type_spec):
return type_spec.is_tensor() and not type_spec.dtype.is_floating
def _check_norm_process(norm_process: estimation_process.EstimationProcess,
name: str):
"""Checks type properties for norm_process.
......@@ -70,8 +77,9 @@ def _check_norm_process(norm_process: estimation_process.EstimationProcess,
norm_type_at_clients = computation_types.at_clients(NORM_TF_TYPE)
if not next_parameter_type[1].is_assignable_from(norm_type_at_clients):
raise TypeError(f'Second argument of `{name}.next` must be assignable from '
f'NORM_TF_TYPE@CLIENTS but found {next_parameter_type[1]}')
raise TypeError(
f'Second argument of `{name}.next` must be assignable from '
f'{norm_type_at_clients} but found {next_parameter_type[1]}')
next_result_type = norm_process.next.type_signature.result
if not norm_process.state_type.is_assignable_from(next_result_type):
......@@ -83,11 +91,12 @@ def _check_norm_process(norm_process: estimation_process.EstimationProcess,
norm_type_at_server = computation_types.at_server(NORM_TF_TYPE)
if not norm_type_at_server.is_assignable_from(result_type):
raise TypeError(f'Result type of `{name}.report` must be assignable to '
f'NORM_TF_TYPE@SERVER but found {result_type}.')
f'{norm_type_at_server} but found {result_type}.')
class ClippingFactory(factory.AggregationProcessFactory):
"""`AggregationProcessFactory` for clipping large values.
class ClippingFactory(factory.UnweightedAggregationFactory,
factory.WeightedAggregationFactory):
"""`AggregationProcess` factory for clipping large values.
The created `tff.templates.AggregationProcess` projects the values onto an
L2 ball (also referred to as "clipping") with norm determined by the provided
......@@ -107,7 +116,7 @@ class ClippingFactory(factory.AggregationProcessFactory):
def __init__(self, clipping_norm: Union[float,
estimation_process.EstimationProcess],
inner_agg_factory: factory.AggregationProcessFactory):
inner_agg_factory: _InnerFactoryType):
"""Initializes `ClippingFactory`.
Args:
......@@ -117,8 +126,7 @@ class ClippingFactory(factory.AggregationProcessFactory):
inner_agg_factory: A factory specifying the type of aggregation to be done
after clipping.
"""
py_typecheck.check_type(inner_agg_factory,
factory.AggregationProcessFactory)
py_typecheck.check_type(inner_agg_factory, _InnerFactoryType.__args__)
self._inner_agg_factory = inner_agg_factory
py_typecheck.check_type(clipping_norm,
......@@ -133,54 +141,91 @@ class ClippingFactory(factory.AggregationProcessFactory):
# we will make this customizable to allow DP measurements.
self._clipped_count_agg_factory = sum_factory.SumFactory()
def create(
def create_unweighted(
self,
value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
py_typecheck.check_type(value_type, factory.ValueType.__args__)
if type_analysis.contains(value_type, predicate=_contains_non_float_dtype):
raise TypeError(f'All values in provided value_type must be of floating '
f'dtype. Provided value_type: {value_type}')
if not all([t.dtype.is_floating for t in structure.flatten(value_type)]):
inner_agg_process = self._inner_agg_factory.create_unweighted(value_type)
clipped_count_agg_process = (
self._clipped_count_agg_factory.create_unweighted(
computation_types.to_type(COUNT_TF_TYPE)))
init_fn = self._create_init_fn(inner_agg_process.initialize,
clipped_count_agg_process.initialize)
next_fn = self._create_next_fn(inner_agg_process.next,
clipped_count_agg_process.next,
init_fn.type_signature.result)
return aggregation_process.AggregationProcess(init_fn, next_fn)
def create_weighted(
self, value_type: factory.ValueType,
weight_type: factory.ValueType) -> aggregation_process.AggregationProcess:
py_typecheck.check_type(value_type, factory.ValueType.__args__)
py_typecheck.check_type(weight_type, factory.ValueType.__args__)
if type_analysis.contains(value_type, predicate=_contains_non_float_dtype):
raise TypeError(f'All values in provided value_type must be of floating '
f'dtype. Provided value_type: {value_type}')
inner_agg_process = self._inner_agg_factory.create(value_type)
inner_agg_process = self._inner_agg_factory.create_weighted(
value_type, weight_type)
clipped_count_agg_process = (
self._clipped_count_agg_factory.create_unweighted(
computation_types.to_type(COUNT_TF_TYPE)))
count_type = computation_types.to_type(COUNT_TF_TYPE)
clipped_count_agg_process = self._clipped_count_agg_factory.create(
count_type)
init_fn = self._create_init_fn(inner_agg_process.initialize,
clipped_count_agg_process.initialize)
next_fn = self._create_next_fn(inner_agg_process.next,
clipped_count_agg_process.next,
init_fn.type_signature.result)
return aggregation_process.AggregationProcess(init_fn, next_fn)
def _create_init_fn(self, inner_agg_initialize, clipped_count_agg_initialize):
@computations.federated_computation()
def init_fn():
return intrinsics.federated_zip(
collections.OrderedDict(
clipping_norm=self._clipping_norm_process.initialize(),
inner_agg=inner_agg_process.initialize(),
clipped_count_agg=clipped_count_agg_process.initialize()))
inner_agg=inner_agg_initialize(),
clipped_count_agg=clipped_count_agg_initialize()))
return init_fn
@computations.tf_computation(value_type, NORM_TF_TYPE)
def clip(value, clipping_norm):
def _create_next_fn(self, inner_agg_next, clipped_count_agg_next, state_type):
@computations.tf_computation(
inner_agg_next.type_signature.parameter[1].member, NORM_TF_TYPE)
def clip_fn(value, clipping_norm):
clipped_value_as_list, global_norm = tf.clip_by_global_norm(
tf.nest.flatten(value), clipping_norm)
clipped_value = tf.nest.pack_sequence_as(value, clipped_value_as_list)
was_clipped = tf.cast((global_norm > clipping_norm), COUNT_TF_TYPE)
return clipped_value, global_norm, was_clipped
@computations.federated_computation(
init_fn.type_signature.result, computation_types.at_clients(value_type),
computation_types.at_clients(tf.float32))
def next_fn(state, value, weight):
def next_fn_impl(state, value, weight=None):
clipping_norm_state, agg_state, clipped_count_state = state
clipping_norm = self._clipping_norm_process.report(clipping_norm_state)
clipped_value, global_norm, was_clipped = intrinsics.federated_map(
clip, (value, intrinsics.federated_broadcast(clipping_norm)))
clip_fn, (value, intrinsics.federated_broadcast(clipping_norm)))
new_clipping_norm_state = self._clipping_norm_process.next(
clipping_norm_state, global_norm)
agg_output = inner_agg_process.next(agg_state, clipped_value, weight)
clipped_count_output = clipped_count_agg_process.next(
clipped_count_state, was_clipped)
if weight is None:
agg_output = inner_agg_next(agg_state, clipped_value)
else:
agg_output = inner_agg_next(agg_state, clipped_value, weight)
clipped_count_output = clipped_count_agg_next(clipped_count_state,
was_clipped)
new_state = collections.OrderedDict(
clipping_norm=new_clipping_norm_state,
......@@ -196,11 +241,27 @@ class ClippingFactory(factory.AggregationProcessFactory):
result=agg_output.result,
measurements=intrinsics.federated_zip(measurements))
return aggregation_process.AggregationProcess(init_fn, next_fn)
if len(inner_agg_next.type_signature.parameter) == 2:
@computations.federated_computation(
state_type, inner_agg_next.type_signature.parameter[1])
def next_fn(state, value):
return next_fn_impl(state, value)
else:
assert len(inner_agg_next.type_signature.parameter) == 3
@computations.federated_computation(
state_type, inner_agg_next.type_signature.parameter[1],
inner_agg_next.type_signature.parameter[2])
def next_fn(state, value, weight):
return next_fn_impl(state, value, weight)
return next_fn
class ZeroingFactory(factory.AggregationProcessFactory):
"""`AggregationProcessFactory` for zeroing large values.
class ZeroingFactory(factory.UnweightedAggregationFactory,
factory.WeightedAggregationFactory):
"""`AggregationProcess` factory for zeroing large values.
The created `tff.templates.AggregationProcess` zeroes out any values whose
norm is greater than that determined by the provided `zeroing_norm`, before
......@@ -219,7 +280,7 @@ class ZeroingFactory(factory.AggregationProcessFactory):
def __init__(self,
zeroing_norm: Union[float, estimation_process.EstimationProcess],
inner_agg_factory: factory.AggregationProcessFactory,
inner_agg_factory: _InnerFactoryType,
norm_order: float = 2.0):
"""Initializes a ZeroingFactory.
......@@ -233,8 +294,7 @@ class ZeroingFactory(factory.AggregationProcessFactory):
after zeroing.
norm_order: A float for the order of the norm. Must be 1, 2, or np.inf.
"""
py_typecheck.check_type(inner_agg_factory,
factory.AggregationProcessFactory)
py_typecheck.check_type(inner_agg_factory, _InnerFactoryType.__args__)
self._inner_agg_factory = inner_agg_factory
py_typecheck.check_type(zeroing_norm,
......@@ -254,32 +314,69 @@ class ZeroingFactory(factory.AggregationProcessFactory):
# we will make this customizable to allow DP measurements.
self._zeroed_count_agg_factory = sum_factory.SumFactory()
def create(
def create_unweighted(
self,
value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
py_typecheck.check_type(value_type, factory.ValueType.__args__)
# This could perhaps be relaxed if we want to zero out ints for example.
if type_analysis.contains(value_type, predicate=_contains_non_float_dtype):
raise TypeError(f'All values in provided value_type must be of floating '
f'dtype. Provided value_type: {value_type}')
inner_agg_process = self._inner_agg_factory.create_unweighted(value_type)
zeroed_count_agg_process = (
self._zeroed_count_agg_factory.create_unweighted(
computation_types.to_type(COUNT_TF_TYPE)))
init_fn = self._create_init_fn(inner_agg_process.initialize,
zeroed_count_agg_process.initialize)
next_fn = self._create_next_fn(inner_agg_process.next,
zeroed_count_agg_process.next,
init_fn.type_signature.result)
return aggregation_process.AggregationProcess(init_fn, next_fn)
def create_weighted(
self, value_type: factory.ValueType,
weight_type: factory.ValueType) -> aggregation_process.AggregationProcess:
py_typecheck.check_type(value_type, factory.ValueType.__args__)
py_typecheck.check_type(weight_type, factory.ValueType.__args__)
# This could perhaps be relaxed if we want to zero out ints for example.
if not all([t.dtype.is_floating for t in structure.flatten(value_type)]):
if type_analysis.contains(value_type, predicate=_contains_non_float_dtype):
raise TypeError(f'All values in provided value_type must be of floating '
f'dtype. Provided value_type: {value_type}')
inner_agg_process = self._inner_agg_factory.create(value_type)
inner_agg_process = self._inner_agg_factory.create_weighted(
value_type, weight_type)
zeroed_count_agg_process = (
self._zeroed_count_agg_factory.create_unweighted(
computation_types.to_type(COUNT_TF_TYPE)))
init_fn = self._create_init_fn(inner_agg_process.initialize,
zeroed_count_agg_process.initialize)
next_fn = self._create_next_fn(inner_agg_process.next,
zeroed_count_agg_process.next,
init_fn.type_signature.result)
return aggregation_process.AggregationProcess(init_fn, next_fn)
count_type = computation_types.to_type(COUNT_TF_TYPE)
zeroed_count_agg_process = self._zeroed_count_agg_factory.create(count_type)
def _create_init_fn(self, inner_agg_initialize, zeroed_count_agg_initialize):
@computations.federated_computation()
def init_fn():
return intrinsics.federated_zip(
collections.OrderedDict(
zeroing_norm=self._zeroing_norm_process.initialize(),
inner_agg=inner_agg_process.initialize(),
zeroed_count_agg=zeroed_count_agg_process.initialize()))
inner_agg=inner_agg_initialize(),
zeroed_count_agg=zeroed_count_agg_initialize()))
return init_fn
def _create_next_fn(self, inner_agg_next, zeroed_count_agg_next, state_type):
@computations.tf_computation(value_type, NORM_TF_TYPE)
def zero(value, zeroing_norm):
@computations.tf_computation(
inner_agg_next.type_signature.parameter[1].member, NORM_TF_TYPE)
def zero_fn(value, zeroing_norm):
if self._norm_order == 1.0:
norm = _global_l1_norm(value)
elif self._norm_order == 2.0:
......@@ -294,22 +391,24 @@ class ZeroingFactory(factory.AggregationProcessFactory):
was_zeroed = tf.cast(should_zero, COUNT_TF_TYPE)
return zeroed_value, norm, was_zeroed
@computations.federated_computation(
init_fn.type_signature.result, computation_types.at_clients(value_type),
computation_types.at_clients(tf.float32))
def next_fn(state, value, weight):
def next_fn_impl(state, value, weight=None):
zeroing_norm_state, agg_state, zeroed_count_state = state
zeroing_norm = self._zeroing_norm_process.report(zeroing_norm_state)
zeroed_value, norm, was_zeroed = intrinsics.federated_map(
zero, (value, intrinsics.federated_broadcast(zeroing_norm)))
zero_fn, (value, intrinsics.federated_broadcast(zeroing_norm)))
new_zeroing_norm_state = self._zeroing_norm_process.next(
zeroing_norm_state, norm)
agg_output = inner_agg_process.next(agg_state, zeroed_value, weight)
zeroed_count_output = zeroed_count_agg_process.next(
zeroed_count_state, was_zeroed)
if weight is None:
agg_output = inner_agg_next(agg_state, zeroed_value)
else:
agg_output = inner_agg_next(agg_state, zeroed_value, weight)
zeroed_count_output = zeroed_count_agg_next(zeroed_count_state,
was_zeroed)
new_state = collections.OrderedDict(
zeroing_norm=new_zeroing_norm_state,
......@@ -325,7 +424,23 @@ class ZeroingFactory(factory.AggregationProcessFactory):
result=agg_output.result,
measurements=intrinsics.federated_zip(measurements))
return aggregation_process.AggregationProcess(init_fn, next_fn)
if len(inner_agg_next.type_signature.parameter) == 2:
@computations.federated_computation(
state_type, inner_agg_next.type_signature.parameter[1])
def next_fn(state, value):
return next_fn_impl(state, value)
else:
assert len(inner_agg_next.type_signature.parameter) == 3
@computations.federated_computation(
state_type, inner_agg_next.type_signature.parameter[1],
inner_agg_next.type_signature.parameter[2])
def next_fn(state, value, weight):
return next_fn_impl(state, value, weight)
return next_fn
def _global_inf_norm(l):
......
......@@ -28,8 +28,8 @@ from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import measured_process
class DifferentiallyPrivateFactory(factory.AggregationProcessFactory):
"""`AggregationProcessFactory` for tensorflow_privacy DPQueries.
class DifferentiallyPrivateFactory(factory.UnweightedAggregationFactory):
"""`UnweightedAggregationFactory` for tensorflow_privacy DPQueries.
The created `tff.templates.AggregationProcess` aggregates values placed at
`CLIENTS` according to the provided DPQuery, and outputs the result placed at
......@@ -56,22 +56,23 @@ class DifferentiallyPrivateFactory(factory.AggregationProcessFactory):
def __init__(self,
query: tfp.DPQuery,
record_aggregation_factory: Optional[
factory.AggregationProcessFactory] = None):
factory.UnweightedAggregationFactory] = None):
"""Initializes `DifferentiallyPrivateFactory`.
Args:
query: A `tfp.SumAggregationDPQuery` to perform private estimation.
record_aggregation_factory: A `tff.aggregators.AggregationProcessFactory`
to aggregate values after preprocessing by the `query`. If `None`,
defaults to `tff.aggregators.SumFactory`. The provided factory is
assumed to implement a sum, and to have the property that it does not
increase the sensitivity of the query-- typically this means that it
should not increase the l2 norm of the records when aggregating.
record_aggregation_factory: A
`tff.aggregators.UnweightedAggregationFactory` to aggregate values
after preprocessing by the `query`. If `None`, defaults to
`tff.aggregators.SumFactory`. The provided factory is assumed to
implement a sum, and to have the property that it does not increase
the sensitivity of the query - typically this means that it should not
increase the l2 norm of the records when aggregating.
Raises:
TypeError: If `query` is not an instance of `tfp.SumAggregationDPQuery` or
`record_aggregation_factory` is not an instance of
`tff.aggregators.AggregationProcessFactory`.
`tff.aggregators.UnweightedAggregationFactory`.
"""
py_typecheck.check_type(query, tfp.SumAggregationDPQuery)
self._query = query
......@@ -80,29 +81,12 @@ class DifferentiallyPrivateFactory(factory.AggregationProcessFactory):
record_aggregation_factory = sum_factory.SumFactory()
py_typecheck.check_type(record_aggregation_factory,
factory.AggregationProcessFactory)
factory.UnweightedAggregationFactory)
self._record_aggregation_factory = record_aggregation_factory
def create(
def create_unweighted(
self,
value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
"""Creates a `tff.aggregators.AggregationProcess` aggregating `value_type`.
The provided `value_type` is a non-federated `tff.Type` object, that is,
`value_type.is_federated()` should return `False`. Provided `value_type`
must be a `tff.TensorType` or a `tff.StructType`.
The returned `tff.aggregators.AggregationProcess` will be created for
aggregating values matching `value_type`. That is, its `next` method will
expect type `<S@SERVER, {value_type}@CLIENTS>`, where `S` is the unplaced
return type of its `initialize` method.
Args:
value_type: A `tff.Type` without placement.
Returns:
A `tff.templates.AggregationProcess`.
"""
py_typecheck.check_type(value_type, factory.ValueType.__args__)
query_initial_state_fn = computations.tf_computation(
......@@ -120,7 +104,7 @@ class DifferentiallyPrivateFactory(factory.AggregationProcessFactory):
derive_metrics = computations.tf_computation(self._query.derive_metrics,
query_state_type)
record_agg_process = self._record_aggregation_factory.create(
record_agg_process = self._record_aggregation_factory.create_unweighted(
query_record_type)
@computations.federated_computation()
......
......@@ -46,9 +46,9 @@ class DPFactoryComputationTest(test_case.TestCase, parameterized.TestCase):
def test_type_properties(self, value_type, inner_agg_factory):
agg_factory = dp_factory.DifferentiallyPrivateFactory(
_test_dp_query, inner_agg_factory)
self.assertIsInstance(agg_factory, factory.AggregationProcessFactory)
self.assertIsInstance(agg_factory, factory.UnweightedAggregationFactory)
value_type = computation_types.to_type(value_type)
process = agg_factory.create(value_type)
process = agg_factory.create_unweighted(value_type)
self.assertIsInstance(process, aggregation_process.AggregationProcess)
query_state = _test_dp_query.initial_global_state()
......@@ -100,7 +100,7 @@ class DPFactoryComputationTest(test_case.TestCase, parameterized.TestCase):
def test_incorrect_value_type_raises(self, bad_value_type):
agg_factory = dp_factory.DifferentiallyPrivateFactory(_test_dp_query)
with self.assertRaises(TypeError):
agg_factory.create(bad_value_type)
agg_factory.create_unweighted(bad_value_type)
class DPFactoryExecutionTest(test_case.TestCase):
......@@ -108,7 +108,7 @@ class DPFactoryExecutionTest(test_case.TestCase):
def test_simple_sum(self):
agg_factory = dp_factory.DifferentiallyPrivateFactory(_test_dp_query)
value_type = computation_types.to_type(tf.float32)
process = agg_factory.create(value_type)
process = agg_factory.create_unweighted(value_type)
# The test query has clip 1.0 and no noise, so this computes clipped sum.
......@@ -121,7 +121,7 @@ class DPFactoryExecutionTest(test_case.TestCase):
def test_structure_sum(self):
agg_factory = dp_factory.DifferentiallyPrivateFactory(_test_dp_query)
value_type = computation_types.to_type([tf.float32, tf.float32])
process = agg_factory.create(value_type)
process = agg_factory.create