optimizer_utils.py 27.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2018, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Michael Reneer's avatar
Michael Reneer committed
14
"""Common building blocks for federated optimization algorithms."""
15
16
17

import abc
import collections
18
from typing import Callable, List, Optional, Tuple, Union
19

20
import attr
21
import numpy as np
22
import tensorflow as tf
Michael Reneer's avatar
Michael Reneer committed
23

24
25
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import mean_factory
26
from tensorflow_federated.python.common_libs import py_typecheck
27
28
29
30
31
32
33
34
from tensorflow_federated.python.core.api import computation_base
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import intrinsics
from tensorflow_federated.python.core.api import placements
from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.templates import iterative_process
from tensorflow_federated.python.core.templates import measured_process
35
from tensorflow_federated.python.learning import model as model_lib
36
from tensorflow_federated.python.learning import model_utils
37
from tensorflow_federated.python.tensorflow_libs import tensor_utils
38

39
40
41
42
43
# Type aliases.
_ModelConstructor = Callable[[], model_lib.Model]
_OptimizerConstructor = Callable[[], tf.keras.optimizers.Optimizer]


44
45
46
47
48
49
50
51
52
53
class ProcessTypeError(Exception):
  """Error raised when a `MeasuredProcess` does not have the correct type signature."""
  pass


class DisjointArgumentError(Exception):
  """Error raised when two disjoint arguments are specified (only one allowed)."""
  pass


54
@attr.s(eq=False, frozen=True)
55
class ClientOutput(object):
56
57
58
59
60
61
62
63
64
65
66
67
68
  """Structure for outputs returned from clients during federated optimization.

  Fields:
  -   `weights_delta`: a dictionary of updates to the model's trainable
      variables.
  -   `weights_delta_weight`: weight to use in a weighted mean when aggregating
      `weights_delta`.
  -   `model_output`: a structure matching
      `tff.learning.Model.report_local_outputs`, reflecting the results of
      training on the input dataset.
  -   `optimizer_output`: additional metrics or other outputs defined by the
      optimizer.
  """
69
70
71
72
  weights_delta = attr.ib()
  weights_delta_weight = attr.ib()
  model_output = attr.ib()
  optimizer_output = attr.ib()
73
74


75
class ClientDeltaFn(object, metaclass=abc.ABCMeta):
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  """Represents a client computation that produces an update to a model."""

  @abc.abstractproperty
  def variables(self):
    """Returns all the variables of this object.

    Note this only includes variables that are part of the state of this object,
    and not the model variables themselves.

    Returns:
      An iterable of `tf.Variable` objects.
    """
    pass

90
  @abc.abstractmethod
91
  def __call__(self, dataset, initial_weights):
92
93
94
95
96
    """Defines the complete client computation.

    Typically implementations should be decorated with `tf.function`.

    Args:
97
      dataset: a `tf.data.Dataset` producing batches than can be fed to
98
        `tff.learning.Model.forward_pass`.
99
      initial_weights: a dictionary of initial values for all trainable and
Michael Reneer's avatar
Michael Reneer committed
100
101
        non-trainable model variables, keyed by name. This will be supplied by
        the server in Federated Averaging.
102
103
104
105
106

    Returns:
      An `optimizer_utils.ClientOutput` namedtuple.
    """
    pass
107
108


109
@attr.s(eq=False, frozen=True)
110
111
112
113
class ServerState(object):
  """Represents the state of the server carried between rounds.

  Attributes:
114
115
    model: a `ModelWeights` structure, containing Tensors or Variables.
    optimizer_state: a list of Tensors or Variables, in the order returned by
116
      `optimizer.variables()`
117
118
    delta_aggregate_state: state (possibly empty) of the delta_aggregate_fn.
    model_broadcast_state: state (possibly empty) of the model_broadcast_fn.
119
120
121
122
123
  """
  model = attr.ib()
  optimizer_state = attr.ib()
  delta_aggregate_state = attr.ib()
  model_broadcast_state = attr.ib()
124
125


126
127
128
129
130
def state_with_new_model_weights(
    server_state: ServerState,
    trainable_weights: List[np.ndarray],
    non_trainable_weights: List[np.ndarray],
) -> ServerState:
131
132
133
  """Returns a `ServerState` with updated model weights.

  Args:
134
    server_state: a server state object returned by an iterative training
135
      process like `tff.learning.build_federated_averaging_process`.
136
    trainable_weights: a list of `numpy` arrays in the order of the original
137
      model's `trainable_variables`.
138
    non_trainable_weights: a list of `numpy` arrays in the order of the original
Keith Rush's avatar
Keith Rush committed
139
      model's `non_trainable_variables`.
140
141

  Returns:
142
143
    A new server `ServerState` object which can be passed to the `next` method
    of the iterative process.
144
  """
145
  py_typecheck.check_type(server_state, ServerState)
146
  leaf_types = (int, float, np.ndarray, tf.Tensor)
147

148
149
150
  def assert_weight_lists_match(old_value, new_value):
    """Assert two flat lists of ndarrays or tensors match."""
    if isinstance(new_value, leaf_types) and isinstance(old_value, leaf_types):
151
152
      if (old_value.dtype != new_value.dtype or
          old_value.shape != new_value.shape):
153
154
155
        raise TypeError('Element is not the same tensor type. old '
                        f'({old_value.dtype}, {old_value.shape}) != '
                        f'new ({new_value.dtype}, {new_value.shape})')
156
157
    elif (isinstance(new_value, collections.abc.Sequence) and
          isinstance(old_value, collections.abc.Sequence)):
158
159
160
161
162
163
164
165
166
167
168
      if len(old_value) != len(new_value):
        raise TypeError('Model weights have different lengths: '
                        f'(old) {len(old_value)} != (new) {len(new_value)})\n'
                        f'Old values: {old_value}\nNew values: {new_value}')
      for old, new in zip(old_value, new_value):
        assert_weight_lists_match(old, new)
    else:
      raise TypeError('Model weights structures contains types that cannot be '
                      'handled.\nOld weights structure: {old}\n'
                      'New weights structure: {new}\n'
                      'Must be one of (int, float, np.ndarray, tf.Tensor, '
169
                      'collections.abc.Sequence)'.format(
170
171
172
173
174
175
                          old=tf.nest.map_structure(type, old_value),
                          new=tf.nest.map_structure(type, new_value)))

  assert_weight_lists_match(server_state.model.trainable, trainable_weights)
  assert_weight_lists_match(server_state.model.non_trainable,
                            non_trainable_weights)
176
  new_server_state = ServerState(
177
      model=model_utils.ModelWeights(
178
179
180
181
          trainable=trainable_weights, non_trainable=non_trainable_weights),
      optimizer_state=server_state.optimizer_state,
      delta_aggregate_state=server_state.delta_aggregate_state,
      model_broadcast_state=server_state.model_broadcast_state)
182
  return new_server_state
183
184


185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def _apply_delta(
    *,
    optimizer: tf.keras.optimizers.Optimizer,
    model: model_lib.Model,
    delta,
) -> None:
  """Applies `delta` to `model` using `optimizer`."""
  model_variables = model_utils.ModelWeights.from_model(model)
  tf.nest.assert_same_structure(delta, model_variables.trainable)
  grads_and_vars = tf.nest.map_structure(
      lambda x, v: (-1.0 * x, v), tf.nest.flatten(delta),
      tf.nest.flatten(model_variables.trainable))
  # Note: this may create variables inside `optimizer`, for example if this is
  # the first usage of Adam or momentum optmizers.
  optimizer.apply_gradients(grads_and_vars)


def _eagerly_create_optimizer_variables(
    *, model: model_lib.Model,
    optimizer: tf.keras.optimizers.Optimizer) -> List[tf.Variable]:
  """Forces eager construction of the optimizer variables.

  This code is needed both in `server_init` and `server_update` (to introduce
  variables so we can read their initial values for the initial state).

  Args:
211
212
    model: A `tff.learning.Model`.
    optimizer: A `tf.keras.optimizers.Optimizer`.
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

  Returns:
    A list of optimizer variables.
  """
  delta_tensor_spec = tf.nest.map_structure(
      lambda v: tf.TensorSpec.from_tensor(v.read_value()),
      model_utils.ModelWeights.from_model(model).trainable)
  # Trace the function, which forces eager variable creation.
  tf.function(_apply_delta).get_concrete_function(
      optimizer=optimizer, model=model, delta=delta_tensor_spec)
  return optimizer.variables()


# ==============================================================================
# Federated Computations
#
# These constructors setup the system level orchestration logic.
# ==============================================================================


