提交 ec2a5e40 编辑于 作者: Michael Reneer's avatar Michael Reneer 提交者: tensorflow-copybara
浏览文件

Move API from the structure module into the public API.

PiperOrigin-RevId: 344310585
上级 ed5c3a9a
......@@ -26,6 +26,7 @@ py_library(
deps = [
":version",
"//tensorflow_federated/python/aggregators",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/common_libs:tracing",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computation_types",
......
......@@ -20,6 +20,7 @@ from tensorflow_federated.version import __version__ # pylint: disable=g-bad-im
from tensorflow_federated.python import aggregators
from tensorflow_federated.python import learning
from tensorflow_federated.python import simulation
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.common_libs import tracing as profiler
from tensorflow_federated.python.core import backends
from tensorflow_federated.python.core import framework
......
......@@ -92,11 +92,9 @@ py_library(
name = "structure",
srcs = ["structure.py"],
srcs_version = "PY3",
# TODO(b/163856734): Fix the visibility of the `common_libs`.
visibility = [
"//intelligence/federated/aggregators:__pkg__",
"//tensorflow_federated:__pkg__",
"//tensorflow_federated:internal",
"//tensorflow_federated/python/tests:__pkg__",
],
deps = [":py_typecheck"],
)
......
......@@ -31,10 +31,7 @@ py_test(
srcs = ["canonical_form_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
"//tensorflow_federated",
"//tensorflow_federated/python/common_libs:structure",
],
deps = ["//tensorflow_federated"],
)
py_test(
......
......@@ -18,8 +18,6 @@ import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow_federated.python.common_libs import structure
def construct_example_training_comp():
"""Constructs a `tff.templates.IterativeProcess` via the FL API."""
......@@ -124,21 +122,24 @@ class CanonicalFormTest(tff.test.TestCase):
client_data = [sample_batch]
state_1 = ip_1.initialize()
server_state_1, server_output_1 = ip_1.next(state_1, [client_data])
server_state_1 = structure.from_container(server_state_1, recursive=True)
server_output_1 = structure.from_container(server_output_1, recursive=True)
server_state_1_arrays = structure.flatten(server_state_1)
server_output_1_arrays = structure.flatten(server_output_1)
server_state_1 = tff.structure.from_container(
server_state_1, recursive=True)
server_output_1 = tff.structure.from_container(
server_output_1, recursive=True)
server_state_1_arrays = tff.structure.flatten(server_state_1)
server_output_1_arrays = tff.structure.flatten(server_output_1)
state_2 = ip_2.initialize()
server_state_2, server_output_2 = ip_2.next(state_2, [client_data])
server_state_2_arrays = structure.flatten(server_state_2)
server_output_2_arrays = structure.flatten(server_output_2)
server_state_2_arrays = tff.structure.flatten(server_state_2)
server_output_2_arrays = tff.structure.flatten(server_output_2)
self.assertEmpty(server_state_1.model_broadcast_state)
# Note that we cannot simply use assertEqual because the values may differ
# due to floating point issues.
self.assertTrue(structure.is_same_structure(server_state_1, server_state_2))
self.assertTrue(
structure.is_same_structure(server_output_1, server_output_2))
tff.structure.is_same_structure(server_state_1, server_state_2))
self.assertTrue(
tff.structure.is_same_structure(server_output_1, server_output_2))
self.assertAllClose(server_state_1_arrays, server_state_2_arrays)
self.assertAllClose(server_output_1_arrays[:2], server_output_2_arrays[:2])
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册