Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
5e1ad040
提交
5e1ad040
编辑于
12月 14, 2020
作者:
Keith Rush
提交者:
tensorflow-copybara
12月 14, 2020
浏览文件
Adds failing test for bad interaction between caching and remote runtime configuration.
PiperOrigin-RevId: 347400725
上级
5721e8f9
变更
2
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/tests/BUILD
浏览文件 @
5e1ad040
...
...
@@ -36,7 +36,7 @@ py_test(
py_test
(
name
=
"perf_regression_test"
,
size
=
"medium
"
,
timeout
=
"moderate
"
,
srcs
=
[
"perf_regression_test.py"
],
python_version
=
"PY3"
,
srcs_version
=
"PY3"
,
...
...
@@ -45,7 +45,7 @@ py_test(
py_test
(
name
=
"remote_runtime_integration_test"
,
size
=
"small
"
,
timeout
=
"moderate
"
,
srcs
=
[
"remote_runtime_integration_test.py"
],
python_version
=
"PY3"
,
srcs_version
=
"PY3"
,
...
...
tensorflow_federated/python/tests/remote_runtime_integration_test.py
浏览文件 @
5e1ad040
...
...
@@ -151,5 +151,57 @@ class StreamingWithIntermediateAggTest(absltest.TestCase):
self
.
assertEqual
(
result
,
3
)
@
parameterized
.
named_parameters
((
'native_remote_request_reply'
,
remote_runtime_test_utils
.
create_localhost_remote_context
(
_WORKER_PORTS
),
remote_runtime_test_utils
.
create_localhost_worker_contexts
(
_WORKER_PORTS
),
),
(
'native_remote_streaming'
,
remote_runtime_test_utils
.
create_localhost_remote_context
(
_WORKER_PORTS
,
rpc_mode
=
'STREAMING'
),
remote_runtime_test_utils
.
create_localhost_worker_contexts
(
_WORKER_PORTS
),
),
(
'native_remote_intermediate_aggregator'
,
remote_runtime_test_utils
.
create_localhost_remote_context
(
_AGGREGATOR_PORTS
),
remote_runtime_test_utils
.
create_localhost_aggregator_contexts
(
_WORKER_PORTS
,
_AGGREGATOR_PORTS
),
))
class
RemoteRuntimeConfigurationChangeTest
(
absltest
.
TestCase
):
def
test_computations_run_with_changing_clients
(
self
,
context
,
server_contexts
):
self
.
skipTest
(
'b/175155128'
)
@
tff
.
tf_computation
(
tf
.
int32
)
@
tf
.
function
def
add_one
(
x
):
return
x
+
1
@
tff
.
federated_computation
(
tff
.
type_at_clients
(
tf
.
int32
))
def
map_add_one
(
federated_arg
):
return
tff
.
federated_map
(
add_one
,
federated_arg
)
context_stack
=
tff
.
framework
.
get_context_stack
()
with
context_stack
.
install
(
context
):
with
contextlib
.
ExitStack
()
as
stack
:
for
server_context
in
server_contexts
:
stack
.
enter_context
(
server_context
)
result_two_clients
=
map_add_one
([
0
,
1
])
self
.
assertEqual
(
result_two_clients
,
[
1
,
2
])
# Moving to three clients should be fine
result_three_clients
=
map_add_one
([
0
,
1
,
2
])
# Running a 0-client function should also be OK
self
.
assertEqual
(
add_one
(
0
),
1
)
self
.
assertEqual
(
result_three_clients
,
[
1
,
2
,
3
])
# Changing back to 2 clients should still succeed.
second_result_two_clients
=
map_add_one
([
0
,
1
])
self
.
assertEqual
(
second_result_two_clients
,
[
1
,
2
])
# Similarly, 3 clients again should be fine.
second_result_three_clients
=
map_add_one
([
0
,
1
,
2
])
self
.
assertEqual
(
second_result_three_clients
,
[
1
,
2
,
3
])
if
__name__
==
'__main__'
:
absltest
.
main
()
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录