From c60c0e427d2d823681ee11d6f10ee26cf6e141c0 Mon Sep 17 00:00:00 2001 From: Zachary Garrett <zachgarrett@google.com> Date: Wed, 22 Jul 2020 17:32:36 -0700 Subject: [PATCH] Rework run_federated.py to allow running in a multimachine setting. - Migrate the train dataset preprocessing function into the iterative process, rather than part of dataset building the client data function. This is necessary to enable serialization of datasets for sending to remote workers. PiperOrigin-RevId: 322691112 --- .../research/optimization/stackoverflow/BUILD | 1 + .../stackoverflow/run_federated.py | 38 ++++++++++++++----- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/tensorflow_federated/python/research/optimization/stackoverflow/BUILD b/tensorflow_federated/python/research/optimization/stackoverflow/BUILD index e381ec128..937bf1c5e 100644 --- a/tensorflow_federated/python/research/optimization/stackoverflow/BUILD +++ b/tensorflow_federated/python/research/optimization/stackoverflow/BUILD @@ -8,6 +8,7 @@ py_binary( python_version = "PY3", srcs_version = "PY3", deps = [ + "//tensorflow_federated", "//tensorflow_federated/python/research/optimization/shared:fed_avg_schedule", "//tensorflow_federated/python/research/optimization/shared:iterative_process_builder", "//tensorflow_federated/python/research/optimization/shared:keras_metrics", diff --git a/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py b/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py index abfef5db1..f5b83c590 100644 --- a/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py +++ b/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py @@ -20,6 +20,7 @@ from absl import flags from absl import logging import tensorflow as tf +import tensorflow_federated as tff from tensorflow_federated.python.research.optimization.shared import fed_avg_schedule from tensorflow_federated.python.research.optimization.shared import iterative_process_builder @@ -96,12 +97,30 @@ def main(argv): keras_metrics.NumTokensCounter(masked_tokens=[pad_token]) ] - train_set, validation_set, test_set = stackoverflow_dataset.construct_word_level_datasets( - FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round, - FLAGS.sequence_length, FLAGS.max_elements_per_user, - FLAGS.num_validation_examples) - - input_spec = validation_set.element_spec + dataset_vocab = stackoverflow_dataset.create_vocab(FLAGS.vocab_size) + + train_clientdata, _, test_clientdata = ( + tff.simulation.datasets.stackoverflow.load_data()) + + # Split the test data into test and validation sets. + # TODO(b/161914546): consider moving evaluation to use + # `tff.learning.build_federated_evaluation` to get metrics over client + # distributions, as well as the example weight means from this centralized + # evaluation. + base_test_dataset = test_clientdata.create_tf_dataset_from_all_clients() + preprocess_val_and_test = stackoverflow_dataset.create_test_dataset_preprocess_fn( + dataset_vocab, FLAGS.sequence_length) + test_set = preprocess_val_and_test( + base_test_dataset.skip(FLAGS.num_validation_examples)) + validation_set = preprocess_val_and_test( + base_test_dataset.take(FLAGS.num_validation_examples)) + + train_dataset_preprocess_comp = stackoverflow_dataset.create_train_dataset_preprocess_fn( + vocab=stackoverflow_dataset.create_vocab(FLAGS.vocab_size), + client_batch_size=FLAGS.client_batch_size, + client_epochs_per_round=FLAGS.client_epochs_per_round, + max_seq_len=FLAGS.sequence_length, + max_training_elements_per_user=FLAGS.max_elements_per_user) def client_weight_fn(local_outputs): # Num_tokens is a tensor with type int64[1], to use as a weight need @@ -109,14 +128,15 @@ def main(argv): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) training_process = iterative_process_builder.from_flags( - input_spec=input_spec, + input_spec=None, # type pulled from train_dataset_preproces_comp. model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder, - client_weight_fn=client_weight_fn) + client_weight_fn=client_weight_fn, + dataset_preprocess_comp=train_dataset_preprocess_comp) client_datasets_fn = training_utils.build_client_datasets_fn( - train_dataset=train_set, + train_dataset=train_clientdata, train_clients_per_round=FLAGS.clients_per_round, random_seed=FLAGS.client_datasets_random_seed) -- GitLab