Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
Commits
7681cec7
Commit
7681cec7
authored
Dec 02, 2021
by
Keith Rush
Committed by
tensorflow-copybara
Dec 02, 2021
Browse files
Automatically zips if necessary on ingestion into federated context.
PiperOrigin-RevId: 413745813
parent
0d736a45
Changes
7
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/core/impl/compiler/BUILD
View file @
7681cec7
...
@@ -74,6 +74,7 @@ py_library(
...
@@ -74,6 +74,7 @@ py_library(
"//tensorflow_federated/python/core/impl/types:type_analysis"
,
"//tensorflow_federated/python/core/impl/types:type_analysis"
,
"//tensorflow_federated/python/core/impl/types:type_conversions"
,
"//tensorflow_federated/python/core/impl/types:type_conversions"
,
"//tensorflow_federated/python/core/impl/types:type_serialization"
,
"//tensorflow_federated/python/core/impl/types:type_serialization"
,
"//tensorflow_federated/python/core/impl/types:type_transformations"
,
"//tensorflow_federated/python/core/impl/utils:tensorflow_utils"
,
"//tensorflow_federated/python/core/impl/utils:tensorflow_utils"
,
],
],
)
)
...
...
tensorflow_federated/python/core/impl/compiler/building_block_factory.py
View file @
7681cec7
...
@@ -33,6 +33,7 @@ from tensorflow_federated.python.core.impl.types import placements
...
@@ -33,6 +33,7 @@ from tensorflow_federated.python.core.impl.types import placements
from
tensorflow_federated.python.core.impl.types
import
type_analysis
from
tensorflow_federated.python.core.impl.types
import
type_analysis
from
tensorflow_federated.python.core.impl.types
import
type_conversions
from
tensorflow_federated.python.core.impl.types
import
type_conversions
from
tensorflow_federated.python.core.impl.types
import
type_serialization
from
tensorflow_federated.python.core.impl.types
import
type_serialization
from
tensorflow_federated.python.core.impl.types
import
type_transformations
from
tensorflow_federated.python.core.impl.utils
import
tensorflow_utils
from
tensorflow_federated.python.core.impl.utils
import
tensorflow_utils
...
@@ -1997,3 +1998,115 @@ def apply_binary_operator_with_upcast(
...
@@ -1997,3 +1998,115 @@ def apply_binary_operator_with_upcast(
called
=
building_blocks
.
Call
(
tf_representing_op
,
arg
)
called
=
building_blocks
.
Call
(
tf_representing_op
,
arg
)
return
called
return
called
def
zip_to_match_type
(
*
,
comp_to_zip
:
building_blocks
.
ComputationBuildingBlock
,
target_type
:
computation_types
.
Type
)
->
Optional
[
building_blocks
.
ComputationBuildingBlock
]:
"""Zips computation argument to match target type.
This function will apply the appropriate federated zips to match `comp_to_zip`
to the requested type `target_type`, subject to a few caveats. We will
traverse `computation_types.StructTypes` to match types, so for example we
would zip `<<T@P, R@P>>` to match `<<T, R>@P>`, but we will not traverse
`computation_types.FunctionTypes`. Therefore we would not apply a zip to the
parameter of `(<<T@P, R@P>> -> Q)` to match (<<T, R>@P> -> Q).
If zipping in this manner cannot match the type of `comp_to_zip` to
`target_type`, `None` will be returned.
Args:
comp_to_zip: Instance of `building_blocks.ComputationBuildingBlock` to
traverse and attempt to zip to match `target_type`.
target_type: The type to target when traversing and zipping `comp_to_zip`.
Returns:
Either a potentially transformed version of `comp_to_zip` or `None`,
depending on whether inserting a zip according to the semantics above
can transformed `comp_to_zip` to the requested type.
"""
py_typecheck
.
check_type
(
comp_to_zip
,
building_blocks
.
ComputationBuildingBlock
)
py_typecheck
.
check_type
(
target_type
,
computation_types
.
Type
)
def
_can_be_zipped_into
(
source_type
:
computation_types
.
Type
,
target_type
:
computation_types
.
Type
)
->
bool
:
"""Indicates possibility of the transformation `zip_to_match_type`."""
def
_struct_can_be_zipped_to_federated
(
struct_type
:
computation_types
.
StructType
,
federated_type
:
computation_types
.
FederatedType
)
->
bool
:
placements_encountered
=
set
()
def
_remove_placement
(
subtype
:
computation_types
.
Type
)
->
Tuple
[
computation_types
.
Type
,
bool
]:
if
subtype
.
is_federated
():
placements_encountered
.
add
(
subtype
.
placement
)
return
subtype
.
member
,
True
return
subtype
,
False
unplaced_struct
,
_
=
type_transformations
.
transform_type_postorder
(
struct_type
,
_remove_placement
)
if
not
(
all
(
x
is
federated_type
.
placement
for
x
in
placements_encountered
)):
return
False
if
(
federated_type
.
placement
is
placements
.
CLIENTS
and
federated_type
.
all_equal
):
# There is no all-equal clients zip; return false.
return
False
return
federated_type
.
member
.
is_assignable_from
(
unplaced_struct
)
def
_struct_elem_zippable
(
source_name
,
source_element
,
target_name
,
target_element
):
return
_can_be_zipped_into
(
source_element
,
target_element
)
and
source_name
in
(
target_name
,
None
)
if
source_type
.
is_struct
():
if
target_type
.
is_federated
():
return
_struct_can_be_zipped_to_federated
(
source_type
,
target_type
)
elif
target_type
.
is_struct
():
elements_zippable
=
[]
for
(
s_name
,
s_el
),
(
t_name
,
t_el
)
in
zip
(
structure
.
iter_elements
(
source_type
),
structure
.
iter_elements
(
target_type
)):
elements_zippable
.
append
(
_struct_elem_zippable
(
s_name
,
s_el
,
t_name
,
t_el
))
return
all
(
elements_zippable
)
else
:
return
target_type
.
is_assignable_from
(
source_type
)
def
_zip_to_match
(
*
,
source
:
building_blocks
.
ComputationBuildingBlock
,
target_type
:
computation_types
.
Type
)
->
building_blocks
.
ComputationBuildingBlock
:
if
target_type
.
is_federated
()
and
source
.
type_signature
.
is_struct
():
return
create_federated_zip
(
source
)
elif
target_type
.
is_struct
()
and
source
.
type_signature
.
is_struct
():
zipped_elements
=
[]
# Bind a reference to the source to prevent duplication in the AST.
ref_name
=
next
(
unique_name_generator
(
source
))
ref_to_source
=
building_blocks
.
Reference
(
ref_name
,
source
.
type_signature
)
for
idx
,
((
_
,
t_el
),
(
s_name
,
_
))
in
enumerate
(
zip
(
structure
.
iter_elements
(
target_type
),
structure
.
iter_elements
(
source
.
type_signature
))):
s_selection
=
building_blocks
.
Selection
(
ref_to_source
,
index
=
idx
)
zipped_elements
.
append
(
(
s_name
,
_zip_to_match
(
source
=
s_selection
,
target_type
=
t_el
)))
# Insert binding above the constructed structure.
return
building_blocks
.
Block
([(
ref_name
,
source
)],
building_blocks
.
Struct
(
zipped_elements
))
else
:
# No zipping to be done here.
return
source
if
target_type
.
is_assignable_from
(
comp_to_zip
.
type_signature
):
# No zipping needs to be done; return directly.
return
comp_to_zip
elif
_can_be_zipped_into
(
comp_to_zip
.
type_signature
,
target_type
):
return
_zip_to_match
(
source
=
comp_to_zip
,
target_type
=
target_type
)
else
:
# Zipping cannot be performed here.
return
None
tensorflow_federated/python/core/impl/compiler/building_block_factory_test.py
View file @
7681cec7
...
@@ -1932,5 +1932,117 @@ class SelectOutputFromLambdaTest(test_case.TestCase):
...
@@ -1932,5 +1932,117 @@ class SelectOutputFromLambdaTest(test_case.TestCase):
self
.
assertEqual
(
str
(
tuple_selected
),
'(x -> <x.a.inner,x.b>)'
)
self
.
assertEqual
(
str
(
tuple_selected
),
'(x -> <x.a.inner,x.b>)'
)
class
ZipUpToTest
(
test_case
.
TestCase
):
def
test_zips_struct_of_federated_values
(
self
):
comp
=
building_blocks
.
Struct
([
building_blocks
.
Reference
(
'x'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
)),
building_blocks
.
Reference
(
'y'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
))
])
zippable_type
=
computation_types
.
FederatedType
(
computation_types
.
StructType
([(
None
,
tf
.
int32
),
(
None
,
tf
.
int32
)]),
placements
.
CLIENTS
)
zipped
=
building_block_factory
.
zip_to_match_type
(
comp_to_zip
=
comp
,
target_type
=
zippable_type
)
self
.
assert_types_equivalent
(
zipped
.
type_signature
,
zippable_type
)
def
test_does_not_zip_different_placement_target
(
self
):
comp
=
building_blocks
.
Struct
([
building_blocks
.
Reference
(
'x'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
)),
building_blocks
.
Reference
(
'y'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
))
])
non_zippable_type
=
computation_types
.
FederatedType
(
computation_types
.
StructType
([(
None
,
tf
.
int32
),
(
None
,
tf
.
int32
)]),
placements
.
SERVER
)
zipped
=
building_block_factory
.
zip_to_match_type
(
comp_to_zip
=
comp
,
target_type
=
non_zippable_type
)
self
.
assertIsNone
(
zipped
)
def
test_zips_struct_of_federated_values_under_struct
(
self
):
comp
=
building_blocks
.
Struct
([
building_blocks
.
Struct
([
building_blocks
.
Reference
(
'x'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
)),
building_blocks
.
Reference
(
'y'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
))
])
])
zippable_type
=
computation_types
.
StructType
([
(
None
,
computation_types
.
FederatedType
(
computation_types
.
StructType
([(
None
,
tf
.
int32
),
(
None
,
tf
.
int32
)]),
placements
.
CLIENTS
))
])
zipped
=
building_block_factory
.
zip_to_match_type
(
comp_to_zip
=
comp
,
target_type
=
zippable_type
)
self
.
assert_types_equivalent
(
zipped
.
type_signature
,
zippable_type
)
def
test_assignability_with_names
(
self
):
# This would correspond to an implicit downcast in TFF's typesystem; the
# result would not be assignable to the requested type.
comp
=
building_blocks
.
Struct
([
building_blocks
.
Struct
([
(
'a'
,
building_blocks
.
Reference
(
'x'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
))),
(
'b'
,
building_blocks
.
Reference
(
'y'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
)))
])
])
unnamed_zippable_type
=
computation_types
.
StructType
([
(
None
,
computation_types
.
FederatedType
(
computation_types
.
StructType
([(
None
,
tf
.
int32
),
(
None
,
tf
.
int32
)]),
placements
.
CLIENTS
))
])
named_zippable_type
=
computation_types
.
StructType
([
(
None
,
computation_types
.
FederatedType
(
computation_types
.
StructType
([(
'a'
,
tf
.
int32
),
(
'b'
,
tf
.
int32
)]),
placements
.
CLIENTS
))
])
not_zipped
=
building_block_factory
.
zip_to_match_type
(
comp_to_zip
=
comp
,
target_type
=
unnamed_zippable_type
)
zipped
=
building_block_factory
.
zip_to_match_type
(
comp_to_zip
=
comp
,
target_type
=
named_zippable_type
)
self
.
assertFalse
(
unnamed_zippable_type
.
is_assignable_from
(
named_zippable_type
))
self
.
assertIsNone
(
not_zipped
)
self
.
assert_types_equivalent
(
zipped
.
type_signature
,
named_zippable_type
)
def
test_does_not_zip_under_function
(
self
):
result_comp
=
building_blocks
.
Struct
([
building_blocks
.
Reference
(
'x'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
)),
building_blocks
.
Reference
(
'y'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
))
])
lam
=
building_blocks
.
Lambda
(
None
,
None
,
result_comp
)
zippable_function_type
=
computation_types
.
FunctionType
(
None
,
computation_types
.
FederatedType
(
computation_types
.
StructType
([(
None
,
tf
.
int32
),
(
None
,
tf
.
int32
)]),
placements
.
CLIENTS
))
zipped
=
building_block_factory
.
zip_to_match_type
(
comp_to_zip
=
lam
,
target_type
=
zippable_function_type
)
self
.
assertIsNone
(
zipped
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_case
.
main
()
test_case
.
main
()
tensorflow_federated/python/core/impl/federated_context/BUILD
View file @
7681cec7
...
@@ -74,8 +74,11 @@ py_test(
...
@@ -74,8 +74,11 @@ py_test(
":federated_computation_context"
,
":federated_computation_context"
,
":value_impl"
,
":value_impl"
,
"//tensorflow_federated/python/core/api:computations"
,
"//tensorflow_federated/python/core/api:computations"
,
"//tensorflow_federated/python/core/api:test_case"
,
"//tensorflow_federated/python/core/impl/compiler:building_blocks"
,
"//tensorflow_federated/python/core/impl/compiler:building_blocks"
,
"//tensorflow_federated/python/core/impl/context_stack:context_stack_impl"
,
"//tensorflow_federated/python/core/impl/context_stack:context_stack_impl"
,
"//tensorflow_federated/python/core/impl/types:computation_types"
,
"//tensorflow_federated/python/core/impl/types:placements"
,
],
],
)
)
...
...
tensorflow_federated/python/core/impl/federated_context/federated_computation_context.py
View file @
7681cec7
...
@@ -98,12 +98,11 @@ class FederatedComputationContext(symbol_binding_context.SymbolBindingContext):
...
@@ -98,12 +98,11 @@ class FederatedComputationContext(symbol_binding_context.SymbolBindingContext):
return
self
.
_symbol_bindings
return
self
.
_symbol_bindings
def
ingest
(
self
,
val
,
type_spec
):
def
ingest
(
self
,
val
,
type_spec
):
val
=
value_impl
.
to_value
(
val
,
type_spec
,
self
.
_context_stack
)
val
=
value_impl
.
to_value
(
val
,
type_spec
,
zip_if_needed
=
True
)
type_analysis
.
check_type
(
val
,
type_spec
)
return
val
return
val
def
invoke
(
self
,
comp
,
arg
):
def
invoke
(
self
,
comp
,
arg
):
fn
=
value_impl
.
to_value
(
comp
,
None
,
self
.
_context_stack
)
fn
=
value_impl
.
to_value
(
comp
,
None
)
tys
=
fn
.
type_signature
tys
=
fn
.
type_signature
py_typecheck
.
check_type
(
tys
,
computation_types
.
FunctionType
)
py_typecheck
.
check_type
(
tys
,
computation_types
.
FunctionType
)
if
arg
is
not
None
:
if
arg
is
not
None
:
...
...
tensorflow_federated/python/core/impl/federated_context/federated_computation_context_test.py
View file @
7681cec7
...
@@ -12,17 +12,19 @@
...
@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
absl.testing
import
absltest
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow_federated.python.core.api
import
computations
from
tensorflow_federated.python.core.api
import
computations
from
tensorflow_federated.python.core.api
import
test_case
from
tensorflow_federated.python.core.impl.compiler
import
building_blocks
from
tensorflow_federated.python.core.impl.compiler
import
building_blocks
from
tensorflow_federated.python.core.impl.context_stack
import
context_stack_impl
from
tensorflow_federated.python.core.impl.context_stack
import
context_stack_impl
from
tensorflow_federated.python.core.impl.federated_context
import
federated_computation_context
from
tensorflow_federated.python.core.impl.federated_context
import
federated_computation_context
from
tensorflow_federated.python.core.impl.federated_context
import
value_impl
from
tensorflow_federated.python.core.impl.federated_context
import
value_impl
from
tensorflow_federated.python.core.impl.types
import
computation_types
from
tensorflow_federated.python.core.impl.types
import
placements
class
FederatedComputationContextTest
(
absl
test
.
TestCase
):
class
FederatedComputationContextTest
(
test
_case
.
TestCase
):
def
test_invoke_returns_value_with_correct_type
(
self
):
def
test_invoke_returns_value_with_correct_type
(
self
):
context
=
federated_computation_context
.
FederatedComputationContext
(
context
=
federated_computation_context
.
FederatedComputationContext
(
...
@@ -33,6 +35,46 @@ class FederatedComputationContextTest(absltest.TestCase):
...
@@ -33,6 +35,46 @@ class FederatedComputationContextTest(absltest.TestCase):
self
.
assertIsInstance
(
result
,
value_impl
.
Value
)
self
.
assertIsInstance
(
result
,
value_impl
.
Value
)
self
.
assertEqual
(
str
(
result
.
type_signature
),
'int32'
)
self
.
assertEqual
(
str
(
result
.
type_signature
),
'int32'
)
def
test_ingest_zips_value_when_necessary_to_match_federated_type
(
self
):
context
=
federated_computation_context
.
FederatedComputationContext
(
context_stack_impl
.
context_stack
)
# This thing will be <{int}@C, {int}@C>
comp
=
building_blocks
.
Struct
([
building_blocks
.
Reference
(
'x'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
)),
building_blocks
.
Reference
(
'y'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
))
])
# The type of comp can be zipped to the below.
zippable_type
=
computation_types
.
FederatedType
(
computation_types
.
StructType
([(
None
,
tf
.
int32
),
(
None
,
tf
.
int32
)]),
placements
.
CLIENTS
)
ingested
=
context
.
ingest
(
comp
,
type_spec
=
zippable_type
)
self
.
assert_types_equivalent
(
ingested
.
type_signature
,
zippable_type
)
def
test_ingest_zips_federated_under_struct
(
self
):
context
=
federated_computation_context
.
FederatedComputationContext
(
context_stack_impl
.
context_stack
)
comp
=
building_blocks
.
Struct
([
building_blocks
.
Struct
([
building_blocks
.
Reference
(
'x'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
)),
building_blocks
.
Reference
(
'y'
,
computation_types
.
FederatedType
(
tf
.
int32
,
placements
.
CLIENTS
))
])
])
# The type of comp can be zipped to the below.
zippable_type
=
computation_types
.
StructType
([
(
None
,
computation_types
.
FederatedType
(
computation_types
.
StructType
([(
None
,
tf
.
int32
),
(
None
,
tf
.
int32
)]),
placements
.
CLIENTS
))
])
ingested
=
context
.
ingest
(
comp
,
type_spec
=
zippable_type
)
self
.
assert_types_equivalent
(
ingested
.
type_signature
,
zippable_type
)
def
test_construction_populates_name
(
self
):
def
test_construction_populates_name
(
self
):
context
=
federated_computation_context
.
FederatedComputationContext
(
context
=
federated_computation_context
.
FederatedComputationContext
(
context_stack_impl
.
context_stack
)
context_stack_impl
.
context_stack
)
...
@@ -94,4 +136,4 @@ class FederatedComputationContextTest(absltest.TestCase):
...
@@ -94,4 +136,4 @@ class FederatedComputationContextTest(absltest.TestCase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
absl
test
.
main
()
test
_case
.
main
()
tensorflow_federated/python/core/impl/federated_context/value_impl.py
View file @
7681cec7
...
@@ -298,6 +298,8 @@ def to_value(
...
@@ -298,6 +298,8 @@ def to_value(
arg
:
Any
,
arg
:
Any
,
type_spec
,
type_spec
,
parameter_type_hint
=
None
,
parameter_type_hint
=
None
,
*
,
zip_if_needed
:
bool
=
False
,
)
->
Value
:
)
->
Value
:
"""Converts the argument into an instance of the abstract class `tff.Value`.
"""Converts the argument into an instance of the abstract class `tff.Value`.
...
@@ -339,6 +341,9 @@ def to_value(
...
@@ -339,6 +341,9 @@ def to_value(
parameter_type_hint: An optional `tff.Type` or value convertible to it by
parameter_type_hint: An optional `tff.Type` or value convertible to it by
`tff.to_type()` which specifies an argument type to use in the case that
`tff.to_type()` which specifies an argument type to use in the case that
`arg` is a `function_utils.PolymorphicComputation`.
`arg` is a `function_utils.PolymorphicComputation`.
zip_if_needed: If `True`, attempt to coerce the result of `to_value` to
match `type_spec` by applying `intrinsics.federated_zip` to appropriate
elements.
Returns:
Returns:
An instance of `tff.Value` as described above.
An instance of `tff.Value` as described above.
...
@@ -409,6 +414,12 @@ def to_value(
...
@@ -409,6 +414,12 @@ def to_value(
py_typecheck
.
check_type
(
result
,
Value
)
py_typecheck
.
check_type
(
result
,
Value
)
if
(
type_spec
is
not
None
and
if
(
type_spec
is
not
None
and
not
type_spec
.
is_assignable_from
(
result
.
type_signature
)):
not
type_spec
.
is_assignable_from
(
result
.
type_signature
)):
if
zip_if_needed
:
# Returns `None` if such a zip can't be performed.
zipped_comp
=
building_block_factory
.
zip_to_match_type
(
comp_to_zip
=
result
.
comp
,
target_type
=
type_spec
)
if
zipped_comp
is
not
None
:
return
Value
(
zipped_comp
)
raise
TypeError
(
raise
TypeError
(
'The supplied argument maps to TFF type {}, which is incompatible with '
'The supplied argument maps to TFF type {}, which is incompatible with '
'the requested type {}.'
.
format
(
result
.
type_signature
,
type_spec
))
'the requested type {}.'
.
format
(
result
.
type_signature
,
type_spec
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment