Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
73a6e4ff
提交
73a6e4ff
编辑于
11月 19, 2021
作者:
Taylor Cramer
提交者:
tensorflow-copybara
11月 19, 2021
浏览文件
Use OrderedDict or tuple to represent structures with unknown Python containers
PiperOrigin-RevId: 411172320
上级
f9bab81f
变更
13
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/aggregators/primitives_test.py
浏览文件 @
73a6e4ff
...
...
@@ -268,11 +268,10 @@ class FederatedSampleTest(tf.test.TestCase):
x1
=
-
1.0
y1
=
5.0
test_type
=
collections
.
namedtuple
(
'NestedScalars'
,
[
'x'
,
'y'
])
value
=
call_federated_sample
(
result
=
call_federated_sample
(
[
test_type
(
x0
,
y0
),
test_type
(
x1
,
y1
),
test_type
(
2.0
,
-
10.0
)])
result
=
value
.
_asdict
()
i0
=
list
(
result
[
'x'
]).
index
(
x0
)
i1
=
list
(
result
[
'y'
]).
index
(
y1
)
...
...
@@ -334,11 +333,10 @@ class FederatedSampleTest(tf.test.TestCase):
x
=
0.0
y
=
5.0
test_type
=
collections
.
namedtuple
(
'NestedScalars'
,
[
'x'
,
'y'
])
value
=
call_federated_sample
(
result
=
call_federated_sample
(
[
test_type
(
x
,
y
),
test_type
(
3.4
,
5.6
),
test_type
(
1.0
,
1.0
)])
result
=
value
.
_asdict
()
self
.
assertIn
(
y
,
result
[
'y'
])
self
.
assertIn
(
x
,
result
[
'x'
])
...
...
@@ -360,13 +358,13 @@ class FederatedSampleTest(tf.test.TestCase):
tuple_type
=
collections
.
namedtuple
(
'NestedScalars'
,
[
'x'
,
'y'
])
dict_type
=
collections
.
namedtuple
(
'NestedScalars'
,
[
'a'
,
'b'
])
value
=
call_federated_sample
([
result
=
call_federated_sample
([
nested_test_type
(
tuple_type
(
1.2
,
2.2
),
dict_type
(
1.3
,
8.8
)),
nested_test_type
(
tuple_type
(
-
9.1
,
3.1
),
dict_type
(
1.2
,
-
5.4
))
])
.
_asdict
(
recursive
=
True
)
])
self
.
assertIn
(
1.2
,
value
[
'tuple_1'
][
'x'
])
self
.
assertIn
(
8.8
,
value
[
'tuple_2'
][
'b'
])
self
.
assertIn
(
1.2
,
result
[
'tuple_1'
][
'x'
])
self
.
assertIn
(
8.8
,
result
[
'tuple_2'
][
'b'
])
class
SecureQuantizedSumStaticAssertsTest
(
tf
.
test
.
TestCase
,
...
...
tensorflow_federated/python/aggregators/sampling.py
浏览文件 @
73a6e4ff
...
...
@@ -70,11 +70,11 @@ def _build_reservoir_type(
# TODO(b/181365504): relax this to allow `StructType` once a `Struct` can be
# returned from `tf.function` decorated methods.
def
is_te
s
nor_or_struct_with_py_type
(
t
:
computation_types
.
Type
)
->
bool
:
def
is_ten
s
or_or_struct_with_py_type
(
t
:
computation_types
.
Type
)
->
bool
:
return
t
.
is_tensor
()
or
t
.
is_struct_with_python
()
if
not
type_analysis
.
contains_only
(
sample_value_type
,
is_te
s
nor_or_struct_with_py_type
):
is_ten
s
or_or_struct_with_py_type
):
raise
TypeError
(
'Cannot create a reservoir for type structure. Sample type '
'must only contain `TensorType` or `StructWithPythonType`, '
f
'got a
{
sample_value_type
!
r
}
.'
)
...
...
tensorflow_federated/python/common_libs/structure.py
浏览文件 @
73a6e4ff
...
...
@@ -260,7 +260,7 @@ def iter_elements(struct: Struct) -> Iterator[Tuple[Optional[str], Any]]:
# pylint: enable=protected-access
def
to_odict
(
struct
:
Struct
,
recursive
=
False
):
def
to_odict
(
struct
:
Struct
,
recursive
=
False
)
->
collections
.
OrderedDict
:
"""Returns `struct` as an `OrderedDict`, if possible.
Args:
...
...
@@ -285,7 +285,9 @@ def to_odict(struct: Struct, recursive=False):
return
_to_odict
(
to_elements
(
struct
))
def
to_odict_or_tuple
(
struct
:
Struct
,
recursive
=
True
):
def
to_odict_or_tuple
(
struct
:
Struct
,
recursive
=
True
)
->
Union
[
collections
.
OrderedDict
,
Tuple
[
Any
,
...]]:
"""Returns `struct` as an `OrderedDict` or `tuple`, if possible.
If all elements of `struct` have names, convert `struct` to an
...
...
@@ -304,16 +306,13 @@ def to_odict_or_tuple(struct: Struct, recursive=True):
def
_to_odict_or_tuple
(
elements
):
field_is_named
=
tuple
(
name
is
not
None
for
name
,
_
in
elements
)
has_names
=
any
(
field_is_named
)
is_all_named
=
all
(
field_is_named
)
if
is_all_named
:
if
any
(
field_is_named
):
if
not
all
(
field_is_named
):
raise
ValueError
(
'Cannot convert a `Struct` with both named and unnamed '
'entries to an OrderedDict or tuple: {!r}'
.
format
(
struct
))
return
collections
.
OrderedDict
(
elements
)
elif
not
has_names
:
return
tuple
(
value
for
_
,
value
in
elements
)
else
:
raise
ValueError
(
'Cannot convert an `Struct` with both named and unnamed '
'entries to an OrderedDict or tuple: {!r}'
.
format
(
struct
))
return
tuple
(
value
for
_
,
value
in
elements
)
if
recursive
:
return
to_container_recursive
(
struct
,
_to_odict_or_tuple
)
...
...
@@ -657,16 +656,17 @@ def update_struct(structure, **kwargs):
# In Python 3.8 and later `_asdict` no longer return OrdereDict, rather a
# regular `dict`, so we wrap here to get consistent types across Python
# version.s
d
=
collections
.
OrderedDict
(
structure
.
_asdict
())
d
ictionary
=
collections
.
OrderedDict
(
structure
.
_asdict
())
elif
py_typecheck
.
is_attrs
(
structure
):
d
=
attr
.
asdict
(
structure
,
dict_factory
=
collections
.
OrderedDict
)
d
ictionary
=
attr
.
asdict
(
structure
,
dict_factory
=
collections
.
OrderedDict
)
else
:
for
key
in
kwargs
:
if
key
not
in
structure
:
raise
KeyError
(
'structure does not contain a field named "{!s}"'
.
format
(
key
))
d
=
structure
d
.
update
(
kwargs
)
# Create a copy to prevent mutation of the original `structure`
dictionary
=
type
(
structure
)(
**
structure
)
dictionary
.
update
(
kwargs
)
if
isinstance
(
structure
,
collections
.
abc
.
Mapping
):
return
d
return
type
(
structure
)(
**
d
)
return
d
ictionary
return
type
(
structure
)(
**
d
ictionary
)
tensorflow_federated/python/common_libs/structure_test.py
浏览文件 @
73a6e4ff
...
...
@@ -14,13 +14,16 @@
import
collections
from
absl.testing
import
parameterized
import
attr
import
tensorflow
as
tf
from
tensorflow_federated.python.common_libs
import
structure
ODict
=
collections
.
OrderedDict
class
StructTest
(
tf
.
test
.
TestCase
):
class
StructTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_new_named
(
self
):
x
=
structure
.
Struct
.
named
(
a
=
1
,
b
=
4
)
...
...
@@ -68,7 +71,7 @@ class StructTest(tf.test.TestCase):
self
.
assertNotEqual
(
x
,
structure
.
Struct
([(
'foo'
,
10
)]))
self
.
assertEqual
(
structure
.
to_elements
(
x
),
v
)
self
.
assertEqual
(
structure
.
to_odict
(
x
),
collections
.
OrderedDict
())
self
.
assertEqual
(
structure
.
to_odict_or_tuple
(
x
),
collections
.
OrderedDict
())
self
.
assertEqual
(
structure
.
to_odict_or_tuple
(
x
),
())
self
.
assertEqual
(
repr
(
x
),
'Struct([])'
)
self
.
assertEqual
(
str
(
x
),
'<>'
)
...
...
@@ -468,28 +471,17 @@ class StructTest(tf.test.TestCase):
tf
.
SparseTensor
(
indices
=
[[
1
]],
values
=
[
2
],
dense_shape
=
[
5
]))
self
.
assertEqual
(
str
(
x
),
'<indices=[[1]],values=[2],dense_shape=[5]>'
)
def
test_to_container_recursive
(
self
):
def
odict
(
**
kwargs
):
return
collections
.
OrderedDict
(
sorted
(
list
(
kwargs
.
items
())))
# Nested OrderedDicts.
s
=
odict
(
a
=
1
,
b
=
2
,
c
=
odict
(
d
=
3
,
e
=
odict
(
f
=
4
,
g
=
5
)))
x
=
structure
.
from_container
(
s
,
recursive
=
True
)
s2
=
x
.
_asdict
(
recursive
=
True
)
self
.
assertEqual
(
s
,
s2
)
# Single OrderedDict.
s
=
odict
(
a
=
1
,
b
=
2
)
x
=
structure
.
from_container
(
s
)
self
.
assertEqual
(
x
.
_asdict
(
recursive
=
True
),
s
)
# Single empty OrderedDict.
s
=
odict
()
x
=
structure
.
from_container
(
s
)
self
.
assertEqual
(
x
.
_asdict
(
recursive
=
True
),
s
)
# Invalid argument.
@
parameterized
.
named_parameters
(
(
'empty'
,
ODict
()),
(
'flat'
,
ODict
(
a
=
1
,
b
=
2
)),
(
'nested'
,
ODict
(
a
=
1
,
b
=
2
,
c
=
ODict
(
d
=
3
,
e
=
ODict
(
f
=
4
,
g
=
5
)))),
)
def
test_from_container_asdict_roundtrip
(
self
,
dict_in
):
structure_repr
=
structure
.
from_container
(
dict_in
,
recursive
=
True
)
dict_out
=
structure_repr
.
_asdict
(
recursive
=
True
)
self
.
assertEqual
(
dict_in
,
dict_out
)
def
test_from_container_raises_on_non_container_argument
(
self
):
with
self
.
assertRaises
(
TypeError
):
structure
.
from_container
(
3
)
...
...
@@ -554,6 +546,12 @@ class StructTest(tf.test.TestCase):
state3
=
structure
.
update_struct
(
state2
,
a
=
8
)
self
.
assertEqual
(
state3
,
{
'a'
:
8
,
'b'
:
2
,
'c'
:
7
})
def
test_update_struct_on_dict_does_not_mutate_original
(
self
):
state
=
collections
.
OrderedDict
(
a
=
1
,
b
=
2
,
c
=
3
)
state2
=
structure
.
update_struct
(
state
,
c
=
7
)
del
state2
self
.
assertEqual
(
state
,
collections
.
OrderedDict
(
a
=
1
,
b
=
2
,
c
=
3
))
def
test_update_struct_ordereddict
(
self
):
state
=
collections
.
OrderedDict
([(
'a'
,
1
),
(
'b'
,
2
),
(
'c'
,
3
)])
state2
=
structure
.
update_struct
(
state
,
c
=
7
)
...
...
@@ -585,53 +583,30 @@ class StructTest(tf.test.TestCase):
with
self
.
assertRaisesRegex
(
KeyError
,
'does not contain a field'
):
structure
.
update_struct
({
'z'
:
1
},
a
=
8
)
def
test_to_ordered_dict_or_tuple
(
self
):
def
odict
(
**
kwargs
):
return
collections
.
OrderedDict
(
sorted
(
list
(
kwargs
.
items
())))
# Nested OrderedDicts.
s
=
odict
(
a
=
1
,
b
=
2
,
c
=
odict
(
d
=
3
,
e
=
odict
(
f
=
4
,
g
=
5
)))
x
=
structure
.
from_container
(
s
,
recursive
=
True
)
self
.
assertEqual
(
s
,
structure
.
to_odict_or_tuple
(
x
))
# Single OrderedDict.
s
=
odict
(
a
=
1
,
b
=
2
)
x
=
structure
.
from_container
(
s
)
self
.
assertEqual
(
structure
.
to_odict_or_tuple
(
x
),
s
)
# Single empty OrderedDict.
s
=
odict
()
x
=
structure
.
from_container
(
s
)
self
.
assertEqual
(
structure
.
to_odict_or_tuple
(
x
),
s
)
# Nested tuples.
s
=
tuple
([
1
,
2
,
tuple
([
3
,
tuple
([
4
,
5
])])])
x
=
structure
.
from_container
(
s
,
recursive
=
True
)
self
.
assertEqual
(
s
,
structure
.
to_odict_or_tuple
(
x
))
# Single tuple.
s
=
tuple
([
1
,
2
])
@
parameterized
.
named_parameters
(
(
'empty_tuple'
,
()),
(
'flat_tuple'
,
(
1
,
2
)),
(
'nested_tuple'
,
(
1
,
2
,
(
3
,
(
4
,
5
)))),
(
'flat_dict'
,
ODict
(
a
=
1
,
b
=
2
)),
(
'nested_dict'
,
ODict
(
a
=
1
,
b
=
2
,
c
=
ODict
(
d
=
3
,
e
=
ODict
(
f
=
4
,
g
=
5
)))),
(
'mixed'
,
ODict
(
a
=
1
,
b
=
2
,
c
=
(
3
,
ODict
(
d
=
4
,
e
=
5
)))),
)
def
test_to_odict_or_tuple_from_container_roundtrip
(
self
,
original
):
structure_repr
=
structure
.
from_container
(
original
,
recursive
=
True
)
out
=
structure
.
to_odict_or_tuple
(
structure_repr
)
self
.
assertEqual
(
original
,
out
)
def
test_to_odict_or_tuple_empty_dict_becomes_empty_tuple
(
self
):
s
=
collections
.
OrderedDict
()
x
=
structure
.
from_container
(
s
)
self
.
assertEqual
(
structure
.
to_odict_or_tuple
(
x
),
s
)
# Struct from a single empty tuple should be converted to an empty
# OrderedDict.
s
=
tuple
()
x
=
structure
.
from_container
(
s
)
self
.
assertEqual
(
structure
.
to_odict_or_tuple
(
x
),
collections
.
OrderedDict
())
# Mixed OrderedDicts and tuples.
s
=
odict
(
a
=
1
,
b
=
2
,
c
=
tuple
([
3
,
odict
(
d
=
4
,
e
=
5
)]))
x
=
structure
.
from_container
(
s
,
recursive
=
True
)
self
.
assertEqual
(
s
,
structure
.
to_odict_or_tuple
(
x
))
self
.
assertEqual
(
structure
.
to_odict_or_tuple
(
x
),
())
# Mixed OrderedDicts and tuples with
recursive
=False.
s
=
od
ict
(
a
=
1
,
b
=
2
,
c
=
tuple
([
3
,
od
ict
(
d
=
4
,
e
=
5
)
]
))
def
test_to_odict_or_tuple_mixed_non
recursive
(
self
):
s
=
OD
ict
(
a
=
1
,
b
=
2
,
c
=
(
3
,
OD
ict
(
d
=
4
,
e
=
5
)))
x
=
structure
.
from_container
(
s
,
recursive
=
False
)
self
.
assertEqual
(
s
,
structure
.
to_odict_or_tuple
(
x
,
recursive
=
False
))
# Struct with
named
and
unnamed
elements should raise error.
def
test_to_odict_or_tuple_raises_on_mixed_
named
_
and
_
unnamed
(
self
):
s
=
[(
None
,
10
),
(
'foo'
,
20
),
(
'bar'
,
30
)]
x
=
structure
.
Struct
(
s
)
with
self
.
assertRaisesRegex
(
ValueError
,
'named and unnamed'
):
...
...
tensorflow_federated/python/core/backends/mapreduce/BUILD
浏览文件 @
73a6e4ff
...
...
@@ -85,7 +85,6 @@ py_test(
":form_utils"
,
":forms"
,
":test_utils"
,
"//tensorflow_federated/python/common_libs:structure"
,
"//tensorflow_federated/python/core/api:computations"
,
"//tensorflow_federated/python/core/api:test_case"
,
"//tensorflow_federated/python/core/backends/reference:reference_context"
,
...
...
tensorflow_federated/python/core/backends/mapreduce/form_utils_test.py
浏览文件 @
73a6e4ff
...
...
@@ -18,7 +18,6 @@ from absl.testing import parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow_federated.python.common_libs
import
structure
from
tensorflow_federated.python.core.api
import
computations
from
tensorflow_federated.python.core.api
import
test_case
from
tensorflow_federated.python.core.backends.mapreduce
import
form_utils
...
...
@@ -638,10 +637,10 @@ class GetMapReduceFormForIterativeProcessTest(MapReduceFormTestCase,
mrf
=
form_utils
.
get_map_reduce_form_for_iterative_process
(
it
)
new_it
=
form_utils
.
get_iterative_process_for_map_reduce_form
(
mrf
)
state
=
new_it
.
initialize
()
self
.
assertEqual
(
state
.
num_rounds
,
0
)
self
.
assertEqual
(
state
[
'
num_rounds
'
]
,
0
)
state
,
metrics
=
new_it
.
next
(
state
,
[[
28.0
],
[
30.0
,
33.0
,
29.0
]])
self
.
assertEqual
(
state
.
num_rounds
,
1
)
self
.
assertEqual
(
state
[
'
num_rounds
'
]
,
1
)
self
.
assertAllClose
(
metrics
,
collections
.
OrderedDict
(
ratio_over_threshold
=
0.5
))
...
...
@@ -743,18 +742,17 @@ class GetMapReduceFormForIterativeProcessTest(MapReduceFormTestCase,
def
test_returns_map_reduce_form_with_secure_sum_bitwidth
(
self
):
mrf
=
self
.
get_map_reduce_form_for_client_to_server_fn
(
lambda
data
:
intrinsics
.
federated_secure_sum_bitwidth
(
data
,
7
))
self
.
assertEqual
(
mrf
.
secure_sum_bitwidth
(),
structure
.
Struct
.
unnamed
(
7
))
self
.
assertEqual
(
mrf
.
secure_sum_bitwidth
(),
(
7
,
))
def
test_returns_map_reduce_form_with_secure_sum_max_input
(
self
):
mrf
=
self
.
get_map_reduce_form_for_client_to_server_fn
(
lambda
data
:
intrinsics
.
federated_secure_sum
(
data
,
12
))
self
.
assertEqual
(
mrf
.
secure_sum_max_input
(),
structure
.
Struct
.
unnamed
(
12
))
self
.
assertEqual
(
mrf
.
secure_sum_max_input
(),
(
12
,
))
def
test_returns_map_reduce_form_with_secure_modular_sum_modulus
(
self
):
mrf
=
self
.
get_map_reduce_form_for_client_to_server_fn
(
lambda
data
:
intrinsics
.
federated_secure_modular_sum
(
data
,
22
))
self
.
assertEqual
(
mrf
.
secure_modular_sum_modulus
(),
structure
.
Struct
.
unnamed
(
22
))
self
.
assertEqual
(
mrf
.
secure_modular_sum_modulus
(),
(
22
,))
class
BroadcastFormTest
(
test_case
.
TestCase
):
...
...
tensorflow_federated/python/core/backends/reference/reference_context_test.py
浏览文件 @
73a6e4ff
...
...
@@ -1012,7 +1012,7 @@ class ReferenceContextTest(test_case.TestCase, parameterized.TestCase):
return
zero_for
([(
'A'
,
tf
.
int32
),
(
'B'
,
tf
.
float32
)])
self
.
assertEqual
(
str
(
foo
.
type_signature
),
'( -> <A=int32,B=float32>)'
)
self
.
assertEqual
(
str
(
foo
()
)
,
'<
A=0,B=0.0
>'
)
self
.
assertEqual
(
foo
(),
collections
.
OrderedDict
(
A
=
0
,
B
=
0.0
)
)
def
test_generic_zero_with_federated_int_on_server
(
self
):
...
...
@@ -1046,10 +1046,10 @@ class ReferenceContextTest(test_case.TestCase, parameterized.TestCase):
'(<x=<A=int32,B=float32>,y=<A=int32,B=float32>> -> <A=int32,B=float32>)'
)
foo_result
=
foo
([
2
,
0.1
],
[
3
,
0.2
])
self
.
assertIsInstance
(
foo_result
,
structure
.
Stru
ct
)
self
.
assertSameElements
(
dir
(
foo_result
),
[
'A'
,
'B'
])
self
.
assertEqual
(
foo_result
.
A
,
5
)
self
.
assertAlmostEqual
(
foo_result
.
B
,
0.3
,
places
=
2
)
self
.
assertIsInstance
(
foo_result
,
collections
.
OrderedDi
ct
)
self
.
assertSameElements
(
foo_result
.
keys
(
),
[
'A'
,
'B'
])
self
.
assertEqual
(
foo_result
[
'A'
],
5
)
# pylint: disable=invalid-sequence-index
self
.
assertAlmostEqual
(
foo_result
[
'B'
]
,
0.3
,
places
=
2
)
# pylint: disable=invalid-sequence-index
def
test_sequence_map_with_list_of_integers
(
self
):
...
...
@@ -1123,17 +1123,16 @@ class ReferenceContextTest(test_case.TestCase, parameterized.TestCase):
str
(
foo
.
type_signature
),
'({<A=float32,B=float32>}@CLIENTS -> <A=float32,B=float32>@SERVER)'
)
self
.
assertEqual
(
str
(
foo
([{
'A'
:
1.0
,
'B'
:
5.0
},
{
'A'
:
2.0
,
'B'
:
6.0
},
{
'A'
:
3.0
,
'B'
:
7.0
}])),
'<A=2.0,B=6.0>'
)
foo
([{
'A'
:
1.0
,
'B'
:
5.0
},
{
'A'
:
2.0
,
'B'
:
6.0
},
{
'A'
:
3.0
,
'B'
:
7.0
}]),
collections
.
OrderedDict
(
A
=
2.0
,
B
=
6.0
))
def
test_federated_zip_at_server
(
self
):
...
...
tensorflow_federated/python/core/impl/types/type_conversions.py
浏览文件 @
73a6e4ff
...
...
@@ -13,7 +13,7 @@
"""Utilities for type conversion, type checking, type inference, etc."""
import
collections
from
typing
import
Any
,
Callable
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
Type
import
attr
import
numpy
as
np
...
...
@@ -348,6 +348,19 @@ def type_from_tensors(tensors):
return
computation_types
.
to_type
(
type_spec
)
def
is_container_type_without_names
(
container_type
:
Type
[
Any
])
->
bool
:
"""Returns whether `container_type`'s elements are unnamed."""
return
(
issubclass
(
container_type
,
(
list
,
tuple
))
and
not
py_typecheck
.
is_named_tuple
(
container_type
))
def
is_container_type_with_names
(
container_type
:
Type
[
Any
])
->
bool
:
"""Returns whether `container_type`'s elements are named."""
return
(
py_typecheck
.
is_named_tuple
(
container_type
)
or
py_typecheck
.
is_attrs
(
container_type
)
or
issubclass
(
container_type
,
dict
))
def
type_to_py_container
(
value
,
type_spec
):
"""Recursively convert `structure.Struct`s to Python containers.
...
...
@@ -366,7 +379,8 @@ def type_to_py_container(value, type_spec):
Raises:
ValueError: If the conversion is not possible due to a mix of named
and unnamed values.
and unnamed values, or if `value` contains names that are mismatched or
not present in the corresponding index of `type_spec`.
"""
if
type_spec
.
is_federated
():
if
type_spec
.
all_equal
:
...
...
@@ -398,44 +412,50 @@ def type_to_py_container(value, type_spec):
return
value
if
not
isinstance
(
value
,
structure
.
Struct
):
# NOTE: When encountering non-
anonymous tuple
s, we assume that
# NOTE: When encountering non-
`structure.Struct`
s, we assume that
# this means that we're attempting to re-convert a value that
# already has the proper containers, and we short-circuit to
# avoid re-converting. This is a possibly dangerous assumption.
return
value
anon_tuple
=
value
def
is_container_type_without_names
(
container_type
):
return
(
issubclass
(
container_type
,
(
list
,
tuple
))
and
not
py_typecheck
.
is_named_tuple
(
container_type
))
def
is_container_type_with_names
(
container_type
):
return
(
py_typecheck
.
is_named_tuple
(
container_type
)
or
py_typecheck
.
is_attrs
(
container_type
)
or
issubclass
(
container_type
,
dict
))
# TODO(b/133228705): Consider requiring StructWithPythonType.
container_type
=
structure_type_spec
.
python_container
or
structure
.
Struct
container_is_anon_tuple
=
structure_type_spec
.
python_container
is
None
container_type
=
structure_type_spec
.
python_container
# Ensure that names are only added, not mismatched or removed
names_from_value
=
structure
.
name_list_with_nones
(
value
)
names_from_type_spec
=
structure
.
name_list_with_nones
(
structure_type_spec
)
for
value_name
,
type_name
in
zip
(
names_from_value
,
names_from_type_spec
):
if
value_name
is
not
None
:
if
value_name
!=
type_name
:
raise
ValueError
(
f
'Cannot convert value with field name `
{
value_name
}
` into a '
f
'type with field name `
{
type_name
}
`.'
)
num_named_elements
=
len
(
dir
(
structure_type_spec
))
num_unnamed_elements
=
len
(
structure_type_spec
)
-
num_named_elements
if
num_named_elements
>
0
and
num_unnamed_elements
>
0
:
raise
ValueError
(
f
'Cannot represent value
{
value
}
with a Python container because it '
'contains a mix of named and unnamed elements.
\n\n
Note: this was '
'previously allowed when using the `tff.structure.Struct` container. '
'This support has been removed: please change to use structures with '
'either all-named or all-unnamed fields.'
)
if
container_type
is
None
:
if
num_named_elements
:
container_type
=
collections
.
OrderedDict
else
:
container_type
=
tuple
# Avoid projecting the `structure.StructType`d TFF value into a Python
# container that is not supported.
if
not
container_is_anon_tuple
:
num_named_elements
=
len
(
dir
(
anon_tuple
))
num_unnamed_elements
=
len
(
anon_tuple
)
-
num_named_elements
if
num_named_elements
>
0
and
num_unnamed_elements
>
0
:
raise
ValueError
(
'Cannot represent value {} with container type {}, '
'because value contains a mix of named and unnamed '
'elements.'
.
format
(
anon_tuple
,
container_type
))
if
(
num_named_elements
>
0
and
is_container_type_without_names
(
container_type
)):
raise
ValueError
(
'Cannot represent value {} with named elements '
'using container type {} which does not support names. In TFF
\'
s '
'typesystem, this corresponds to an implicit downcast'
.
format
(
anon_tuple
,
container_type
))
if
(
num_named_elements
>
0
and
is_container_type_without_names
(
container_type
)):
raise
ValueError
(
'Cannot represent value {} with named elements '
'using container type {} which does not support names. In TFF
\'
s '
'typesystem, this corresponds to an implicit downcast'
.
format
(
value
,
container_type
))
if
(
is_container_type_with_names
(
container_type
)
and
len
(
dir
(
structure_type_spec
))
!=
len
(
anon_tupl
e
)):
len
(
dir
(
structure_type_spec
))
!=
len
(
valu
e
)):
# If the type specifies the names, we have all the information we need.
# Otherwise we must raise here.
raise
ValueError
(
'When packaging as a Python value which requires names, '
...
...
@@ -443,17 +463,17 @@ def type_to_py_container(value, type_spec):
'{} names in type spec {} of length {}, with requested'
'python type {}.'
.
format
(
len
(
dir
(
structure_type_spec
)),
structure_type_spec
,
len
(
anon_tupl
e
),
container_type
))
len
(
valu
e
),
container_type
))
elements
=
[]
for
index
,
(
elem_name
,
elem_type
)
in
enumerate
(
structure
.
iter_elements
(
structure_type_spec
)):
value
=
type_to_py_container
(
anon_tupl
e
[
index
],
elem_type
)
element
=
type_to_py_container
(
valu
e
[
index
],
elem_type
)
if
elem_name
is
None
and
not
container_is_anon_tuple
:
elements
.
append
(
value
)
if
elem_name
is
None
:
elements
.
append
(
element
)
else
:
elements
.
append
((
elem_name
,
value
))
elements
.
append
((
elem_name
,
element
))
if
(
py_typecheck
.
is_named_tuple
(
container_type
)
or
py_typecheck
.
is_attrs
(
container_type
)
or
...
...
tensorflow_federated/python/core/impl/types/type_conversions_test.py
浏览文件 @
73a6e4ff
...
...
@@ -516,7 +516,7 @@ class TypeFromTensorsTest(test_case.TestCase):
class
TypeToPyContainerTest
(
test_case
.
TestCase
):
def
test_
not_anon_
tuple_passthrough
(
self
):
def
test_tuple_passthrough
(
self
):
value
=
(
1
,
2.0
)
result
=
type_conversions
.
type_to_py_container
(
(
1
,
2.0
),
...
...
@@ -524,12 +524,27 @@ class TypeToPyContainerTest(test_case.TestCase):
container_type
=
list
))
self
.
assertEqual
(
result
,
value
)
def
test_anon_tuple_return
(
self
):
anon_tuple
=
structure
.
Struct
([(
None
,
1
),
(
None
,
2.0
)])
def
test_represents_unnamed_fields_as_tuple
(
self
):
input_value
=
structure
.
Struct
([(
None
,
1
),
(
None
,
2.0
)])
input_type
=
computation_types
.
StructType
([
tf
.
int32
,
tf
.
float32
])
self
.
assertEqual
(
type_conversions
.
type_to_py_container
(
anon_tuple
,
computation_types
.
StructType
([
tf
.
int32
,
tf
.
float32
])),
anon_tuple
)
type_conversions
.
type_to_py_container
(
input_value
,
input_type
),
(
1
,
2.0
))
def
test_represents_named_fields_as_odict
(
self
):
input_value
=
structure
.
Struct
([(
'a'
,
1
),
(
'b'
,
2.0
)])
input_type
=
computation_types
.
StructType
([(
'a'
,
tf
.
int32
),
(
'b'
,
tf
.
float32
)])
self
.
assertEqual
(
type_conversions
.
type_to_py_container
(
input_value
,
input_type
),
collections
.
OrderedDict
(
a
=
1
,
b
=
2.0
))
def
test_raises_on_mixed_named_unnamed
(
self
):
input_value
=
structure
.
Struct
([(
'a'
,
1
),
(
None
,
2.0
)])
input_type
=
computation_types
.
StructType
([(
'a'
,
tf
.
int32
),
(
None
,
tf
.
float32
)])