diff --git a/tensorflow_federated/python/research/optimization/stackoverflow/BUILD b/tensorflow_federated/python/research/optimization/stackoverflow/BUILD
index e381ec1281651d25f23b58f215aae063a5313a3a..937bf1c5eac4b23d5d824e86d5fe6746eb901ae0 100644
--- a/tensorflow_federated/python/research/optimization/stackoverflow/BUILD
+++ b/tensorflow_federated/python/research/optimization/stackoverflow/BUILD
@@ -8,6 +8,7 @@ py_binary(
     python_version = "PY3",
     srcs_version = "PY3",
     deps = [
+        "//tensorflow_federated",
         "//tensorflow_federated/python/research/optimization/shared:fed_avg_schedule",
         "//tensorflow_federated/python/research/optimization/shared:iterative_process_builder",
         "//tensorflow_federated/python/research/optimization/shared:keras_metrics",
diff --git a/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py b/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py
index abfef5db1c7b83453342161e1d8cc2e58cb6edee..f5b83c59092c6f744009a5c0347cd56cd7caf993 100644
--- a/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py
+++ b/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py
@@ -20,6 +20,7 @@ from absl import flags
 from absl import logging
 
 import tensorflow as tf
+import tensorflow_federated as tff
 
 from tensorflow_federated.python.research.optimization.shared import fed_avg_schedule
 from tensorflow_federated.python.research.optimization.shared import iterative_process_builder
@@ -96,12 +97,30 @@ def main(argv):
         keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
     ]
 
-  train_set, validation_set, test_set = stackoverflow_dataset.construct_word_level_datasets(
-      FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round,
-      FLAGS.sequence_length, FLAGS.max_elements_per_user,
-      FLAGS.num_validation_examples)
-
-  input_spec = validation_set.element_spec
+  dataset_vocab = stackoverflow_dataset.create_vocab(FLAGS.vocab_size)
+
+  train_clientdata, _, test_clientdata = (
+      tff.simulation.datasets.stackoverflow.load_data())
+
+  # Split the test data into test and validation sets.
+  # TODO(b/161914546): consider moving evaluation to use
+  # `tff.learning.build_federated_evaluation` to get metrics over client
+  # distributions, as well as the example weight means from this centralized
+  # evaluation.
+  base_test_dataset = test_clientdata.create_tf_dataset_from_all_clients()
+  preprocess_val_and_test = stackoverflow_dataset.create_test_dataset_preprocess_fn(
+      dataset_vocab, FLAGS.sequence_length)
+  test_set = preprocess_val_and_test(
+      base_test_dataset.skip(FLAGS.num_validation_examples))
+  validation_set = preprocess_val_and_test(
+      base_test_dataset.take(FLAGS.num_validation_examples))
+
+  train_dataset_preprocess_comp = stackoverflow_dataset.create_train_dataset_preprocess_fn(
+      vocab=stackoverflow_dataset.create_vocab(FLAGS.vocab_size),
+      client_batch_size=FLAGS.client_batch_size,
+      client_epochs_per_round=FLAGS.client_epochs_per_round,
+      max_seq_len=FLAGS.sequence_length,
+      max_training_elements_per_user=FLAGS.max_elements_per_user)
 
   def client_weight_fn(local_outputs):
     # Num_tokens is a tensor with type int64[1], to use as a weight need
@@ -109,14 +128,15 @@ def main(argv):
     return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)
 
   training_process = iterative_process_builder.from_flags(
-      input_spec=input_spec,
+      input_spec=None,  # type pulled from train_dataset_preproces_comp.
       model_builder=model_builder,
       loss_builder=loss_builder,
       metrics_builder=metrics_builder,
-      client_weight_fn=client_weight_fn)
+      client_weight_fn=client_weight_fn,
+      dataset_preprocess_comp=train_dataset_preprocess_comp)
 
   client_datasets_fn = training_utils.build_client_datasets_fn(
-      train_dataset=train_set,
+      train_dataset=train_clientdata,
       train_clients_per_round=FLAGS.clients_per_round,
       random_seed=FLAGS.client_datasets_random_seed)