Commit 3814d455 authored by Galen Andrew's avatar Galen Andrew Committed by tensorflow-copybara
Browse files

Allows unweighted aggregation for federated_averaging and federated_sgd.

PiperOrigin-RevId: 347907220
parent bfebce02
......@@ -39,6 +39,8 @@ from tensorflow_federated.python.tensorflow_libs import tensor_utils
# Type aliases.
_ModelConstructor = Callable[[], model_lib.Model]
_OptimizerConstructor = Callable[[], tf.keras.optimizers.Optimizer]
AggregationFactory = Union[factory.WeightedAggregationFactory,
factory.UnweightedAggregationFactory]
class ProcessTypeError(Exception):
......@@ -292,10 +294,9 @@ def _build_one_round_computation(
*,
model_fn: _ModelConstructor,
server_optimizer_fn: _OptimizerConstructor,
model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]],
ClientDeltaFn],
compute_client_delta: computation_base.Computation,
broadcast_process: measured_process.MeasuredProcess,
aggregation_process: measured_process.MeasuredProcess,
aggregation_process: Optional[measured_process.MeasuredProcess],
) -> computation_base.Computation:
"""Builds the `next` computation for a model delta averaging process.
......@@ -307,14 +308,14 @@ def _build_one_round_computation(
`tf.keras.optimizers.Optimizer`. *Must* construct and return a new
optimizer when called. Returning captured optimizers from other scopes
will raise errors.
model_to_client_delta_fn: a callable that takes a single no-arg callable
that returns `tff.learning.Model` as an argument and returns a
`ClientDeltaFn` which performs the local training loop and model delta
computation.
compute_client_delta: A `tff.tf_computation` that takes a `tf.data.Dataset`
that provides training examples and `model_utils.ModelWeights` for initial
model weights and returns a ClientOutput.
broadcast_process: a `tff.templates.MeasuredProcess` to broadcast the global
model to the clients.
aggregation_process: a `tff.templates.MeasuredProcess` to aggregate client
model deltas.
model deltas. Must be None if model_update_aggregation_factory is
non-None.
Returns:
A `tff.Computation` that initializes the process. The computation takes
......@@ -369,24 +370,6 @@ def _build_one_round_computation(
dataset_type = computation_types.SequenceType(
dummy_model_for_metadata.input_spec)
@computations.tf_computation(dataset_type, model_weights_type)
@tf.function
def _compute_local_training_and_client_delta(dataset, initial_model_weights):
"""Performs client local model optimization.
Args:
dataset: a `tf.data.Dataset` that provides training examples.
initial_model_weights: a `model_utils.ModelWeights` containing the
starting weights.
Returns:
A `ClientOutput` structure.
"""
with tf.init_scope():
client_delta_fn = model_to_client_delta_fn(model_fn)
client_output = client_delta_fn(dataset, initial_model_weights)
return client_output
broadcast_state = broadcast_process.initialize.type_signature.result.member
aggregation_state = aggregation_process.initialize.type_signature.result.member
......@@ -396,6 +379,9 @@ def _build_one_round_computation(
delta_aggregate_state=aggregation_state,
model_broadcast_state=broadcast_state)
weight_type = compute_client_delta.type_signature.result.weights_delta_weight
unweighted_aggregation = weight_type.is_struct() and not weight_type
@computations.federated_computation(
computation_types.FederatedType(server_state_type, placements.SERVER),
computation_types.FederatedType(dataset_type, placements.CLIENTS))
......@@ -414,11 +400,14 @@ def _build_one_round_computation(
broadcast_output = broadcast_process.next(
server_state.model_broadcast_state, server_state.model)
client_outputs = intrinsics.federated_map(
_compute_local_training_and_client_delta,
(federated_dataset, broadcast_output.result))
aggregation_output = aggregation_process.next(
server_state.delta_aggregate_state, client_outputs.weights_delta,
client_outputs.weights_delta_weight)
compute_client_delta, (federated_dataset, broadcast_output.result))
if unweighted_aggregation:
aggregation_output = aggregation_process.next(
server_state.delta_aggregate_state, client_outputs.weights_delta)
else:
aggregation_output = aggregation_process.next(
server_state.delta_aggregate_state, client_outputs.weights_delta,
client_outputs.weights_delta_weight)
new_global_model, new_optimizer_state = intrinsics.federated_map(
server_update, (server_state.model, aggregation_output.result,
server_state.optimizer_state))
......@@ -561,8 +550,7 @@ def build_model_delta_optimizer_process(
*,
broadcast_process: Optional[measured_process.MeasuredProcess] = None,
aggregation_process: Optional[measured_process.MeasuredProcess] = None,
model_update_aggregation_factory: Optional[
factory.WeightedAggregationFactory] = None,
model_update_aggregation_factory: Optional[AggregationFactory] = None,
) -> iterative_process.IterativeProcess:
"""Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD.
......@@ -624,16 +612,56 @@ def build_model_delta_optimizer_process(
'Must specify only one of `model_update_aggregation_factory` and '
'`AggregationProcess`.')
if model_update_aggregation_factory is None and aggregation_process is None:
model_update_aggregation_factory = mean_factory.MeanFactory()
with tf.Graph().as_default():
model_for_metadata = model_fn()
dataset_type = computation_types.SequenceType(model_for_metadata.input_spec)
if model_update_aggregation_factory is not None:
aggregation_process = model_update_aggregation_factory.create_weighted(
model_weights_type.trainable, computation_types.TensorType(tf.float32))
@computations.tf_computation(dataset_type, model_weights_type)
@tf.function
def _compute_client_delta(dataset, initial_model_weights):
"""Performs client local model optimization.
Args:
dataset: a `tf.data.Dataset` that provides training examples.
initial_model_weights: a `model_utils.ModelWeights` containing the
starting weights.
Returns:
A `ClientOutput` structure.
"""
with tf.init_scope():
client_delta_fn = model_to_client_delta_fn(model_fn)
client_output = client_delta_fn(dataset, initial_model_weights)
return client_output
weight_type = _compute_client_delta.type_signature.result.weights_delta_weight
if aggregation_process is None:
aggregation_process = build_stateless_mean(
model_delta_type=model_weights_type.trainable)
if model_update_aggregation_factory is None:
model_update_aggregation_factory = mean_factory.MeanFactory()
# If weights_delta_weight are empty, use unweighted aggregation.
if weight_type.is_struct() and not weight_type:
if not isinstance(model_update_aggregation_factory,
factory.UnweightedAggregationFactory):
raise TypeError(
f'Unweighted aggregation is requested, but '
f'model_update_aggregation_factory is of type '
f'{type(model_update_aggregation_factory)} which is not an '
f'instance of `tff.aggregators.UnweightedAggregationFactory`.')
aggregation_process = model_update_aggregation_factory.create_unweighted(
model_weights_type.trainable)
else:
if not isinstance(model_update_aggregation_factory,
factory.WeightedAggregationFactory):
raise TypeError(
f'Weighted aggregation is requested, but '
f'model_update_aggregation_factory is of type '
f'{type(model_update_aggregation_factory)} which is not an '
f'instance of `tff.aggregators.WeightedAggregationFactory`.')
aggregation_process = model_update_aggregation_factory.create_weighted(
model_weights_type.trainable, weight_type)
if not _is_valid_aggregation_process(aggregation_process):
raise ProcessTypeError(
'aggregation_process type signature does not conform to expected '
......@@ -649,7 +677,7 @@ def build_model_delta_optimizer_process(
run_one_round_computation = _build_one_round_computation(
model_fn=model_fn,
server_optimizer_fn=server_optimizer_fn,
model_to_client_delta_fn=model_to_client_delta_fn,
compute_client_delta=_compute_client_delta,
broadcast_process=broadcast_process,
aggregation_process=aggregation_process)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment