提交 147a6c60 编辑于 作者: Jakub Konecny's avatar Jakub Konecny 提交者: tensorflow-copybara
浏览文件

Modifies `federated_secure_sum` to support broadcasting of its `bitwidth` argument.

PiperOrigin-RevId: 345150315
上级 79df23b0
......@@ -70,7 +70,7 @@ def federated_aggregate(value, zero, accumulate, merge, report):
using the multi-stage process described above.
Raises:
TypeError: if the arguments are not of the types specified above.
TypeError: If the arguments are not of the types specified above.
"""
factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack)
return factory.federated_aggregate(value, zero, accumulate, merge, report)
......@@ -120,7 +120,7 @@ def federated_mean(value, weight=None):
member constituents contributed by all clients are equally weighted).
Raises:
TypeError: if `value` is not a federated TFF value placed at `tff.CLIENTS`,
TypeError: If `value` is not a federated TFF value placed at `tff.CLIENTS`,
or if `weight` is not a federated integer or a floating-point tensor with
the matching placement.
"""
......@@ -141,7 +141,7 @@ def federated_broadcast(value):
type placed at the `tff.CLIENTS`, all members of which are equal.
Raises:
TypeError: if the argument is not a federated TFF value placed at the
TypeError: If the argument is not a federated TFF value placed at the
`tff.SERVER`.
"""
factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack)
......@@ -159,7 +159,7 @@ def federated_collect(value):
the `tff.SERVER`.
Raises:
TypeError: if the argument is not a federated TFF value placed at
TypeError: If the argument is not a federated TFF value placed at
`tff.CLIENTS`.
"""
factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack)
......@@ -240,7 +240,7 @@ def federated_reduce(value, zero, op):
item.
Raises:
TypeError: if the arguments are not of the types specified above.
TypeError: If the arguments are not of the types specified above.
"""
factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack)
return factory.federated_reduce(value, zero, op)
......@@ -283,15 +283,16 @@ def federated_secure_sum(value, bitwidth):
Args:
value: An integer value of a TFF federated type placed at the `tff.CLIENTS`,
in the range [0, 2^bitwidth - 1].
bitwidth: An integer or nested structure of integers. For each tensor in
`value`, `bitwidth` must contain exactly one corresponding integer.
bitwidth: An integer or nested structure of integers matching the structure
of `value`. If integer `bitwidth` is used with a nested `value`, the same
integer is used for each tensor in `value`.
Returns:
A representation of the sum of the member constituents of `value` placed
on the `tff.SERVER`.
Raises:
TypeError: if the argument is not a federated TFF value placed at
TypeError: If the argument is not a federated TFF value placed at
`tff.CLIENTS`.
"""
factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack)
......@@ -312,7 +313,7 @@ def federated_sum(value):
on the `tff.SERVER`.
Raises:
TypeError: if the argument is not a federated TFF value placed at
TypeError: If the argument is not a federated TFF value placed at
`tff.CLIENTS`.
"""
factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack)
......@@ -352,7 +353,7 @@ def federated_zip(value):
corresponding member components of the elements of `value`.
Raises:
TypeError: if the argument is not a named tuple of federated values with the
TypeError: If the argument is not a named tuple of federated values with the
same placement.
"""
factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack)
......
......@@ -151,6 +151,7 @@ py_library(
":value_impl",
":value_utils",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:value_base",
"//tensorflow_federated/python/core/impl/compiler:building_block_factory",
......
......@@ -14,6 +14,7 @@
"""A factory of intrinsics for use in composing federated computations."""
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import value_base
from tensorflow_federated.python.core.impl import value_impl
......@@ -383,9 +384,9 @@ class IntrinsicFactory(object):
placement_literals.CLIENTS,
'value to be summed')
type_analysis.check_is_structure_of_integers(value.type_signature)
bitwidth = value_impl.to_value(bitwidth, None, self._context_stack)
bitwidth_value = value_impl.to_value(bitwidth, None, self._context_stack)
value_member_type = value.type_signature.member
bitwidth_type = bitwidth.type_signature
bitwidth_type = bitwidth_value.type_signature
if not type_analysis.is_valid_bitwidth_type_for_value_type(
bitwidth_type, value_member_type):
raise TypeError(
......@@ -393,9 +394,14 @@ class IntrinsicFactory(object):
'the structure of `value`, with one integer bitwidth per tensor in '
'`value`. Found `value` of `{}` and `bitwidth` of `{}`.'.format(
value_member_type, bitwidth_type))
if bitwidth_type.is_tensor() and value_member_type.is_struct():
bitwidth_value = value_impl.to_value(
structure.map_structure(lambda _: bitwidth, value_member_type), None,
self._context_stack)
value = value_impl.ValueImpl.get_comp(value)
bitwidth = value_impl.ValueImpl.get_comp(bitwidth)
comp = building_block_factory.create_federated_secure_sum(value, bitwidth)
bitwidth_value = value_impl.ValueImpl.get_comp(bitwidth_value)
comp = building_block_factory.create_federated_secure_sum(
value, bitwidth_value)
comp = self._bind_comp_as_reference(comp)
return value_impl.ValueImpl(comp, self._context_stack)
......
......@@ -47,6 +47,15 @@ class FederatedSecureSumTest(absltest.TestCase):
self.assertEqual(intrinsic.type_signature.compact_representation(),
'<int32,<int32,int32>>@SERVER')
def test_type_signature_with_structure_of_ints_scalar_bitwidth(self):
value = intrinsics.federated_value([1, [1, 1]], placement_literals.CLIENTS)
bitwidth = 8
intrinsic = intrinsics.federated_secure_sum(value, bitwidth)
self.assertEqual(intrinsic.type_signature.compact_representation(),
'<int32,<int32,int32>>@SERVER')
def test_type_signature_with_one_tensor_and_bitwidth(self):
value = intrinsics.federated_value(
np.ndarray(shape=(5, 37), dtype=np.int16), placement_literals.CLIENTS)
......@@ -84,7 +93,7 @@ class FederatedSecureSumTest(absltest.TestCase):
def test_raises_type_error_with_different_structures(self):
value = intrinsics.federated_value([1, [1, 1]], placement_literals.CLIENTS)
bitwidth = 8
bitwidth = [8, 4, 2]
with self.assertRaises(TypeError):
intrinsics.federated_secure_sum(value, bitwidth)
......
......@@ -347,10 +347,9 @@ def is_valid_bitwidth_type_for_value_type(
py_typecheck.check_type(bitwidth_type, computation_types.Type)
py_typecheck.check_type(value_type, computation_types.Type)
if value_type.is_tensor() and bitwidth_type.is_tensor():
# Here, `value_type` refers to a tensor. Rather than check that
# `bitwidth_type` is exactly the same, we check that it is a single integer,
# since we want a single bitwidth integer per tensor.
if bitwidth_type.is_tensor():
# This condition applies to both `value_type` being a tensor or structure,
# as the same integer bitwidth can be used for all values in the structure.
return bitwidth_type.dtype.is_integer and (
bitwidth_type.shape.num_elements() == 1)
elif value_type.is_struct() and bitwidth_type.is_struct():
......
......@@ -305,6 +305,9 @@ class IsValidBitwidthTypeForValueType(parameterized.TestCase):
('different kinds of ints',
computation_types.TensorType(tf.int32),
computation_types.TensorType(tf.int8)),
('single int_for_struct',
computation_types.TensorType(tf.int32),
computation_types.StructType([tf.int32, tf.int32])),
)
# pyformat: enable
def test_returns_true(self, bitwidth_type, value_type):
......@@ -314,9 +317,6 @@ class IsValidBitwidthTypeForValueType(parameterized.TestCase):
# pyformat: disable
@parameterized.named_parameters(
('single int_for_struct',
computation_types.TensorType(tf.int32),
computation_types.StructType([tf.int32, tf.int32])),
('miscounted struct',
computation_types.StructType([tf.int32, tf.int32, tf.int32]),
computation_types.StructType([tf.int32, tf.int32])),
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册