233
def _build_initialize_computation(
234
    *,
235
    model_fn: _ModelConstructor,
236
    server_optimizer_fn: _OptimizerConstructor,
237
238
239
    broadcast_process: measured_process.MeasuredProcess,
    aggregation_process: measured_process.MeasuredProcess,
) -> computation_base.Computation:
240
  """Builds the `initialize` computation for a model delta averaging process.
241
242

  Args:
243
244
245
246
247
248
249
    model_fn: a no-argument callable that constructs and returns a
      `tff.learning.Model`. *Must* construct and return a new model when called.
      Returning captured models from other scopes will raise errors.
    server_optimizer_fn: a no-argument callable that constructs and returns a
      `tf.keras.optimizers.Optimizer`. *Must* construct and return a new
      optimizer when called. Returning captured optimizers from other scopes
      will raise errors.
Michael Reneer's avatar
Michael Reneer committed
250
251
    broadcast_process: a `tff.templates.MeasuredProcess` to broadcast the global
      model to the clients.
252
253
    aggregation_process: a `tff.templates.MeasuredProcess` to aggregate client
      model deltas.
254
255

  Returns:
256
    A `tff.Computation` that initializes the process. The computation takes no
257
258
    arguments and returns a `tuple` of global model weights and server state
    with `tff.SERVER` placement.
259
260
  """

261
  @computations.tf_computation
262
  def server_init() -> Tuple[model_utils.ModelWeights, List[tf.Variable]]:
263
    """Returns initial `tff.learning.framework.ServerState`.
264

265
    Returns:
266
267
      A `tuple` of `tff.learning.framework.ModelWeights` and a `list` of
      `tf.Variable`s for the global optimizer state.
268
269
270
271
272
273
    """
    model = model_fn()
    optimizer = server_optimizer_fn()
    # We must force variable creation for momentum and adaptive optimizers.
    optimizer_vars = _eagerly_create_optimizer_variables(
        model=model, optimizer=optimizer)
274
    return model_utils.ModelWeights.from_model(model), optimizer_vars,
275

276
  @computations.federated_computation()
277
278
  def initialize_computation():
    """Orchestration logic for server model initialization."""
279
280
281
    initial_global_model, initial_global_optimizer_state = intrinsics.federated_eval(
        server_init, placements.SERVER)
    return intrinsics.federated_zip(
282
283
284
285
286
        ServerState(
            model=initial_global_model,
            optimizer_state=initial_global_optimizer_state,
            delta_aggregate_state=aggregation_process.initialize(),
            model_broadcast_state=broadcast_process.initialize()))
287
288
289
290
291
292

  return initialize_computation


def _build_one_round_computation(
    *,
293
    model_fn: _ModelConstructor,
294
295
296
    server_optimizer_fn: _OptimizerConstructor,
    model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]],
                                       ClientDeltaFn],
297
298
299
    broadcast_process: measured_process.MeasuredProcess,
    aggregation_process: measured_process.MeasuredProcess,
) -> computation_base.Computation:
300
  """Builds the `next` computation for a model delta averaging process.
301
302

  Args:
303
304
305
306
307
308
309
310
311
312
313
    model_fn: a no-argument callable that constructs and returns a
      `tff.learning.Model`. *Must* construct and return a new model when called.
      Returning captured models from other scopes will raise errors.
    server_optimizer_fn: a no-argument callable that constructs and returns a
      `tf.keras.optimizers.Optimizer`. *Must* construct and return a new
      optimizer when called. Returning captured optimizers from other scopes
      will raise errors.
    model_to_client_delta_fn: a callable that takes a single no-arg callable
      that returns `tff.learning.Model` as an argument and returns a
      `ClientDeltaFn` which performs the local training loop and model delta
      computation.
Michael Reneer's avatar
Michael Reneer committed
314
315
    broadcast_process: a `tff.templates.MeasuredProcess` to broadcast the global
      model to the clients.
316
317
    aggregation_process: a `tff.templates.MeasuredProcess` to aggregate client
      model deltas.
318
319

  Returns:
320
321
322
    A `tff.Computation` that initializes the process. The computation takes
    a tuple of `(ServerState@SERVER, tf.data.Dataset@CLIENTS)` argument, and
    returns a tuple of `(ServerState@SERVER, metrics@SERVER)`.
323
  """
324
325
326
327
328
329
330
  # TODO(b/124477628): would be nice not to have the construct a throwaway model
  # here just to get the types. After fully moving to TF2.0 and eager-mode, we
  # should re-evaluate what happens here.
  # TODO(b/144382142): Keras name uniquification is probably the main reason we
  # still need this.
  with tf.Graph().as_default():
    dummy_model_for_metadata = model_fn()
331
332
    model_weights_type = model_utils.weights_type_from_model(
        dummy_model_for_metadata)
