Skip to content
Snippets Groups Projects
Commit f058a1ac authored by Zachary Garrett's avatar Zachary Garrett Committed by tensorflow-copybara
Browse files

Add back `tff.tf_computation` decorator to preprocessing function for

stackoverflow so it can be passed to the iterative process.

Make shuffling conidition on the buffer size.

PiperOrigin-RevId: 321846124
parent 96737fde
No related branches found
No related tags found
No related merge requests found
......@@ -14,6 +14,7 @@
"""Data loader for Stackoverflow."""
from typing import List
from absl import logging
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
......@@ -133,12 +134,25 @@ def create_train_dataset_preprocess_fn(vocab: List[str],
else:
shuffle_buffer_size = max_training_elements_per_user
# TODO(b/155408842): need further investigation on why `tff.tf_compuation`
# decorator causes b/153363900 for `to_ids`, and large memory consumption.
feature_dtypes = [
('creation_date', tf.string),
('title', tf.string),
('score', tf.int64),
('tags', tf.string),
('tokens', tf.string),
('type', tf.string),
]
@tff.tf_computation(
tff.SequenceType(
tff.NamedTupleType([(name, tff.TensorType(dtype=dtype, shape=()))
for name, dtype in feature_dtypes])))
def preprocess_train(dataset):
to_ids = build_to_ids_fn(vocab, max_seq_len)
dataset = dataset.take(max_training_elements_per_user)
dataset = dataset.shuffle(shuffle_buffer_size)
if shuffle_buffer_size > 0:
logging.info('Adding shuffle with buffer size: %d', shuffle_buffer_size)
dataset = dataset.shuffle(shuffle_buffer_size)
dataset = dataset.repeat(client_epochs_per_round)
dataset = dataset.map(
to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE)
......@@ -230,17 +244,15 @@ def construct_word_level_datasets(vocab_size: int,
vocab = create_vocab(vocab_size)
raw_test_dataset = stackoverflow_test.create_tf_dataset_from_all_clients()
preprocess_train = create_train_dataset_preprocess_fn(
vocab, client_batch_size, client_epochs_per_round, max_seq_len,
max_training_elements_per_user, max_batches_per_user,
max_shuffle_buffer_size)
stackoverflow_train = stackoverflow_train.preprocess(preprocess_train)
raw_test_dataset = stackoverflow_test.create_tf_dataset_from_all_clients()
preprocess_val_and_test = create_test_dataset_preprocess_fn(
vocab, max_seq_len)
stackoverflow_train = stackoverflow_train.preprocess(preprocess_train)
stackoverflow_val = preprocess_val_and_test(
raw_test_dataset.take(num_validation_examples))
stackoverflow_test = preprocess_val_and_test(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment