Skip to content
Snippets Groups Projects
Commit 96737fde authored by Zachary Garrett's avatar Zachary Garrett Committed by tensorflow-copybara
Browse files

Extend iterative process builder and the LR scheduling iterative process to...

Extend iterative process builder and the LR scheduling iterative process to accept an optional"datasset preprocessing computation.

This allows for pushing the dataset preprocessing methods down to the client
executors, which is required for multimachine simulations since stateful
datasets (e.g. datasets which suffling) cannot be serialized.

PiperOrigin-RevId: 321835968
parent 641558c9
No related branches found
No related tags found
No related merge requests found
......@@ -301,7 +301,8 @@ def build_fed_avg_process(
client_lr: Union[float, LRScheduleFn] = 0.1,
server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD,
server_lr: Union[float, LRScheduleFn] = 1.0,
client_weight_fn: Optional[ClientWeightFn] = None
client_weight_fn: Optional[ClientWeightFn] = None,
dataset_preprocess_comp: Optional[tff.Computation] = None,
) -> FederatedAveragingProcessAdapter:
"""Builds the TFF computations for optimization using federated averaging.
......@@ -319,6 +320,10 @@ def build_fed_avg_process(
`model.report_local_outputs` and returns a tensor that provides the weight
in the federated average of model deltas. If not provided, the default is
the total number of examples processed on device.
dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
pipeline on the clients. The computation must take a squence of values
and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If
`None`, no dataset preprocessing is applied.
Returns:
A `FederatedAveragingProcessAdapter`.
......@@ -342,9 +347,21 @@ def build_fed_avg_process(
model_weights_type = server_state_type.model
round_num_type = server_state_type.round_num
tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
@tff.tf_computation(tf_dataset_type, model_weights_type, round_num_type)
if dataset_preprocess_comp is not None:
tf_dataset_type = dataset_preprocess_comp.type_signature.parameter
model_input_type = tff.SequenceType(dummy_model.input_spec)
preprocessed_dataset_type = dataset_preprocess_comp.type_signature.result
if not model_input_type.is_assignable_from(preprocessed_dataset_type):
raise TypeError('Supplied `dataset_preprocess_comp` does not yield '
'batches that are compatible with the model constructed '
'by `model_fn`. Model expects type {m}, but dataset '
'yields type {d}.'.format(
m=model_input_type, d=preprocessed_dataset_type))
else:
tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
model_input_type = tff.SequenceType(dummy_model.input_spec)
@tff.tf_computation(model_input_type, model_weights_type, round_num_type)
def client_update_fn(tf_dataset, initial_model_weights, round_num):
client_lr = client_lr_schedule(round_num)
client_optimizer = client_optimizer_fn(client_lr)
......@@ -378,6 +395,9 @@ def build_fed_avg_process(
"""
client_model = tff.federated_broadcast(server_state.model)
client_round_num = tff.federated_broadcast(server_state.round_num)
if dataset_preprocess_comp is not None:
federated_dataset = tff.federated_map(dataset_preprocess_comp,
federated_dataset)
client_outputs = tff.federated_map(
client_update_fn,
(federated_dataset, client_model, client_round_num))
......
......@@ -50,12 +50,12 @@ def _uncompiled_model_builder():
class ModelDeltaProcessTest(tf.test.TestCase):
def _run_rounds(self, iterative_process, federated_data, num_rounds):
def _run_rounds(self, iterproc_adapter, federated_data, num_rounds):
train_outputs = []
initial_state = iterative_process.initialize()
initial_state = iterproc_adapter.initialize()
state = initial_state
for round_num in range(num_rounds):
iteration_result = iterative_process.next(state, federated_data)
iteration_result = iterproc_adapter.next(state, federated_data)
train_outputs.append(iteration_result.metrics)
logging.info('Round %d: %s', round_num, iteration_result.metrics)
state = iteration_result.state
......@@ -64,12 +64,12 @@ class ModelDeltaProcessTest(tf.test.TestCase):
def test_fed_avg_without_schedule_decreases_loss(self):
federated_data = [[_batch_fn()]]
iterative_process = fed_avg_schedule.build_fed_avg_process(
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
server_optimizer_fn=tf.keras.optimizers.SGD)
_, train_outputs, _ = self._run_rounds(iterative_process, federated_data, 5)
_, train_outputs, _ = self._run_rounds(iterproc_adapter, federated_data, 5)
self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def test_fed_avg_with_custom_client_weight_fn(self):
......@@ -78,13 +78,13 @@ class ModelDeltaProcessTest(tf.test.TestCase):
def client_weight_fn(local_outputs):
return 1.0/(1.0 + local_outputs['loss'][-1])
iterative_process = fed_avg_schedule.build_fed_avg_process(
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
server_optimizer_fn=tf.keras.optimizers.SGD,
client_weight_fn=client_weight_fn)
_, train_outputs, _ = self._run_rounds(iterative_process, federated_data, 5)
_, train_outputs, _ = self._run_rounds(iterproc_adapter, federated_data, 5)
self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def test_client_update_with_finite_delta(self):
......@@ -112,27 +112,27 @@ class ModelDeltaProcessTest(tf.test.TestCase):
def test_server_update_with_nan_data_is_noop(self):
federated_data = [[_batch_fn(has_nan=True)]]
iterative_process = fed_avg_schedule.build_fed_avg_process(
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
server_optimizer_fn=tf.keras.optimizers.SGD)
state, _, initial_state = self._run_rounds(iterative_process,
federated_data, 1)
state, _, initial_state = self._run_rounds(iterproc_adapter, federated_data,
1)
self.assertAllClose(state.model, initial_state.model, 1e-8)
def test_server_update_with_inf_weight_is_noop(self):
federated_data = [[_batch_fn()]]
client_weight_fn = lambda x: np.inf
iterative_process = fed_avg_schedule.build_fed_avg_process(
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
server_optimizer_fn=tf.keras.optimizers.SGD,
client_weight_fn=client_weight_fn)
state, _, initial_state = self._run_rounds(iterative_process,
federated_data, 1)
state, _, initial_state = self._run_rounds(iterproc_adapter, federated_data,
1)
self.assertAllClose(state.model, initial_state.model, 1e-8)
def test_fed_avg_with_client_schedule(self):
......@@ -142,13 +142,13 @@ class ModelDeltaProcessTest(tf.test.TestCase):
def lr_schedule(x):
return 0.1 if x < 1.5 else 0.0
iterative_process = fed_avg_schedule.build_fed_avg_process(
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
client_lr=lr_schedule,
server_optimizer_fn=tf.keras.optimizers.SGD)
_, train_outputs, _ = self._run_rounds(iterative_process, federated_data, 4)
_, train_outputs, _ = self._run_rounds(iterproc_adapter, federated_data, 4)
self.assertLess(train_outputs[1]['loss'], train_outputs[0]['loss'])
self.assertNear(
train_outputs[2]['loss'], train_outputs[3]['loss'], err=1e-4)
......@@ -160,13 +160,13 @@ class ModelDeltaProcessTest(tf.test.TestCase):
def lr_schedule(x):
return 1.0 if x < 1.5 else 0.0
iterative_process = fed_avg_schedule.build_fed_avg_process(
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
server_optimizer_fn=tf.keras.optimizers.SGD,
server_lr=lr_schedule)
_, train_outputs, _ = self._run_rounds(iterative_process, federated_data, 4)
_, train_outputs, _ = self._run_rounds(iterproc_adapter, federated_data, 4)
self.assertLess(train_outputs[1]['loss'], train_outputs[0]['loss'])
self.assertNear(
train_outputs[2]['loss'], train_outputs[3]['loss'], err=1e-4)
......@@ -174,14 +174,14 @@ class ModelDeltaProcessTest(tf.test.TestCase):
def test_fed_avg_with_client_and_server_schedules(self):
federated_data = [[_batch_fn()]]
iterative_process = fed_avg_schedule.build_fed_avg_process(
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
client_lr=lambda x: 0.1 / (x + 1)**2,
server_optimizer_fn=tf.keras.optimizers.SGD,
server_lr=lambda x: 1.0 / (x + 1)**2)
_, train_outputs, _ = self._run_rounds(iterative_process, federated_data, 6)
_, train_outputs, _ = self._run_rounds(iterproc_adapter, federated_data, 6)
self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2]['loss']
train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5]['loss']
......@@ -190,16 +190,86 @@ class ModelDeltaProcessTest(tf.test.TestCase):
def test_conversion_from_tff_result(self):
federated_data = [[_batch_fn()]]
iterative_process = fed_avg_schedule.build_fed_avg_process(
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
server_optimizer_fn=tf.keras.optimizers.SGD)
state, _, _ = self._run_rounds(iterative_process, federated_data, 1)
state, _, _ = self._run_rounds(iterproc_adapter, federated_data, 1)
converted_state = fed_avg_schedule.ServerState.from_tff_result(state)
self.assertIsInstance(converted_state, fed_avg_schedule.ServerState)
self.assertIsInstance(converted_state.model, fed_avg_schedule.ModelWeights)
def test_build_with_preprocess_function(self):
test_dataset = tf.data.Dataset.range(5)
client_datasets_type = tff.FederatedType(
tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)
@tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
def preprocess_dataset(ds):
def to_batch(x):
return _Batch(
tf.fill(dims=(784,), value=float(x) * 2.0),
tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0))
return ds.map(to_batch).batch(2)
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
server_optimizer_fn=tf.keras.optimizers.SGD,
dataset_preprocess_comp=preprocess_dataset)
with tf.Graph().as_default():
test_model_for_types = _uncompiled_model_builder()
iterproc = iterproc_adapter._iterative_process
server_state_type = tff.FederatedType(
fed_avg_schedule.ServerState(
model=tff.framework.type_from_tensors(
fed_avg_schedule.ModelWeights(
test_model_for_types.trainable_variables,
test_model_for_types.non_trainable_variables)),
optimizer_state=(tf.int64,),
round_num=tf.float32), tff.SERVER)
metrics_type = test_model_for_types.federated_output_computation.type_signature.result
expected_type = tff.FunctionType(
parameter=(server_state_type, client_datasets_type),
result=(server_state_type, metrics_type))
self.assertEqual(
iterproc.next.type_signature,
expected_type,
msg='{s}\n!={t}'.format(
s=iterproc.next.type_signature, t=expected_type))
def test_execute_with_preprocess_function(self):
test_dataset = tf.data.Dataset.range(1)
@tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
def preprocess_dataset(ds):
def to_example(x):
del x # Unused.
return _Batch(
x=np.ones([784], dtype=np.float32), y=np.ones([1], dtype=np.int64))
return ds.map(to_example).batch(1)
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
server_optimizer_fn=tf.keras.optimizers.SGD,
dataset_preprocess_comp=preprocess_dataset)
_, train_outputs, _ = self._run_rounds(iterproc_adapter, [test_dataset], 6)
self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2]['loss']
train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5]['loss']
self.assertLess(train_gap_second_half, train_gap_first_half)
if __name__ == '__main__':
tf.test.main()
......@@ -49,7 +49,9 @@ def from_flags(
model_builder: ModelBuilder,
loss_builder: LossBuilder,
metrics_builder: MetricsBuilder,
client_weight_fn: Optional[ClientWeightFn] = None
client_weight_fn: Optional[ClientWeightFn] = None,
*,
dataset_preprocess_comp: Optional[tff.Computation] = None,
) -> fed_avg_schedule.FederatedAveragingProcessAdapter:
"""Builds a `tff.templates.IterativeProcess` instance from flags.
......@@ -70,6 +72,11 @@ def from_flags(
`tff.learning.Model.report_local_outputs` from the model returned by
`model_builder`, and returns a scalar client weight. If `None`, defaults
to the number of examples processed over all batches.
dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
pipeline on the clients. The computation must take a squence of values
and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If
`None`, no dataset preprocessing is applied. If specified, `input_spec` is
optinal, as the necessary type signatures will taken from the computation.
Returns:
A `fed_avg_schedule.FederatedAveragingProcessAdapter`.
......@@ -82,10 +89,19 @@ def from_flags(
client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client')
server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server')
if dataset_preprocess_comp is not None:
if input_spec is not None:
print('Specified both `dataset_preprocess_comp` and `input_spec` when '
'only one is necessary. Ignoring `input_spec` and using type '
'signature of `dataset_preprocess_comp`.')
model_input_spec = dataset_preprocess_comp.type_signature.result.element
else:
model_input_spec = input_spec
def tff_model_fn() -> tff.learning.Model:
return tff.learning.from_keras_model(
keras_model=model_builder(),
input_spec=input_spec,
input_spec=model_input_spec,
loss=loss_builder(),
metrics=metrics_builder())
......@@ -95,4 +111,5 @@ def from_flags(
client_lr=client_lr_schedule,
server_optimizer_fn=server_optimizer_fn,
server_lr=server_lr_schedule,
client_weight_fn=client_weight_fn)
client_weight_fn=client_weight_fn,
dataset_preprocess_comp=dataset_preprocess_comp)
......@@ -63,11 +63,11 @@ def metrics_builder():
class IterativeProcessBuilderTest(tf.test.TestCase, parameterized.TestCase):
def _run_rounds(self, iterative_process, client_datasets, num_rounds):
def _run_rounds(self, iterproc_adapter, client_datasets, num_rounds):
train_outputs = []
state = iterative_process.initialize()
state = iterproc_adapter.initialize()
for round_num in range(num_rounds):
iteration_result = iterative_process.next(state, client_datasets)
iteration_result = iterproc_adapter.next(state, client_datasets)
train_outputs.append(iteration_result.metrics)
logging.info('Round %d: %s', round_num, iteration_result.metrics)
state = iteration_result.state
......@@ -78,9 +78,9 @@ class IterativeProcessBuilderTest(tf.test.TestCase, parameterized.TestCase):
FLAGS.server_lr_schedule = 'constant'
federated_data = [[_batch_fn()]]
input_spec = _get_input_spec()
iterative_process = iterative_process_builder.from_flags(
iterproc_adapter = iterative_process_builder.from_flags(
input_spec, model_builder, loss_builder, metrics_builder)
_, train_outputs = self._run_rounds(iterative_process, federated_data, 4)
_, train_outputs = self._run_rounds(iterproc_adapter, federated_data, 4)
self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def test_iterative_process_with_custom_client_weight_fn_decreases_loss(self):
......@@ -92,13 +92,13 @@ class IterativeProcessBuilderTest(tf.test.TestCase, parameterized.TestCase):
def client_weight_fn(local_outputs):
return 1.0 / (1.0 + local_outputs['loss'][-1])
iterative_process = iterative_process_builder.from_flags(
iterproc_adapter = iterative_process_builder.from_flags(
input_spec,
model_builder,
loss_builder,
metrics_builder,
client_weight_fn=client_weight_fn)
_, train_outputs = self._run_rounds(iterative_process, federated_data, 4)
_, train_outputs = self._run_rounds(iterproc_adapter, federated_data, 4)
self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
@parameterized.named_parameters(('inv_lin_decay', 'inv_lin_decay'),
......@@ -112,9 +112,9 @@ class IterativeProcessBuilderTest(tf.test.TestCase, parameterized.TestCase):
FLAGS.client_lr_schedule = sched_type
federated_data = [[_batch_fn()]]
input_spec = _get_input_spec()
iterative_process = iterative_process_builder.from_flags(
iterproc_adapter = iterative_process_builder.from_flags(
input_spec, model_builder, loss_builder, metrics_builder)
_, train_outputs = self._run_rounds(iterative_process, federated_data, 4)
_, train_outputs = self._run_rounds(iterproc_adapter, federated_data, 4)
self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def test_iterative_process_with_exp_decay_client_schedule(self):
......@@ -126,9 +126,9 @@ class IterativeProcessBuilderTest(tf.test.TestCase, parameterized.TestCase):
federated_data = [[_batch_fn()]]
input_spec = _get_input_spec()
iterative_process = iterative_process_builder.from_flags(
iterproc_adapter = iterative_process_builder.from_flags(
input_spec, model_builder, loss_builder, metrics_builder)
_, train_outputs = self._run_rounds(iterative_process, federated_data, 4)
_, train_outputs = self._run_rounds(iterproc_adapter, federated_data, 4)
self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
@parameterized.named_parameters(('inv_lin_decay', 'inv_lin_decay'),
......@@ -141,9 +141,9 @@ class IterativeProcessBuilderTest(tf.test.TestCase, parameterized.TestCase):
FLAGS.server_lr_schedule = sched_type
federated_data = [[_batch_fn()]]
input_spec = _get_input_spec()
iterative_process = iterative_process_builder.from_flags(
iterproc_adapter = iterative_process_builder.from_flags(
input_spec, model_builder, loss_builder, metrics_builder)
_, train_outputs = self._run_rounds(iterative_process, federated_data, 4)
_, train_outputs = self._run_rounds(iterproc_adapter, federated_data, 4)
self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def test_iterative_process_with_exp_decay_server_schedule(self):
......@@ -155,9 +155,9 @@ class IterativeProcessBuilderTest(tf.test.TestCase, parameterized.TestCase):
federated_data = [[_batch_fn()]]
input_spec = _get_input_spec()
iterative_process = iterative_process_builder.from_flags(
iterproc_adapter = iterative_process_builder.from_flags(
input_spec, model_builder, loss_builder, metrics_builder)
_, train_outputs = self._run_rounds(iterative_process, federated_data, 4)
_, train_outputs = self._run_rounds(iterproc_adapter, federated_data, 4)
self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def test_decay_factor_0_does_not_decrease_loss(self):
......@@ -169,9 +169,9 @@ class IterativeProcessBuilderTest(tf.test.TestCase, parameterized.TestCase):
federated_data = [[_batch_fn()]]
input_spec = _get_input_spec()
iterative_process = iterative_process_builder.from_flags(
iterproc_adapter = iterative_process_builder.from_flags(
input_spec, model_builder, loss_builder, metrics_builder)
_, train_outputs = self._run_rounds(iterative_process, federated_data, 4)
_, train_outputs = self._run_rounds(iterproc_adapter, federated_data, 4)
self.assertLess(train_outputs[1]['loss'], train_outputs[0]['loss'])
self.assertNear(
train_outputs[2]['loss'], train_outputs[3]['loss'], err=1e-5)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment