Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
Commits
3814d455
Commit
3814d455
authored
Dec 16, 2020
by
Galen Andrew
Committed by
tensorflow-copybara
Dec 16, 2020
Browse files
Allows unweighted aggregation for federated_averaging and federated_sgd.
PiperOrigin-RevId: 347907220
parent
bfebce02
Changes
1
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/learning/framework/optimizer_utils.py
View file @
3814d455
...
...
@@ -39,6 +39,8 @@ from tensorflow_federated.python.tensorflow_libs import tensor_utils
# Type aliases.
_ModelConstructor
=
Callable
[[],
model_lib
.
Model
]
_OptimizerConstructor
=
Callable
[[],
tf
.
keras
.
optimizers
.
Optimizer
]
AggregationFactory
=
Union
[
factory
.
WeightedAggregationFactory
,
factory
.
UnweightedAggregationFactory
]
class
ProcessTypeError
(
Exception
):
...
...
@@ -292,10 +294,9 @@ def _build_one_round_computation(
*
,
model_fn
:
_ModelConstructor
,
server_optimizer_fn
:
_OptimizerConstructor
,
model_to_client_delta_fn
:
Callable
[[
Callable
[[],
model_lib
.
Model
]],
ClientDeltaFn
],
compute_client_delta
:
computation_base
.
Computation
,
broadcast_process
:
measured_process
.
MeasuredProcess
,
aggregation_process
:
measured_process
.
MeasuredProcess
,
aggregation_process
:
Optional
[
measured_process
.
MeasuredProcess
]
,
)
->
computation_base
.
Computation
:
"""Builds the `next` computation for a model delta averaging process.
...
...
@@ -307,14 +308,14 @@ def _build_one_round_computation(
`tf.keras.optimizers.Optimizer`. *Must* construct and return a new
optimizer when called. Returning captured optimizers from other scopes
will raise errors.
model_to_client_delta_fn: a callable that takes a single no-arg callable
that returns `tff.learning.Model` as an argument and returns a
`ClientDeltaFn` which performs the local training loop and model delta
computation.
compute_client_delta: A `tff.tf_computation` that takes a `tf.data.Dataset`
that provides training examples and `model_utils.ModelWeights` for initial
model weights and returns a ClientOutput.
broadcast_process: a `tff.templates.MeasuredProcess` to broadcast the global
model to the clients.
aggregation_process: a `tff.templates.MeasuredProcess` to aggregate client
model deltas.
model deltas. Must be None if model_update_aggregation_factory is
non-None.
Returns:
A `tff.Computation` that initializes the process. The computation takes
...
...
@@ -369,24 +370,6 @@ def _build_one_round_computation(
dataset_type
=
computation_types
.
SequenceType
(
dummy_model_for_metadata
.
input_spec
)
@
computations
.
tf_computation
(
dataset_type
,
model_weights_type
)
@
tf
.
function
def
_compute_local_training_and_client_delta
(
dataset
,
initial_model_weights
):
"""Performs client local model optimization.
Args:
dataset: a `tf.data.Dataset` that provides training examples.
initial_model_weights: a `model_utils.ModelWeights` containing the
starting weights.
Returns:
A `ClientOutput` structure.
"""
with
tf
.
init_scope
():
client_delta_fn
=
model_to_client_delta_fn
(
model_fn
)
client_output
=
client_delta_fn
(
dataset
,
initial_model_weights
)
return
client_output
broadcast_state
=
broadcast_process
.
initialize
.
type_signature
.
result
.
member
aggregation_state
=
aggregation_process
.
initialize
.
type_signature
.
result
.
member
...
...
@@ -396,6 +379,9 @@ def _build_one_round_computation(
delta_aggregate_state
=
aggregation_state
,
model_broadcast_state
=
broadcast_state
)
weight_type
=
compute_client_delta
.
type_signature
.
result
.
weights_delta_weight
unweighted_aggregation
=
weight_type
.
is_struct
()
and
not
weight_type
@
computations
.
federated_computation
(
computation_types
.
FederatedType
(
server_state_type
,
placements
.
SERVER
),
computation_types
.
FederatedType
(
dataset_type
,
placements
.
CLIENTS
))
...
...
@@ -414,11 +400,14 @@ def _build_one_round_computation(
broadcast_output
=
broadcast_process
.
next
(
server_state
.
model_broadcast_state
,
server_state
.
model
)
client_outputs
=
intrinsics
.
federated_map
(
_compute_local_training_and_client_delta
,
(
federated_dataset
,
broadcast_output
.
result
))
aggregation_output
=
aggregation_process
.
next
(
server_state
.
delta_aggregate_state
,
client_outputs
.
weights_delta
,
client_outputs
.
weights_delta_weight
)
compute_client_delta
,
(
federated_dataset
,
broadcast_output
.
result
))
if
unweighted_aggregation
:
aggregation_output
=
aggregation_process
.
next
(
server_state
.
delta_aggregate_state
,
client_outputs
.
weights_delta
)
else
:
aggregation_output
=
aggregation_process
.
next
(
server_state
.
delta_aggregate_state
,
client_outputs
.
weights_delta
,
client_outputs
.
weights_delta_weight
)
new_global_model
,
new_optimizer_state
=
intrinsics
.
federated_map
(
server_update
,
(
server_state
.
model
,
aggregation_output
.
result
,
server_state
.
optimizer_state
))
...
...
@@ -561,8 +550,7 @@ def build_model_delta_optimizer_process(
*
,
broadcast_process
:
Optional
[
measured_process
.
MeasuredProcess
]
=
None
,
aggregation_process
:
Optional
[
measured_process
.
MeasuredProcess
]
=
None
,
model_update_aggregation_factory
:
Optional
[
factory
.
WeightedAggregationFactory
]
=
None
,
model_update_aggregation_factory
:
Optional
[
AggregationFactory
]
=
None
,
)
->
iterative_process
.
IterativeProcess
:
"""Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD.
...
...
@@ -624,16 +612,56 @@ def build_model_delta_optimizer_process(
'Must specify only one of `model_update_aggregation_factory` and '
'`AggregationProcess`.'
)
if
model_update_aggregation_factory
is
None
and
aggregation_process
is
None
:
model_update_aggregation_factory
=
mean_factory
.
MeanFactory
()
with
tf
.
Graph
().
as_default
():
model_for_metadata
=
model_fn
()
dataset_type
=
computation_types
.
SequenceType
(
model_for_metadata
.
input_spec
)
if
model_update_aggregation_factory
is
not
None
:
aggregation_process
=
model_update_aggregation_factory
.
create_weighted
(
model_weights_type
.
trainable
,
computation_types
.
TensorType
(
tf
.
float32
))
@
computations
.
tf_computation
(
dataset_type
,
model_weights_type
)
@
tf
.
function
def
_compute_client_delta
(
dataset
,
initial_model_weights
):
"""Performs client local model optimization.
Args:
dataset: a `tf.data.Dataset` that provides training examples.
initial_model_weights: a `model_utils.ModelWeights` containing the
starting weights.
Returns:
A `ClientOutput` structure.
"""
with
tf
.
init_scope
():
client_delta_fn
=
model_to_client_delta_fn
(
model_fn
)
client_output
=
client_delta_fn
(
dataset
,
initial_model_weights
)
return
client_output
weight_type
=
_compute_client_delta
.
type_signature
.
result
.
weights_delta_weight
if
aggregation_process
is
None
:
aggregation_process
=
build_stateless_mean
(
model_delta_type
=
model_weights_type
.
trainable
)
if
model_update_aggregation_factory
is
None
:
model_update_aggregation_factory
=
mean_factory
.
MeanFactory
()
# If weights_delta_weight are empty, use unweighted aggregation.
if
weight_type
.
is_struct
()
and
not
weight_type
:
if
not
isinstance
(
model_update_aggregation_factory
,
factory
.
UnweightedAggregationFactory
):
raise
TypeError
(
f
'Unweighted aggregation is requested, but '
f
'model_update_aggregation_factory is of type '
f
'
{
type
(
model_update_aggregation_factory
)
}
which is not an '
f
'instance of `tff.aggregators.UnweightedAggregationFactory`.'
)
aggregation_process
=
model_update_aggregation_factory
.
create_unweighted
(
model_weights_type
.
trainable
)
else
:
if
not
isinstance
(
model_update_aggregation_factory
,
factory
.
WeightedAggregationFactory
):
raise
TypeError
(
f
'Weighted aggregation is requested, but '
f
'model_update_aggregation_factory is of type '
f
'
{
type
(
model_update_aggregation_factory
)
}
which is not an '
f
'instance of `tff.aggregators.WeightedAggregationFactory`.'
)
aggregation_process
=
model_update_aggregation_factory
.
create_weighted
(
model_weights_type
.
trainable
,
weight_type
)
if
not
_is_valid_aggregation_process
(
aggregation_process
):
raise
ProcessTypeError
(
'aggregation_process type signature does not conform to expected '
...
...
@@ -649,7 +677,7 @@ def build_model_delta_optimizer_process(
run_one_round_computation
=
_build_one_round_computation
(
model_fn
=
model_fn
,
server_optimizer_fn
=
server_optimizer_fn
,
model_to
_client_delta
_fn
=
model_to
_client_delta
_fn
,
compute
_client_delta
=
_compute
_client_delta
,
broadcast_process
=
broadcast_process
,
aggregation_process
=
aggregation_process
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment