提交 b19d3800 编辑于 作者: Shanshan Wu's avatar Shanshan Wu 提交者: tensorflow-copybara
浏览文件

Add `measurements` to `UnweightedReservoirSamplingFactory`.

The returned `measurements` has the same structure as the client value, and every leaf node is a `tf.int64` scalar tensor counting the number of clients having non-finite value in that leaf.

PiperOrigin-RevId: 414013617
上级 d87edcc6
......@@ -37,6 +37,19 @@ from tensorflow_federated.python.core.templates import measured_process
SEED_SENTINEL = -1
def _is_tensor_or_structure_of_tensors(
value_type: computation_types.Type) -> bool:
"""Return True if `value_type` is a TensorType or structure of TensorTypes."""
# TODO(b/181365504): relax this to allow `StructType` once a `Struct` can be
# returned from `tf.function` decorated methods.
def is_tensor_or_struct_with_py_type(t: computation_types.Type) -> bool:
return t.is_tensor() or t.is_struct_with_python()
return type_analysis.contains_only(value_type,
is_tensor_or_struct_with_py_type)
def _build_reservoir_type(
sample_value_type: computation_types.Type) -> computation_types.Type:
"""Create the TFF type for the reservoir's state.
......@@ -67,14 +80,7 @@ def _build_reservoir_type(
which has an unknown size. This will be used to concatenate samples and
store them in the reservoir.
"""
# TODO(b/181365504): relax this to allow `StructType` once a `Struct` can be
# returned from `tf.function` decorated methods.
def is_tensor_or_struct_with_py_type(t: computation_types.Type) -> bool:
return t.is_tensor() or t.is_struct_with_python()
if not type_analysis.contains_only(sample_value_type,
is_tensor_or_struct_with_py_type):
if not _is_tensor_or_structure_of_tensors(sample_value_type):
raise TypeError('Cannot create a reservoir for type structure. Sample type '
'must only contain `TensorType` or `StructWithPythonType`, '
f'got a {sample_value_type!r}.')
......@@ -299,15 +305,83 @@ def _build_finalize_sample_computation(
return finalize_samples
def _build_check_non_finite_leaves_computation(
value_type: computation_types.Type) -> computation_base.Computation:
"""Builds the computation for checking non-finite leaves in the client value.
Args:
value_type: The `tff.typs.Type` of the client value. Must only contain
`tff.types.TensorType`s or `tff.types.StructWithPythonType`s.
Returns:
A TFF computation (constructed by the `tff.tf_computation` decoration) that
takes in a client-side value as input, and returns a value of the same
structure as the client value, with all the leaves being a `tf.int64` 0/1
scalar tensor indicating whether the corresponding leaf tensor in the input
client value has any non-finite (`NaN` or `Inf`) value.
Raises:
TypeError: if `value_type` contains types other than `tff.types.TensorType`
or `tff.types.StructWithPythonType`.
"""
if not _is_tensor_or_structure_of_tensors(value_type):
raise TypeError(
'Cannot check non-finite leaves for the client value. Expected the '
'client value type to only contain `TensorType`s or '
f'`StructWithPythonType`s, got a {value_type!r}.')
@computations.tf_computation(value_type)
@tf.function
def check_non_finite_leaves(client_value):
def is_non_finite(leaf_tensor: tf.Tensor) -> tf.Tensor:
"""Returns True if `leaf_tensor` has at least one non-finite value."""
# `tf.math.is_finite` only works for tensors of float dtype. This is
# because the type of `np.nan` or `np.inf` is float, so it only exists in
# tensors of float dtype.
if leaf_tensor.dtype.is_floating:
# TODO(b/201213657): replaces `tf.math.is_finite` by a memory-efficient
# way of checking finite tensors.
return tf.math.logical_not(
tf.reduce_all(tf.math.is_finite(leaf_tensor)))
return tf.constant(False)
if isinstance(client_value, tf.Tensor):
return tf.cast(is_non_finite(client_value), tf.int64)
else:
# The returned structure is the same as `client_value`, but with all the
# leaves being an integer 0/1 scalar tensor indicating whether that leaf
# tensor has any non-finite value.
return tf.nest.map_structure(
lambda leaf_tensor: tf.cast(is_non_finite(leaf_tensor), tf.int64),
client_value)
return check_non_finite_leaves
class UnweightedReservoirSamplingFactory(factory.UnweightedAggregationFactory):
"""An `UnweightedAggregationFactory` for reservoir sampling values.
The created `tff.templates.AggregationProcess` samples values placed at
`CLIENTS`, and outputs the sample placed at `SERVER`.
The process has empty `state`. The `measurements` of this factory include
the number of non-finite (`NaN` or `Inf` values) for each leaf in the value
structure.
The process has empty `state`. The `measurements` of this factory counts the
number of non-finite (`NaN` or `Inf` values) leaves in the client values
*before* sampling. Specifically, the returned `measurements` has the same
structure as the client value, and every leaf node is a `tf.int64` scalar
tensor counting the number of clients having non-finite value in that leaf.
For example, suppose we are aggregating from three clients:
```
client_value_1 = collections.OrderedDict(a=[1.0, 2.0], b=[1.0, np.nan])
client_value_2 = collections.OrderedDict(a=[np.nan, np.inf], b=[1.0, 2.0])
client_value_3 = collections.OrderedDict(a=[1.0, 2.0], b=[np.inf, np.nan])
```
Then `measurements` will be:
```
collections.OrderedDict(a=tf.constant(1, dtype=int64),
b=tf.constant(2, dtype=int64)
```
For more about reservoir sampling see
https://en.wikipedia.org/wiki/Reservoir_sampling.
......@@ -334,6 +408,9 @@ class UnweightedReservoirSamplingFactory(factory.UnweightedAggregationFactory):
def next_fn(unused_state, value):
# Empty tuple is the `None` of TFF.
empty_tuple = intrinsics.federated_value((), placements.SERVER)
non_finite_leaves_counts = intrinsics.federated_sum(
intrinsics.federated_map(
_build_check_non_finite_leaves_computation(value_type), value))
initial_reservoir = _build_initial_sample_reservoir(value_type)
sample_value = _build_sample_value_computation(value_type,
self._sample_size)
......@@ -347,6 +424,8 @@ class UnweightedReservoirSamplingFactory(factory.UnweightedAggregationFactory):
merge=merge_samples,
report=finalize_sample)
return measured_process.MeasuredProcessOutput(
state=empty_tuple, result=samples, measurements=empty_tuple)
state=empty_tuple,
result=samples,
measurements=non_finite_leaves_counts)
return aggregation_process.AggregationProcess(init_fn, next_fn)
......@@ -15,6 +15,7 @@
import collections
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.aggregators import sampling
......@@ -450,6 +451,56 @@ class BuildFinalizeSampleTest(test_case.TestCase):
self.assertAllEqual(finalize_computation(reservoir), test_samples)
class BuildCheckNonFiniteLeavesComputationTest(test_case.TestCase,
parameterized.TestCase):
@parameterized.named_parameters(('float32_nan', tf.float32, np.nan, True),
('bfloat16_inf', tf.bfloat16, np.inf, True),
('half_nan', tf.half, np.nan, True),
('float64_inf', tf.float64, np.inf, True),
('int32_finite', tf.int32, 1, False),
('bool_finite', tf.bool, False, False))
def test_scalar(self, dtype, value, is_non_finite):
computation = sampling._build_check_non_finite_leaves_computation(
TensorType(dtype))
result = computation(value)
expected_result = tf.constant(is_non_finite, dtype=tf.int64)
self.assertEqual(result, expected_result)
def test_structure(self):
value_type = computation_types.to_type(
collections.OrderedDict(
a=TensorType(tf.int32),
b=[TensorType(tf.float32, [3]),
TensorType(tf.bool)],
c=collections.OrderedDict(d=TensorType(tf.float64, [2, 2]))))
computation = sampling._build_check_non_finite_leaves_computation(
value_type)
value = collections.OrderedDict(
a=1,
b=[[1.0, np.nan, np.inf], True],
c=collections.OrderedDict(d=[[np.inf, 2.0], [3.0, 4.0]]))
result = computation(value)
expected_result = collections.OrderedDict(
a=tf.constant(0, dtype=tf.int64),
b=[tf.constant(1, dtype=tf.int64),
tf.constant(0, dtype=tf.int64)],
c=collections.OrderedDict(d=tf.constant(1, dtype=tf.int64)))
self.assertEqual(result, expected_result)
def test_fails_with_non_tensor_type(self):
with self.assertRaisesRegex(TypeError, 'only contain `TensorType`s'):
sampling._build_check_non_finite_leaves_computation(
SequenceType(TensorType(tf.int32)))
with self.assertRaisesRegex(TypeError, 'only contain `TensorType`s'):
sampling._build_check_non_finite_leaves_computation(
computation_types.to_type(
collections.OrderedDict(
a=TensorType(tf.float32, [3]),
b=[SequenceType(TensorType(tf.int32)),
TensorType(tf.bool)])))
class UnweightedReservoirSamplingFactoryTest(test_case.TestCase,
parameterized.TestCase):
......@@ -503,6 +554,38 @@ class UnweightedReservoirSamplingFactoryTest(test_case.TestCase,
with self.assertRaises(TypeError):
sampling.UnweightedReservoirSamplingFactory(sample_size='5')
def test_measurements_scalar_value(self):
process = sampling.UnweightedReservoirSamplingFactory(sample_size=1).create(
computation_types.to_type(tf.float32))
state = process.initialize()
output = process.next(state, [1.0, np.nan, np.inf, 2.0, 3.0])
# Two clients' values are non-infinte.
self.assertEqual(output.measurements, tf.constant(2, dtype=tf.int64))
def test_measurements_structure_value(self):
process = sampling.UnweightedReservoirSamplingFactory(sample_size=1).create(
computation_types.to_type(
collections.OrderedDict(
a=TensorType(tf.float32),
b=[TensorType(tf.float32, [2, 2]),
TensorType(tf.bool)])))
state = process.initialize()
output = process.next(state, [
collections.OrderedDict(
a=1.0, b=[[[1.0, np.nan], [np.inf, 4.0]], True]),
collections.OrderedDict(a=2.0, b=[[[1.0, 2.0], [3.0, 4.0]], False]),
collections.OrderedDict(
a=np.inf, b=[[[np.nan, 2.0], [3.0, 4.0]], True])
])
self.assertEqual(
output.measurements,
collections.OrderedDict(
# One client has non-infinte tensors for this leaf node.
a=tf.constant(1, dtype=tf.int64),
# Two clients have non-infinte tensors for this leaf node.
b=[tf.constant(2, dtype=tf.int64),
tf.constant(0, dtype=tf.int64)]))
if __name__ == '__main__':
execution_contexts.set_local_python_execution_context()
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册