333
334
335
336
337

    dummy_optimizer = server_optimizer_fn()
    # We must force variable creation for momentum and adaptive optimizers.
    _eagerly_create_optimizer_variables(
        model=dummy_model_for_metadata, optimizer=dummy_optimizer)
338
    optimizer_variable_type = type_conversions.type_from_tensors(
339
340
        dummy_optimizer.variables())

341
342
  @computations.tf_computation(model_weights_type, model_weights_type.trainable,
                               optimizer_variable_type)
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
  @tf.function
  def server_update(global_model, mean_model_delta, optimizer_state):
    """Updates the global model with the mean model update from clients."""
    with tf.init_scope():
      model = model_fn()
      optimizer = server_optimizer_fn()
      # We must force variable creation for momentum and adaptive optimizers.
      _eagerly_create_optimizer_variables(model=model, optimizer=optimizer)
    model_variables = model_utils.ModelWeights.from_model(model)
    optimizer_variables = optimizer.variables()
    # Set the variables to the current global model, the optimizer will
    # update these variables.
    tf.nest.map_structure(lambda a, b: a.assign(b),
                          (model_variables, optimizer_variables),
                          (global_model, optimizer_state))
    # We might have a NaN value e.g. if all of the clients processed had no
    # data, so the denominator in the federated_mean is zero. If we see any
    # NaNs, zero out the whole update.
    # TODO(b/124538167): We should increment a server counter to
    # track the fact a non-finite weights_delta was encountered.
    finite_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
        mean_model_delta)
    # Update the global model variables with the delta as a pseudo-gradient.
    _apply_delta(optimizer=optimizer, model=model, delta=finite_weights_delta)
    return model_variables, optimizer_variables
368

369
370
  dataset_type = computation_types.SequenceType(
      dummy_model_for_metadata.input_spec)
371

372
  @computations.tf_computation(dataset_type, model_weights_type)
373
  @tf.function
374
375
  def _compute_local_training_and_client_delta(dataset, initial_model_weights):
    """Performs client local model optimization.
376

377
378
379
380
    Args:
      dataset: a `tf.data.Dataset` that provides training examples.
      initial_model_weights: a `model_utils.ModelWeights` containing the
        starting weights.
381

382
383
384
    Returns:
      A `ClientOutput` structure.
    """
385
386
    with tf.init_scope():
      client_delta_fn = model_to_client_delta_fn(model_fn)
387
388
389
    client_output = client_delta_fn(dataset, initial_model_weights)
    return client_output

390
391
392
  broadcast_state = broadcast_process.initialize.type_signature.result.member
  aggregation_state = aggregation_process.initialize.type_signature.result.member

393
394
395
  server_state_type = ServerState(
      model=model_weights_type,
      optimizer_state=optimizer_variable_type,
396
397
      delta_aggregate_state=aggregation_state,
      model_broadcast_state=broadcast_state)
398

399
400
401
  @computations.federated_computation(
      computation_types.FederatedType(server_state_type, placements.SERVER),
      computation_types.FederatedType(dataset_type, placements.CLIENTS))
402
403
404
405
406
407
408
409
410
411
412
413
  def one_round_computation(server_state, federated_dataset):
    """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
      `tff.learning.Model.federated_output_computation`, both having
      `tff.SERVER` placement.
    """
414
    broadcast_output = broadcast_process.next(
415
        server_state.model_broadcast_state, server_state.model)
416
    client_outputs = intrinsics.federated_map(
417
418
419
420
421
        _compute_local_training_and_client_delta,
        (federated_dataset, broadcast_output.result))
    aggregation_output = aggregation_process.next(
        server_state.delta_aggregate_state, client_outputs.weights_delta,
        client_outputs.weights_delta_weight)
422
    new_global_model, new_optimizer_state = intrinsics.federated_map(
423
424
        server_update, (server_state.model, aggregation_output.result,
                        server_state.optimizer_state))
425
    new_server_state = intrinsics.federated_zip(
426
        ServerState(new_global_model, new_optimizer_state,
427
                    aggregation_output.state, broadcast_output.state))
428
429
    aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
        client_outputs.model_output)
430
    measurements = intrinsics.federated_zip(
431
432
433
434
435
        collections.OrderedDict(
            broadcast=broadcast_output.measurements,
            aggregation=aggregation_output.measurements,
            train=aggregated_outputs))
    return new_server_state, measurements
436

437
  return one_round_computation
438
439


440
441
def _is_valid_stateful_process(
    process: measured_process.MeasuredProcess) -> bool:
442
  """Validates whether a `MeasuredProcess` is valid for model delta processes.
443

444
445
446
  Valid processes must have `state` and `measurements` placed on the server.
  This method is intended to be used with additional validation on the non-state
  parameters, inputs and result.
447

448
449
  Args:
    process: A measured process to validate.
450

451
452
453
454
455
  Returns:
    `True` iff `process` is a valid stateful process, `False` otherwise.
  """
  init_type = process.initialize.type_signature
  next_type = process.next.type_signature
456
457
458
459
  return (init_type.result.placement is placements.SERVER and
          next_type.parameter[0].placement is placements.SERVER and
          next_type.result.state.placement is placements.SERVER and
          next_type.result.measurements.placement is placements.SERVER)
460
461


462
463
def _is_valid_broadcast_process(
    process: measured_process.MeasuredProcess) -> bool:
464
465
466
467
468
469
470
471
472
473
474
475
  """Validates a `MeasuredProcess` adheres to the broadcast signature.

  A valid broadcast process is one whose argument is placed at `SERVER` and
  whose output is placed at `CLIENTS`.

  Args:
    process: A measured process to validate.

  Returns:
    `True` iff the process is a validate broadcast process, otherwise `False`.
  """
  next_type = process.next.type_signature
476
  return (isinstance(process, measured_process.MeasuredProcess) and
477
          _is_valid_stateful_process(process) and
478
479
          next_type.parameter[1].placement is placements.SERVER and
          next_type.result.result.placement is placements.CLIENTS)
480
481
482


def _is_valid_aggregation_process(
483
    process: measured_process.MeasuredProcess) -> bool:
484
485
486
487
488
489
490
491
492
493
494
495
  """Validates a `MeasuredProcess` adheres to the aggregation signature.

  A valid aggregation process is one whose argument is placed at `SERVER` and
  whose output is placed at `CLIENTS`.

  Args:
    process: A measured process to validate.

  Returns:
    `True` iff the process is a validate aggregation process, otherwise `False`.
  """
  next_type = process.next.type_signature
496
  return (isinstance(process, measured_process.MeasuredProcess) and
497
          _is_valid_stateful_process(process) and
498
499
          next_type.parameter[1].placement is placements.CLIENTS and
          next_type.result.result.placement is placements.SERVER)
500
501
502
503


# ============================================================================

504
NONE_SERVER_TYPE = computation_types.FederatedType((), placements.SERVER)
505
506


507
@computations.federated_computation()
508
def _empty_server_initialization():
509
  return intrinsics.federated_value((), placements.SERVER)
510
511


512
def build_stateless_mean(
513
514
515
    *, model_delta_type: Union[computation_types.StructType,
                               computation_types.TensorType]
) -> measured_process.MeasuredProcess:
516
  """Builds a `MeasuredProcess` that wraps` tff.federated_mean`."""
517

518
519
520
521
  @computations.federated_computation(
      NONE_SERVER_TYPE,
      computation_types.FederatedType(model_delta_type, placements.CLIENTS),
      computation_types.FederatedType(tf.float32, placements.CLIENTS))
522
  def stateless_mean(state, value, weight):
523
    empty_metrics = intrinsics.federated_value((), placements.SERVER)
524
    return measured_process.MeasuredProcessOutput(
525
        state=state,
526
        result=intrinsics.federated_mean(value, weight=weight),
527
        measurements=empty_metrics)
528

529
  return measured_process.MeasuredProcess(
530
531
532
533
      initialize_fn=_empty_server_initialization, next_fn=stateless_mean)


def build_stateless_broadcaster(
534
535
536
    *, model_weights_type: Union[computation_types.StructType,
                                 computation_types.TensorType]
) -> measured_process.MeasuredProcess:
537
538
  """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`."""

539
540
541
542
  @computations.federated_computation(
      NONE_SERVER_TYPE,
      computation_types.FederatedType(model_weights_type, placements.SERVER),
  )
543
  def stateless_broadcast(state, value):
544
    empty_metrics = intrinsics.federated_value((), placements.SERVER)
545
    return measured_process.MeasuredProcessOutput(
546
        state=state,
547
        result=intrinsics.federated_broadcast(value),
548
549
        measurements=empty_metrics)

550
  return measured_process.MeasuredProcess(
551
      initialize_fn=_empty_server_initialization, next_fn=stateless_broadcast)
552
553


