From c60c0e427d2d823681ee11d6f10ee26cf6e141c0 Mon Sep 17 00:00:00 2001
From: Zachary Garrett <zachgarrett@google.com>
Date: Wed, 22 Jul 2020 17:32:36 -0700
Subject: [PATCH] Rework run_federated.py to allow running in a multimachine
 setting.

-  Migrate the train dataset preprocessing function into the iterative process,
   rather than part of dataset building the client data function. This is necessary
   to enable serialization of datasets for sending to remote workers.

PiperOrigin-RevId: 322691112
---
 .../research/optimization/stackoverflow/BUILD |  1 +
 .../stackoverflow/run_federated.py            | 38 ++++++++++++++-----
 2 files changed, 30 insertions(+), 9 deletions(-)

diff --git a/tensorflow_federated/python/research/optimization/stackoverflow/BUILD b/tensorflow_federated/python/research/optimization/stackoverflow/BUILD
index e381ec128..937bf1c5e 100644
--- a/tensorflow_federated/python/research/optimization/stackoverflow/BUILD
+++ b/tensorflow_federated/python/research/optimization/stackoverflow/BUILD
@@ -8,6 +8,7 @@ py_binary(
     python_version = "PY3",
     srcs_version = "PY3",
     deps = [
+        "//tensorflow_federated",
         "//tensorflow_federated/python/research/optimization/shared:fed_avg_schedule",
         "//tensorflow_federated/python/research/optimization/shared:iterative_process_builder",
         "//tensorflow_federated/python/research/optimization/shared:keras_metrics",
diff --git a/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py b/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py
index abfef5db1..f5b83c590 100644
--- a/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py
+++ b/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py
@@ -20,6 +20,7 @@ from absl import flags
 from absl import logging
 
 import tensorflow as tf
+import tensorflow_federated as tff
 
 from tensorflow_federated.python.research.optimization.shared import fed_avg_schedule
 from tensorflow_federated.python.research.optimization.shared import iterative_process_builder
@@ -96,12 +97,30 @@ def main(argv):
         keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
     ]
 
-  train_set, validation_set, test_set = stackoverflow_dataset.construct_word_level_datasets(
-      FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round,
-      FLAGS.sequence_length, FLAGS.max_elements_per_user,
-      FLAGS.num_validation_examples)
-
-  input_spec = validation_set.element_spec
+  dataset_vocab = stackoverflow_dataset.create_vocab(FLAGS.vocab_size)
+
+  train_clientdata, _, test_clientdata = (
+      tff.simulation.datasets.stackoverflow.load_data())
+
+  # Split the test data into test and validation sets.
+  # TODO(b/161914546): consider moving evaluation to use
+  # `tff.learning.build_federated_evaluation` to get metrics over client
+  # distributions, as well as the example weight means from this centralized
+  # evaluation.
+  base_test_dataset = test_clientdata.create_tf_dataset_from_all_clients()
+  preprocess_val_and_test = stackoverflow_dataset.create_test_dataset_preprocess_fn(
+      dataset_vocab, FLAGS.sequence_length)
+  test_set = preprocess_val_and_test(
+      base_test_dataset.skip(FLAGS.num_validation_examples))
+  validation_set = preprocess_val_and_test(
+      base_test_dataset.take(FLAGS.num_validation_examples))
+
+  train_dataset_preprocess_comp = stackoverflow_dataset.create_train_dataset_preprocess_fn(
+      vocab=stackoverflow_dataset.create_vocab(FLAGS.vocab_size),
+      client_batch_size=FLAGS.client_batch_size,
+      client_epochs_per_round=FLAGS.client_epochs_per_round,
+      max_seq_len=FLAGS.sequence_length,
+      max_training_elements_per_user=FLAGS.max_elements_per_user)
 
   def client_weight_fn(local_outputs):
     # Num_tokens is a tensor with type int64[1], to use as a weight need
@@ -109,14 +128,15 @@ def main(argv):
     return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)
 
   training_process = iterative_process_builder.from_flags(
-      input_spec=input_spec,
+      input_spec=None,  # type pulled from train_dataset_preproces_comp.
       model_builder=model_builder,
       loss_builder=loss_builder,
       metrics_builder=metrics_builder,
-      client_weight_fn=client_weight_fn)
+      client_weight_fn=client_weight_fn,
+      dataset_preprocess_comp=train_dataset_preprocess_comp)
 
   client_datasets_fn = training_utils.build_client_datasets_fn(
-      train_dataset=train_set,
+      train_dataset=train_clientdata,
       train_clients_per_round=FLAGS.clients_per_round,
       random_seed=FLAGS.client_datasets_random_seed)
 
-- 
GitLab