Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
ec2a5e40
提交
ec2a5e40
编辑于
11月 25, 2020
作者:
Michael Reneer
提交者:
tensorflow-copybara
11月 25, 2020
浏览文件
Move API from the structure module into the public API.
PiperOrigin-RevId: 344310585
上级
ed5c3a9a
变更
5
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/BUILD
浏览文件 @
ec2a5e40
...
...
@@ -26,6 +26,7 @@ py_library(
deps
=
[
":version"
,
"//tensorflow_federated/python/aggregators"
,
"//tensorflow_federated/python/common_libs:structure"
,
"//tensorflow_federated/python/common_libs:tracing"
,
"//tensorflow_federated/python/core/api:computation_base"
,
"//tensorflow_federated/python/core/api:computation_types"
,
...
...
tensorflow_federated/__init__.py
浏览文件 @
ec2a5e40
...
...
@@ -20,6 +20,7 @@ from tensorflow_federated.version import __version__ # pylint: disable=g-bad-im
from
tensorflow_federated.python
import
aggregators
from
tensorflow_federated.python
import
learning
from
tensorflow_federated.python
import
simulation
from
tensorflow_federated.python.common_libs
import
structure
from
tensorflow_federated.python.common_libs
import
tracing
as
profiler
from
tensorflow_federated.python.core
import
backends
from
tensorflow_federated.python.core
import
framework
...
...
tensorflow_federated/python/common_libs/BUILD
浏览文件 @
ec2a5e40
...
...
@@ -92,11 +92,9 @@ py_library(
name
=
"structure"
,
srcs
=
[
"structure.py"
],
srcs_version
=
"PY3"
,
# TODO(b/163856734): Fix the visibility of the `common_libs`.
visibility
=
[
"//
intelligence/federated/aggregators
:__pkg__"
,
"//
tensorflow_federated
:__pkg__"
,
"//tensorflow_federated:internal"
,
"//tensorflow_federated/python/tests:__pkg__"
,
],
deps
=
[
":py_typecheck"
],
)
...
...
tensorflow_federated/python/tests/BUILD
浏览文件 @
ec2a5e40
...
...
@@ -31,10 +31,7 @@ py_test(
srcs
=
[
"canonical_form_test.py"
],
python_version
=
"PY3"
,
srcs_version
=
"PY3"
,
deps
=
[
"//tensorflow_federated"
,
"//tensorflow_federated/python/common_libs:structure"
,
],
deps
=
[
"//tensorflow_federated"
],
)
py_test
(
...
...
tensorflow_federated/python/tests/canonical_form_test.py
浏览文件 @
ec2a5e40
...
...
@@ -18,8 +18,6 @@ import numpy as np
import
tensorflow
as
tf
import
tensorflow_federated
as
tff
from
tensorflow_federated.python.common_libs
import
structure
def
construct_example_training_comp
():
"""Constructs a `tff.templates.IterativeProcess` via the FL API."""
...
...
@@ -124,21 +122,24 @@ class CanonicalFormTest(tff.test.TestCase):
client_data
=
[
sample_batch
]
state_1
=
ip_1
.
initialize
()
server_state_1
,
server_output_1
=
ip_1
.
next
(
state_1
,
[
client_data
])
server_state_1
=
structure
.
from_container
(
server_state_1
,
recursive
=
True
)
server_output_1
=
structure
.
from_container
(
server_output_1
,
recursive
=
True
)
server_state_1_arrays
=
structure
.
flatten
(
server_state_1
)
server_output_1_arrays
=
structure
.
flatten
(
server_output_1
)
server_state_1
=
tff
.
structure
.
from_container
(
server_state_1
,
recursive
=
True
)
server_output_1
=
tff
.
structure
.
from_container
(
server_output_1
,
recursive
=
True
)
server_state_1_arrays
=
tff
.
structure
.
flatten
(
server_state_1
)
server_output_1_arrays
=
tff
.
structure
.
flatten
(
server_output_1
)
state_2
=
ip_2
.
initialize
()
server_state_2
,
server_output_2
=
ip_2
.
next
(
state_2
,
[
client_data
])
server_state_2_arrays
=
structure
.
flatten
(
server_state_2
)
server_output_2_arrays
=
structure
.
flatten
(
server_output_2
)
server_state_2_arrays
=
tff
.
structure
.
flatten
(
server_state_2
)
server_output_2_arrays
=
tff
.
structure
.
flatten
(
server_output_2
)
self
.
assertEmpty
(
server_state_1
.
model_broadcast_state
)
# Note that we cannot simply use assertEqual because the values may differ
# due to floating point issues.
self
.
assertTrue
(
structure
.
is_same_structure
(
server_state_1
,
server_state_2
))
self
.
assertTrue
(
structure
.
is_same_structure
(
server_output_1
,
server_output_2
))
tff
.
structure
.
is_same_structure
(
server_state_1
,
server_state_2
))
self
.
assertTrue
(
tff
.
structure
.
is_same_structure
(
server_output_1
,
server_output_2
))
self
.
assertAllClose
(
server_state_1_arrays
,
server_state_2_arrays
)
self
.
assertAllClose
(
server_output_1_arrays
[:
2
],
server_output_2_arrays
[:
2
])
...
...
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录