554
555
# TODO(b/170208719): remove `aggregation_process` after migration to
# `model_update_aggregation_factory`.
556
def build_model_delta_optimizer_process(
557
    model_fn: _ModelConstructor,
558
559
    model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]],
                                       ClientDeltaFn],
560
    server_optimizer_fn: _OptimizerConstructor,
561
    *,
562
563
    broadcast_process: Optional[measured_process.MeasuredProcess] = None,
    aggregation_process: Optional[measured_process.MeasuredProcess] = None,
564
565
    model_update_aggregation_factory: Optional[
        factory.AggregationProcessFactory] = None,
566
) -> iterative_process.IterativeProcess:
567
  """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD.
568
569

  This provides the TFF orchestration logic connecting the common server logic
570
571
  which applies aggregated model deltas to the server model with a
  `ClientDeltaFn` that specifies how `weight_deltas` are computed on device.
572

573
574
  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.
575
576

  Args:
577
578
579
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
580
581
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
582
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
583
584
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)`.
585
    aggregation_process: A `tff.templates.MeasuredProcess` that aggregates the
586
      model updates on the clients back to the server. It must support the
587
588
589
590
591
592
593
594
      signature `({input_values}@CLIENTS-> output_values@SERVER)`. Must be
      `None` if `model_update_aggregation_factory` is not `None.`
    model_update_aggregation_factory: An optional
      `tff.aggregators.AggregationProcessFactory` that contstructs
      `tff.templates.AggregationProcess` for aggregating the client model
      updates on the server. If `None`, uses a default constructed
      `tff.aggregators.MeanFactory`, creating a stateless mean aggregation. Must
      be `None` if `aggregation_process` is not `None.`
595
596

  Returns:
597
    A `tff.templates.IterativeProcess`.
598
599
600

  Raises:
    ProcessTypeError: if `broadcast_process` or `aggregation_process` do not
601
602
603
604
      conform to the signature of broadcast (SERVER->CLIENTS) or aggregation
      (CLIENTS->SERVER).
    DisjointArgumentError: if both `aggregation_process` and
      `model_update_aggregation_factory` are not `None`.
605
  """
606
607
608
  py_typecheck.check_callable(model_fn)
  py_typecheck.check_callable(model_to_client_delta_fn)
  py_typecheck.check_callable(server_optimizer_fn)
609

610
  model_weights_type = model_utils.weights_type_from_model(model_fn)
611
612
613
614
615
616
617
618

  if broadcast_process is None:
    broadcast_process = build_stateless_broadcaster(
        model_weights_type=model_weights_type)
  if not _is_valid_broadcast_process(broadcast_process):
    raise ProcessTypeError(
        'broadcast_process type signature does not conform to expected '
        'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).'
619
        ' Got: {t}'.format(t=broadcast_process.next.type_signature))
620

621
622
623
624
625
626
627
628
629
630
631
632
633
  if (model_update_aggregation_factory is not None and
      aggregation_process is not None):
    raise DisjointArgumentError(
        'Must specify only one of `model_update_aggregation_factory` and '
        '`AggregationProcess`.')

  if model_update_aggregation_factory is None and aggregation_process is None:
    model_update_aggregation_factory = mean_factory.MeanFactory()

  if model_update_aggregation_factory is not None:
    aggregation_process = model_update_aggregation_factory.create(
        model_weights_type.trainable)

634
635
636
637
638
639
640
  if aggregation_process is None:
    aggregation_process = build_stateless_mean(
        model_delta_type=model_weights_type.trainable)
  if not _is_valid_aggregation_process(aggregation_process):
    raise ProcessTypeError(
        'aggregation_process type signature does not conform to expected '
        'signature (<state@S, input@C> -> <state@S, result@S, measurements@S>).'
641
        ' Got: {t}'.format(t=aggregation_process.next.type_signature))
642
643

  initialize_computation = _build_initialize_computation(
644
645
      model_fn=model_fn,
      server_optimizer_fn=server_optimizer_fn,
646
647
      broadcast_process=broadcast_process,
      aggregation_process=aggregation_process)
648
649
650
651
652

  run_one_round_computation = _build_one_round_computation(
      model_fn=model_fn,
      server_optimizer_fn=server_optimizer_fn,
      model_to_client_delta_fn=model_to_client_delta_fn,
653
654
      broadcast_process=broadcast_process,
      aggregation_process=aggregation_process)
655

656
  return iterative_process.IterativeProcess(
657
      initialize_fn=initialize_computation, next_fn=run_one_round_computation)