提交 0b5abd0d 编辑于 作者: Zachary Charles's avatar Zachary Charles 提交者: tensorflow-copybara
浏览文件

Add `get_model_weights` tf_computation to FedAvg and FedSGD processes created by tff.learning.

PiperOrigin-RevId: 343939365
上级 5965c17c
......@@ -44,6 +44,7 @@ py_library(
":model_utils",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
......@@ -113,6 +114,7 @@ py_library(
":model_utils",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
......
......@@ -28,6 +28,7 @@ import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.templates import iterative_process
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import model as model_lib
......@@ -151,7 +152,7 @@ def build_federated_averaging_process(
This function creates a `tff.templates.IterativeProcess` that performs
federated averaging on client models. The iterative process has the following
methods:
methods inherited from `tff.templates.IterativeProcess`:
* `initialize`: A `tff.Computation` with the functional type signature
`( -> S@SERVER)`, where `S` is a `tff.learning.framework.ServerState`
......@@ -166,6 +167,13 @@ def build_federated_averaging_process(
`tff.learning.Model.federated_output_computation` during client training
and any other metrics from broadcast and aggregation processes.
The iterative process also has the following method not inherited from
`tff.templates.IterativeProcess`:
* `get_model_weights`: A `tff.Computation` that takes as input the
a `tff.learning.framework.ServerState`, and returns a
`tff.learning.ModelWeights` containing the state's model weights.
Each time the `next` method is called, the server model is broadcast to each
client using a broadcast function. For each client, one epoch of local
training is performed via the `tf.keras.optimizers.Optimizer.apply_gradients`
......@@ -220,10 +228,19 @@ def build_federated_averaging_process(
return ClientFedAvg(model_fn(), client_optimizer_fn(), client_weight_fn,
use_experimental_simulation_loop)
return optimizer_utils.build_model_delta_optimizer_process(
iter_proc = optimizer_utils.build_model_delta_optimizer_process(
model_fn,
model_to_client_delta_fn=client_fed_avg,
server_optimizer_fn=server_optimizer_fn,
broadcast_process=broadcast_process,
aggregation_process=aggregation_process,
model_update_aggregation_factory=model_update_aggregation_factory)
server_state_type = iter_proc.state_type.member
@computations.tf_computation(server_state_type)
def get_model_weights(server_state):
return server_state.model
iter_proc.get_model_weights = get_model_weights
return iter_proc
......@@ -261,6 +261,33 @@ class FederatedAveragingModelTffTest(test_case.TestCase,
self.assertEqual(metric_outputs['train']['num_examples'], 0)
self.assertTrue(tf.math.is_nan(metric_outputs['train']['loss']))
@test_utils.skip_test_for_multi_gpu
def test_get_model_weights(self):
iterative_process = federated_averaging.build_federated_averaging_process(
model_fn=model_examples.LinearRegression,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1))
num_clients = 3
ds = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[1.0, 2.0], [3.0, 4.0]],
y=[[5.0], [6.0]],
)).batch(2)
datasets = [ds] * num_clients
state = iterative_process.initialize()
self.assertIsInstance(
iterative_process.get_model_weights(state), model_utils.ModelWeights)
self.assertAllClose(state.model.trainable,
iterative_process.get_model_weights(state).trainable)
for _ in range(3):
state, _ = iterative_process.next(state, datasets)
self.assertIsInstance(
iterative_process.get_model_weights(state), model_utils.ModelWeights)
self.assertAllClose(state.model.trainable,
iterative_process.get_model_weights(state).trainable)
if __name__ == '__main__':
execution_contexts.set_local_execution_context()
......
......@@ -27,6 +27,7 @@ import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.templates import iterative_process
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import model as model_lib
......@@ -159,8 +160,8 @@ def build_federated_sgd_process(
"""Builds the TFF computations for optimization using federated SGD.
This function creates a `tff.templates.IterativeProcess` that performs
federated averaging on client models. The iterative process has the following
methods:
federated SGD on client models. The iterative process has the following
methods inherited from `tff.templates.IterativeProcess`:
* `initialize`: A `tff.Computation` with the functional type signature
`( -> S@SERVER)`, where `S` is a `tff.learning.framework.ServerState`
......@@ -175,6 +176,13 @@ def build_federated_sgd_process(
`tff.learning.Model.federated_output_computation` during client training
and any other metrics from broadcast and aggregation processes.
The iterative process also has the following method not inherited from
`tff.templates.IterativeProcess`:
* `get_model_weights`: A `tff.Computation` that takes as input the
a `tff.learning.framework.ServerState`, and returns a
`tff.learning.ModelWeights` containing the state's model weights.
Each time the `next` method is called, the server model is broadcast to each
client using a broadcast function. Each client sums the gradients at each
batch in the client's local dataset. These gradient sums are then aggregated
......@@ -228,10 +236,19 @@ def build_federated_sgd_process(
client_weight_fn,
use_experimental_simulation_loop=use_experimental_simulation_loop)
return optimizer_utils.build_model_delta_optimizer_process(
iter_proc = optimizer_utils.build_model_delta_optimizer_process(
model_fn,
model_to_client_delta_fn=client_sgd_avg,
server_optimizer_fn=server_optimizer_fn,
broadcast_process=broadcast_process,
aggregation_process=aggregation_process,
model_update_aggregation_factory=model_update_aggregation_factory)
server_state_type = iter_proc.state_type.member
@computations.tf_computation(server_state_type)
def get_model_weights(server_state):
return server_state.model
iter_proc.get_model_weights = get_model_weights
return iter_proc
......@@ -224,6 +224,32 @@ class FederatedSGDTffTest(test_case.TestCase, parameterized.TestCase):
self.assertEqual(metric_outputs['train']['num_examples'], 0)
self.assertTrue(tf.math.is_nan(metric_outputs['train']['loss']))
@test_utils.skip_test_for_multi_gpu
def test_get_model_weights(self):
iterative_process = federated_sgd.build_federated_sgd_process(
model_fn=model_examples.LinearRegression)
num_clients = 3
ds = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[1.0, 2.0], [3.0, 4.0]],
y=[[5.0], [6.0]],
)).batch(2)
datasets = [ds] * num_clients
state = iterative_process.initialize()
self.assertIsInstance(
iterative_process.get_model_weights(state), model_utils.ModelWeights)
self.assertAllClose(state.model.trainable,
iterative_process.get_model_weights(state).trainable)
for _ in range(3):
state, _ = iterative_process.next(state, datasets)
self.assertIsInstance(
iterative_process.get_model_weights(state), model_utils.ModelWeights)
self.assertAllClose(state.model.trainable,
iterative_process.get_model_weights(state).trainable)
if __name__ == '__main__':
execution_contexts.set_local_execution_context()
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册