Skip to content
Snippets Groups Projects
Commit 9cf02534 authored by Scott Wegner's avatar Scott Wegner Committed by tensorflow-copybara
Browse files

Add example and test case for building an IterativeProcess compatible with CanonicalForm.

PiperOrigin-RevId: 270088547
parent 4e835ced
Branches
Tags
No related merge requests found
......@@ -61,7 +61,10 @@ py_library(
":canonical_form",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:intrinsics",
"//tensorflow_federated/python/core/api:placements",
"//tensorflow_federated/python/core/impl/compiler:building_blocks",
"//tensorflow_federated/python/core/utils:computation_utils",
"//tensorflow_federated/python/learning",
],
)
......
......@@ -86,8 +86,7 @@ class CanonicalFormUtilsTest(absltest.TestCase):
self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
def test_get_canonical_form_for_iterative_process(self):
cf = test_utils.get_temperature_sensor_example()
it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
it = test_utils.get_iterative_process_for_canonical_form_example()
cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
self.assertIsInstance(cf, canonical_form.CanonicalForm)
......@@ -97,7 +96,7 @@ class CanonicalFormUtilsTest(absltest.TestCase):
cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
self.assertIsInstance(cf, canonical_form.CanonicalForm)
def test_temperature_example_round_trip_(self):
def test_temperature_example_round_trip(self):
it = canonical_form_utils.get_iterative_process_for_canonical_form(
test_utils.get_temperature_sensor_example())
cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
......
......@@ -26,8 +26,11 @@ import tensorflow as tf
from tensorflow_federated.python import learning
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import intrinsics
from tensorflow_federated.python.core.api import placements
from tensorflow_federated.python.core.backends.mapreduce import canonical_form
from tensorflow_federated.python.core.impl.compiler import building_blocks
from tensorflow_federated.python.core.utils import computation_utils
def get_temperature_sensor_example():
......@@ -313,3 +316,37 @@ def construct_example_training_comp():
def computation_to_building_block(comp):
return building_blocks.ComputationBuildingBlock.from_proto(
comp._computation_proto) # pylint: disable=protected-access
def get_iterative_process_for_canonical_form_example():
"""Construct a simple `IterativeProcess` compatible with `CanonicalForm`.
The computation itself is non-sensical; but demonstrates the required type
signatures for `CanonicalForm```.
Returns:
An `IterativeProcess` compatible with `CanonicalForm`.
"""
@computations.tf_computation(tf.int32, tf.float32)
def add_two(x_int, y_float):
return tf.cast(x_int, tf.float32) + y_float
@computations.federated_computation
def init_fn():
return intrinsics.federated_value(1.234, placements.SERVER)
@computations.federated_computation([
computation_types.FederatedType(tf.float32, placements.SERVER),
computation_types.FederatedType(tf.int32, placements.CLIENTS)
])
def next_fn(server_val, client_val):
"""Defines a series of federated computations compatible with CanonicalForm."""
broadcast_val = intrinsics.federated_broadcast(server_val)
values_on_clients = intrinsics.federated_zip((client_val, broadcast_val))
result_on_clients = intrinsics.federated_map(add_two, values_on_clients)
aggregated_result = intrinsics.federated_mean(result_on_clients)
side_output = intrinsics.federated_value([1, 2, 3, 4, 5], placements.SERVER)
return aggregated_result, side_output
return computation_utils.IterativeProcess(init_fn, next_fn)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment