Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
147a6c60
提交
147a6c60
编辑于
12月 01, 2020
作者:
Jakub Konecny
提交者:
tensorflow-copybara
12月 01, 2020
浏览文件
Modifies `federated_secure_sum` to support broadcasting of its `bitwidth` argument.
PiperOrigin-RevId: 345150315
上级
79df23b0
变更
6
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/core/api/intrinsics.py
浏览文件 @
147a6c60
...
...
@@ -70,7 +70,7 @@ def federated_aggregate(value, zero, accumulate, merge, report):
using the multi-stage process described above.
Raises:
TypeError:
i
f the arguments are not of the types specified above.
TypeError:
I
f the arguments are not of the types specified above.
"""
factory
=
intrinsic_factory
.
IntrinsicFactory
(
context_stack_impl
.
context_stack
)
return
factory
.
federated_aggregate
(
value
,
zero
,
accumulate
,
merge
,
report
)
...
...
@@ -120,7 +120,7 @@ def federated_mean(value, weight=None):
member constituents contributed by all clients are equally weighted).
Raises:
TypeError:
i
f `value` is not a federated TFF value placed at `tff.CLIENTS`,
TypeError:
I
f `value` is not a federated TFF value placed at `tff.CLIENTS`,
or if `weight` is not a federated integer or a floating-point tensor with
the matching placement.
"""
...
...
@@ -141,7 +141,7 @@ def federated_broadcast(value):
type placed at the `tff.CLIENTS`, all members of which are equal.
Raises:
TypeError:
i
f the argument is not a federated TFF value placed at the
TypeError:
I
f the argument is not a federated TFF value placed at the
`tff.SERVER`.
"""
factory
=
intrinsic_factory
.
IntrinsicFactory
(
context_stack_impl
.
context_stack
)
...
...
@@ -159,7 +159,7 @@ def federated_collect(value):
the `tff.SERVER`.
Raises:
TypeError:
i
f the argument is not a federated TFF value placed at
TypeError:
I
f the argument is not a federated TFF value placed at
`tff.CLIENTS`.
"""
factory
=
intrinsic_factory
.
IntrinsicFactory
(
context_stack_impl
.
context_stack
)
...
...
@@ -240,7 +240,7 @@ def federated_reduce(value, zero, op):
item.
Raises:
TypeError:
i
f the arguments are not of the types specified above.
TypeError:
I
f the arguments are not of the types specified above.
"""
factory
=
intrinsic_factory
.
IntrinsicFactory
(
context_stack_impl
.
context_stack
)
return
factory
.
federated_reduce
(
value
,
zero
,
op
)
...
...
@@ -283,15 +283,16 @@ def federated_secure_sum(value, bitwidth):
Args:
value: An integer value of a TFF federated type placed at the `tff.CLIENTS`,
in the range [0, 2^bitwidth - 1].
bitwidth: An integer or nested structure of integers. For each tensor in
`value`, `bitwidth` must contain exactly one corresponding integer.
bitwidth: An integer or nested structure of integers matching the structure
of `value`. If integer `bitwidth` is used with a nested `value`, the same
integer is used for each tensor in `value`.
Returns:
A representation of the sum of the member constituents of `value` placed
on the `tff.SERVER`.
Raises:
TypeError:
i
f the argument is not a federated TFF value placed at
TypeError:
I
f the argument is not a federated TFF value placed at
`tff.CLIENTS`.
"""
factory
=
intrinsic_factory
.
IntrinsicFactory
(
context_stack_impl
.
context_stack
)
...
...
@@ -312,7 +313,7 @@ def federated_sum(value):
on the `tff.SERVER`.
Raises:
TypeError:
i
f the argument is not a federated TFF value placed at
TypeError:
I
f the argument is not a federated TFF value placed at
`tff.CLIENTS`.
"""
factory
=
intrinsic_factory
.
IntrinsicFactory
(
context_stack_impl
.
context_stack
)
...
...
@@ -352,7 +353,7 @@ def federated_zip(value):
corresponding member components of the elements of `value`.
Raises:
TypeError:
i
f the argument is not a named tuple of federated values with the
TypeError:
I
f the argument is not a named tuple of federated values with the
same placement.
"""
factory
=
intrinsic_factory
.
IntrinsicFactory
(
context_stack_impl
.
context_stack
)
...
...
tensorflow_federated/python/core/impl/BUILD
浏览文件 @
147a6c60
...
...
@@ -151,6 +151,7 @@ py_library(
":value_impl"
,
":value_utils"
,
"//tensorflow_federated/python/common_libs:py_typecheck"
,
"//tensorflow_federated/python/common_libs:structure"
,
"//tensorflow_federated/python/core/api:computation_types"
,
"//tensorflow_federated/python/core/api:value_base"
,
"//tensorflow_federated/python/core/impl/compiler:building_block_factory"
,
...
...
tensorflow_federated/python/core/impl/intrinsic_factory.py
浏览文件 @
147a6c60
...
...
@@ -14,6 +14,7 @@
"""A factory of intrinsics for use in composing federated computations."""
from
tensorflow_federated.python.common_libs
import
py_typecheck
from
tensorflow_federated.python.common_libs
import
structure
from
tensorflow_federated.python.core.api
import
computation_types
from
tensorflow_federated.python.core.api
import
value_base
from
tensorflow_federated.python.core.impl
import
value_impl
...
...
@@ -383,9 +384,9 @@ class IntrinsicFactory(object):
placement_literals
.
CLIENTS
,
'value to be summed'
)
type_analysis
.
check_is_structure_of_integers
(
value
.
type_signature
)
bitwidth
=
value_impl
.
to_value
(
bitwidth
,
None
,
self
.
_context_stack
)
bitwidth
_value
=
value_impl
.
to_value
(
bitwidth
,
None
,
self
.
_context_stack
)
value_member_type
=
value
.
type_signature
.
member
bitwidth_type
=
bitwidth
.
type_signature
bitwidth_type
=
bitwidth
_value
.
type_signature
if
not
type_analysis
.
is_valid_bitwidth_type_for_value_type
(
bitwidth_type
,
value_member_type
):
raise
TypeError
(
...
...
@@ -393,9 +394,14 @@ class IntrinsicFactory(object):
'the structure of `value`, with one integer bitwidth per tensor in '
'`value`. Found `value` of `{}` and `bitwidth` of `{}`.'
.
format
(
value_member_type
,
bitwidth_type
))
if
bitwidth_type
.
is_tensor
()
and
value_member_type
.
is_struct
():
bitwidth_value
=
value_impl
.
to_value
(
structure
.
map_structure
(
lambda
_
:
bitwidth
,
value_member_type
),
None
,
self
.
_context_stack
)
value
=
value_impl
.
ValueImpl
.
get_comp
(
value
)
bitwidth
=
value_impl
.
ValueImpl
.
get_comp
(
bitwidth
)
comp
=
building_block_factory
.
create_federated_secure_sum
(
value
,
bitwidth
)
bitwidth_value
=
value_impl
.
ValueImpl
.
get_comp
(
bitwidth_value
)
comp
=
building_block_factory
.
create_federated_secure_sum
(
value
,
bitwidth_value
)
comp
=
self
.
_bind_comp_as_reference
(
comp
)
return
value_impl
.
ValueImpl
(
comp
,
self
.
_context_stack
)
...
...
tensorflow_federated/python/core/impl/intrinsic_factory_test.py
浏览文件 @
147a6c60
...
...
@@ -47,6 +47,15 @@ class FederatedSecureSumTest(absltest.TestCase):
self
.
assertEqual
(
intrinsic
.
type_signature
.
compact_representation
(),
'<int32,<int32,int32>>@SERVER'
)
def
test_type_signature_with_structure_of_ints_scalar_bitwidth
(
self
):
value
=
intrinsics
.
federated_value
([
1
,
[
1
,
1
]],
placement_literals
.
CLIENTS
)
bitwidth
=
8
intrinsic
=
intrinsics
.
federated_secure_sum
(
value
,
bitwidth
)
self
.
assertEqual
(
intrinsic
.
type_signature
.
compact_representation
(),
'<int32,<int32,int32>>@SERVER'
)
def
test_type_signature_with_one_tensor_and_bitwidth
(
self
):
value
=
intrinsics
.
federated_value
(
np
.
ndarray
(
shape
=
(
5
,
37
),
dtype
=
np
.
int16
),
placement_literals
.
CLIENTS
)
...
...
@@ -84,7 +93,7 @@ class FederatedSecureSumTest(absltest.TestCase):
def
test_raises_type_error_with_different_structures
(
self
):
value
=
intrinsics
.
federated_value
([
1
,
[
1
,
1
]],
placement_literals
.
CLIENTS
)
bitwidth
=
8
bitwidth
=
[
8
,
4
,
2
]
with
self
.
assertRaises
(
TypeError
):
intrinsics
.
federated_secure_sum
(
value
,
bitwidth
)
...
...
tensorflow_federated/python/core/impl/types/type_analysis.py
浏览文件 @
147a6c60
...
...
@@ -347,10 +347,9 @@ def is_valid_bitwidth_type_for_value_type(
py_typecheck
.
check_type
(
bitwidth_type
,
computation_types
.
Type
)
py_typecheck
.
check_type
(
value_type
,
computation_types
.
Type
)
if
value_type
.
is_tensor
()
and
bitwidth_type
.
is_tensor
():
# Here, `value_type` refers to a tensor. Rather than check that
# `bitwidth_type` is exactly the same, we check that it is a single integer,
# since we want a single bitwidth integer per tensor.
if
bitwidth_type
.
is_tensor
():
# This condition applies to both `value_type` being a tensor or structure,
# as the same integer bitwidth can be used for all values in the structure.
return
bitwidth_type
.
dtype
.
is_integer
and
(
bitwidth_type
.
shape
.
num_elements
()
==
1
)
elif
value_type
.
is_struct
()
and
bitwidth_type
.
is_struct
():
...
...
tensorflow_federated/python/core/impl/types/type_analysis_test.py
浏览文件 @
147a6c60
...
...
@@ -305,6 +305,9 @@ class IsValidBitwidthTypeForValueType(parameterized.TestCase):
(
'different kinds of ints'
,
computation_types
.
TensorType
(
tf
.
int32
),
computation_types
.
TensorType
(
tf
.
int8
)),
(
'single int_for_struct'
,
computation_types
.
TensorType
(
tf
.
int32
),
computation_types
.
StructType
([
tf
.
int32
,
tf
.
int32
])),
)
# pyformat: enable
def
test_returns_true
(
self
,
bitwidth_type
,
value_type
):
...
...
@@ -314,9 +317,6 @@ class IsValidBitwidthTypeForValueType(parameterized.TestCase):
# pyformat: disable
@
parameterized
.
named_parameters
(
(
'single int_for_struct'
,
computation_types
.
TensorType
(
tf
.
int32
),
computation_types
.
StructType
([
tf
.
int32
,
tf
.
int32
])),
(
'miscounted struct'
,
computation_types
.
StructType
([
tf
.
int32
,
tf
.
int32
,
tf
.
int32
]),
computation_types
.
StructType
([
tf
.
int32
,
tf
.
int32
])),
...
...
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录