diff --git a/tensorflow_federated/python/research/optimization/stackoverflow/BUILD b/tensorflow_federated/python/research/optimization/stackoverflow/BUILD index e381ec1281651d25f23b58f215aae063a5313a3a..937bf1c5eac4b23d5d824e86d5fe6746eb901ae0 100644 --- a/tensorflow_federated/python/research/optimization/stackoverflow/BUILD +++ b/tensorflow_federated/python/research/optimization/stackoverflow/BUILD @@ -8,6 +8,7 @@ py_binary( python_version = "PY3", srcs_version = "PY3", deps = [ + "//tensorflow_federated", "//tensorflow_federated/python/research/optimization/shared:fed_avg_schedule", "//tensorflow_federated/python/research/optimization/shared:iterative_process_builder", "//tensorflow_federated/python/research/optimization/shared:keras_metrics", diff --git a/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py b/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py index abfef5db1c7b83453342161e1d8cc2e58cb6edee..f5b83c59092c6f744009a5c0347cd56cd7caf993 100644 --- a/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py +++ b/tensorflow_federated/python/research/optimization/stackoverflow/run_federated.py @@ -20,6 +20,7 @@ from absl import flags from absl import logging import tensorflow as tf +import tensorflow_federated as tff from tensorflow_federated.python.research.optimization.shared import fed_avg_schedule from tensorflow_federated.python.research.optimization.shared import iterative_process_builder @@ -96,12 +97,30 @@ def main(argv): keras_metrics.NumTokensCounter(masked_tokens=[pad_token]) ] - train_set, validation_set, test_set = stackoverflow_dataset.construct_word_level_datasets( - FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round, - FLAGS.sequence_length, FLAGS.max_elements_per_user, - FLAGS.num_validation_examples) - - input_spec = validation_set.element_spec + dataset_vocab = stackoverflow_dataset.create_vocab(FLAGS.vocab_size) + + train_clientdata, _, test_clientdata = ( + tff.simulation.datasets.stackoverflow.load_data()) + + # Split the test data into test and validation sets. + # TODO(b/161914546): consider moving evaluation to use + # `tff.learning.build_federated_evaluation` to get metrics over client + # distributions, as well as the example weight means from this centralized + # evaluation. + base_test_dataset = test_clientdata.create_tf_dataset_from_all_clients() + preprocess_val_and_test = stackoverflow_dataset.create_test_dataset_preprocess_fn( + dataset_vocab, FLAGS.sequence_length) + test_set = preprocess_val_and_test( + base_test_dataset.skip(FLAGS.num_validation_examples)) + validation_set = preprocess_val_and_test( + base_test_dataset.take(FLAGS.num_validation_examples)) + + train_dataset_preprocess_comp = stackoverflow_dataset.create_train_dataset_preprocess_fn( + vocab=stackoverflow_dataset.create_vocab(FLAGS.vocab_size), + client_batch_size=FLAGS.client_batch_size, + client_epochs_per_round=FLAGS.client_epochs_per_round, + max_seq_len=FLAGS.sequence_length, + max_training_elements_per_user=FLAGS.max_elements_per_user) def client_weight_fn(local_outputs): # Num_tokens is a tensor with type int64[1], to use as a weight need @@ -109,14 +128,15 @@ def main(argv): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) training_process = iterative_process_builder.from_flags( - input_spec=input_spec, + input_spec=None, # type pulled from train_dataset_preproces_comp. model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder, - client_weight_fn=client_weight_fn) + client_weight_fn=client_weight_fn, + dataset_preprocess_comp=train_dataset_preprocess_comp) client_datasets_fn = training_utils.build_client_datasets_fn( - train_dataset=train_set, + train_dataset=train_clientdata, train_clients_per_round=FLAGS.clients_per_round, random_seed=FLAGS.client_datasets_random_seed)