simple_fedavg_tf.py 6.3 KB
Newer Older
1
# Copyright 2020, The TensorFlow Federated Authors.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An implementation of the Federated Averaging algorithm.

This is intended to be a minimal stand-alone implementation of Federated
Averaging, suitable for branching as a starting point for algorithm
modifications; see `tff.learning.build_federated_averaging_process` for a
more full-featured implementation.

Based on the paper:

Communication-Efficient Learning of Deep Networks from Decentralized Data
    H. Brendan McMahan, Eider Moore, Daniel Ramage,
    Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017.
    https://arxiv.org/abs/1602.05629
"""

import attr
30
import tensorflow as tf
31
import tensorflow_federated as tff
32

33

34
@attr.s(eq=False, frozen=True, slots=True)
35
36
37
class ClientOutput(object):
  """Structure for outputs returned from clients during federated optimization.

38
39
40
41
42
  Attributes:
    weights_delta: A dictionary of updates to the model's trainable variables.
    client_weight: Weight to be used in a weighted mean when aggregating
      `weights_delta`.
    model_output: A structure matching
43
44
45
46
      `tff.learning.Model.report_local_outputs`, reflecting the results of
      training on the input dataset.
  """
  weights_delta = attr.ib()
47
  client_weight = attr.ib()
48
49
50
  model_output = attr.ib()


51
@attr.s(eq=False, frozen=True, slots=True)
52
class ServerState(object):
53
  """Structure for state on the server.
54

55
56
57
58
  Attributes:
    model_weights: A dictionary of model's trainable variables.
    optimizer_state: Variables of optimizer.
    round_num: The current round in the training process.
59
  """
60
  model_weights = attr.ib()
61
  optimizer_state = attr.ib()
62
63
64
65
66
  round_num = attr.ib()


@attr.s(eq=False, frozen=True, slots=True)
class BroadcastMessage(object):
67
  """Structure for tensors broadcasted by server during federated optimization.
68

69
70
71
72
73
74
  Attributes:
    model_weights: A dictionary of model's trainable tensors.
    round_num: Round index to broadcast. We use `round_num` as an example to
      show how to broadcast auxiliary information that can be helpful on
      clients. It is not explicitly used, but can be applied to enable learning
      rate scheduling.
75
76
77
  """
  model_weights = attr.ib()
  round_num = attr.ib()
78
79
80


@tf.function
81
def server_update(model, server_optimizer, server_state, weights_delta):
82
83
84
  """Updates `server_state` based on `weights_delta`.

  Args:
85
    model: A `KerasModelWrapper` or `tff.learning.Model`.
86
87
    server_optimizer: A `tf.keras.optimizers.Optimizer`. If the optimizer
      creates variables, they must have already been created.
88
    server_state: A `ServerState`, the state to be updated.
89
90
    weights_delta: A nested structure of tensors holding the updates to the
      trainable variables of the model.
91
92
93
94

  Returns:
    An updated `ServerState`.
  """
95
  # Initialize the model with the current state.
96
  model_weights = tff.learning.ModelWeights.from_model(model)
97
98
99
100
  tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                        server_state.model_weights)
  tf.nest.map_structure(lambda v, t: v.assign(t), server_optimizer.variables(),
                        server_state.optimizer_state)
101

102
  # Apply the update to the model.
103
104
105
  neg_weights_delta = [-1.0 * x for x in weights_delta]
  server_optimizer.apply_gradients(
      zip(neg_weights_delta, model_weights.trainable), name='server_update')
106

107
  # Create a new state based on the updated model.
108
  return tff.structure.update_struct(
109
      server_state,
110
111
112
113
114
115
116
      model_weights=model_weights,
      optimizer_state=server_optimizer.variables(),
      round_num=server_state.round_num + 1)


@tf.function
def build_server_broadcast_message(server_state):
117
  """Builds `BroadcastMessage` for broadcasting.
118

119
120
121
122
  This method can be used to post-process `ServerState` before broadcasting.
  For example, perform model compression on `ServerState` to obtain a compressed
  state that is sent in a `BroadcastMessage`.

123
124
125
126
127
128
129
130
131
  Args:
    server_state: A `ServerState`.

  Returns:
    A `BroadcastMessage`.
  """
  return BroadcastMessage(
      model_weights=server_state.model_weights,
      round_num=server_state.round_num)
132
133
134


@tf.function
135
def client_update(model, dataset, server_message, client_optimizer):
136
  """Performans client local training of `model` on `dataset`.
137
138

  Args:
139
140
141
142
143
144
    model: A `tff.learning.Model` to train locally on the client.
    dataset: A 'tf.data.Dataset' representing the clients local dataset.
    server_message: A `BroadcastMessage` from serve containing the initial
      model weights to train.
    client_optimizer: A `tf.keras.optimizers.Optimizer` used to update the local
      model during training.
145
146

  Returns:
147
    A `ClientOutput` instance with a model update to aggregate on the server.
148
  """
149
  model_weights = tff.learning.ModelWeights.from_model(model)
150
  initial_weights = server_message.model_weights
151
152
  tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                        initial_weights)
153

154
  num_examples = tf.constant(0, dtype=tf.int32)
155
  loss_sum = tf.constant(0, dtype=tf.float32)
156
157
158
159
  # Explicit use `iter` for dataset is a trick that makes TFF more robust in
  # GPU simulation and slightly more performant in the unconventional usage
  # of large number of small datasets.
  for batch in iter(dataset):
160
    with tf.GradientTape() as tape:
161
162
      outputs = model.forward_pass(batch)
    grads = tape.gradient(outputs.loss, model_weights.trainable)
163
    client_optimizer.apply_gradients(zip(grads, model_weights.trainable))
164
    batch_size = tf.shape(batch['y'])[0]
165
    num_examples += batch_size
166
    loss_sum += outputs.loss * tf.cast(batch_size, tf.float32)
167

168
169
170
  weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                        model_weights.trainable,
                                        initial_weights.trainable)
171
  client_weight = tf.cast(num_examples, tf.float32)
172
  return ClientOutput(weights_delta, client_weight, loss_sum / client_weight)