Skip to content
Snippets Groups Projects
Commit c60c0e42 authored by Zachary Garrett's avatar Zachary Garrett Committed by tensorflow-copybara
Browse files

Rework run_federated.py to allow running in a multimachine setting.

-  Migrate the train dataset preprocessing function into the iterative process,
   rather than part of dataset building the client data function. This is necessary
   to enable serialization of datasets for sending to remote workers.

PiperOrigin-RevId: 322691112
parent 2410acda
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ py_binary(
python_version = "PY3",
srcs_version = "PY3",
deps = [
"//tensorflow_federated",
"//tensorflow_federated/python/research/optimization/shared:fed_avg_schedule",
"//tensorflow_federated/python/research/optimization/shared:iterative_process_builder",
"//tensorflow_federated/python/research/optimization/shared:keras_metrics",
......
......@@ -20,6 +20,7 @@ from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow_federated.python.research.optimization.shared import fed_avg_schedule
from tensorflow_federated.python.research.optimization.shared import iterative_process_builder
......@@ -96,12 +97,30 @@ def main(argv):
keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
]
train_set, validation_set, test_set = stackoverflow_dataset.construct_word_level_datasets(
FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round,
FLAGS.sequence_length, FLAGS.max_elements_per_user,
FLAGS.num_validation_examples)
input_spec = validation_set.element_spec
dataset_vocab = stackoverflow_dataset.create_vocab(FLAGS.vocab_size)
train_clientdata, _, test_clientdata = (
tff.simulation.datasets.stackoverflow.load_data())
# Split the test data into test and validation sets.
# TODO(b/161914546): consider moving evaluation to use
# `tff.learning.build_federated_evaluation` to get metrics over client
# distributions, as well as the example weight means from this centralized
# evaluation.
base_test_dataset = test_clientdata.create_tf_dataset_from_all_clients()
preprocess_val_and_test = stackoverflow_dataset.create_test_dataset_preprocess_fn(
dataset_vocab, FLAGS.sequence_length)
test_set = preprocess_val_and_test(
base_test_dataset.skip(FLAGS.num_validation_examples))
validation_set = preprocess_val_and_test(
base_test_dataset.take(FLAGS.num_validation_examples))
train_dataset_preprocess_comp = stackoverflow_dataset.create_train_dataset_preprocess_fn(
vocab=stackoverflow_dataset.create_vocab(FLAGS.vocab_size),
client_batch_size=FLAGS.client_batch_size,
client_epochs_per_round=FLAGS.client_epochs_per_round,
max_seq_len=FLAGS.sequence_length,
max_training_elements_per_user=FLAGS.max_elements_per_user)
def client_weight_fn(local_outputs):
# Num_tokens is a tensor with type int64[1], to use as a weight need
......@@ -109,14 +128,15 @@ def main(argv):
return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)
training_process = iterative_process_builder.from_flags(
input_spec=input_spec,
input_spec=None, # type pulled from train_dataset_preproces_comp.
model_builder=model_builder,
loss_builder=loss_builder,
metrics_builder=metrics_builder,
client_weight_fn=client_weight_fn)
client_weight_fn=client_weight_fn,
dataset_preprocess_comp=train_dataset_preprocess_comp)
client_datasets_fn = training_utils.build_client_datasets_fn(
train_dataset=train_set,
train_dataset=train_clientdata,
train_clients_per_round=FLAGS.clients_per_round,
random_seed=FLAGS.client_datasets_random_seed)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment