Skip to content
Snippets Groups Projects
Commit 418e6997 authored by Keith Rush's avatar Keith Rush Committed by Zachary Garrett
Browse files

Moves sequence model functions into research/baselines and

creates centralized experiment binary for stackoverlow NWP problem.

PiperOrigin-RevId: 272554184
parent 3c57477e
No related branches found
No related tags found
No related merge requests found
package(default_visibility = ["//tensorflow_federated/python/research"])
licenses(["notice"])
py_library(
name = "models",
srcs = ["models.py"],
)
py_test(
name = "models_test",
srcs = ["models_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":models"],
)
py_binary(
name = "non_federated_stackoverflow",
srcs = ["non_federated_stackoverflow.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":models",
"//tensorflow_federated",
],
)
# Lint as: python3
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sequence model functions for research baselines."""
import tensorflow as tf
def create_recurrent_model(vocab_size,
embedding_size,
num_layers,
recurrent_layer_fn,
name='rnn'):
"""Constructs zero-padded keras model with the given parameters and cell.
Args:
vocab_size: Size of vocabulary.
embedding_size: Size of embedding.
num_layers: Number of LSTM layers to sequentially stack.
recurrent_layer_fn: No-arg function which returns an instance of a
subclass of `tf.keras.layers.RNN`, creating the cells of the recurrent
model.
name: (Optional) string to name the returned `tf.keras.Model`.
Returns:
`tf.keras.Model`.
"""
inputs = tf.keras.layers.Input(shape=(None,))
embedded = tf.keras.layers.Embedding(
input_dim=vocab_size + 1, # Add 1 for padding.
output_dim=embedding_size,
mask_zero=True)(
inputs)
projected = embedded
for _ in range(num_layers):
layer = recurrent_layer_fn()
if not isinstance(layer, tf.keras.layers.RNN):
raise ValueError('The `recurrent_layer_fn` parameter to '
'`create_recurrent_model` should return an instance of '
'`tf.keras.layers.Layer` which inherits from '
'`tf.keras.layers.RNN`; you passed a function returning '
'{}'.format(layer))
processed = layer(projected)
# A projection changes dimension from rnn_layer_size to input_embedding_size
projected = tf.keras.layers.Dense(embedding_size)(processed)
logits = tf.keras.layers.Dense(vocab_size + 1)(projected)
return tf.keras.Model(inputs=inputs, outputs=logits, name=name)
# Lint as: python3
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for stackoverflow models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import tensorflow as tf
from tensorflow_federated.python.research.baselines.stackoverflow import models
class KerasSequenceModelsTest(absltest.TestCase):
def test_dense_fn_raises(self):
def _dense_layer_fn():
return tf.keras.layers.Dense(10)
with self.assertRaisesRegex(ValueError, 'tf.keras.layers.RNN'):
models.create_recurrent_model(10, 10, 1, _dense_layer_fn, 'dense')
def test_lstm_constructs(self):
def _recurrent_layer_fn():
return tf.keras.layers.LSTM(10, return_sequences=True)
model = models.create_recurrent_model(10, 10, 2, _recurrent_layer_fn,
'rnn-lstm')
self.assertIsInstance(model, tf.keras.Model)
self.assertEqual('rnn-lstm', model.name)
def test_gru_constructs(self):
def _recurrent_layer_fn():
return tf.keras.layers.GRU(10, return_sequences=True)
model = models.create_recurrent_model(10, 10, 2, _recurrent_layer_fn,
'rnn-gru')
self.assertIsInstance(model, tf.keras.Model)
self.assertEqual('rnn-gru', model.name)
def test_gru_fewer_parameters_than_lstm(self):
def _gru_fn():
return tf.keras.layers.GRU(10, return_sequences=True)
def _lstm_fn():
return tf.keras.layers.LSTM(10, return_sequences=True)
gru_model = models.create_recurrent_model(10, 10, 1, _gru_fn, 'gru')
lstm_model = models.create_recurrent_model(10, 10, 1, _lstm_fn, 'lstm')
self.assertLess(gru_model.count_params(), lstm_model.count_params())
if __name__ == '__main__':
absltest.main()
# Lint as: python3
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Baseline experiment on centralized data."""
import collections
import os
from absl import app
from absl import flags
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow_federated.python.research.baselines.stackoverflow import models
flags.DEFINE_string(
'exp_name', 'centralized_keras_stackoverflow',
'Unique name for the experiment, suitable for '
'use in filenames.')
flags.DEFINE_integer('batch_size', 128, 'Batch size used.')
flags.DEFINE_integer(
'vocab_size', 30000,
'Size of the vocab to use; results in most `vocab_size` number of most '
'common words used as vocabulary.')
flags.DEFINE_integer('embedding_size', 256,
'Dimension of word embedding to use.')
flags.DEFINE_integer('latent_size', 512,
'Dimension of latent size to use in recurrent cell')
flags.DEFINE_integer('num_layers', 1,
'Number of stacked recurrent layers to use.')
flags.DEFINE_float('learning_rate', 0.01,
'Learning rate to use for centralized SGD optimizer.')
flags.DEFINE_float(
'momentum', 0.0, 'Momentum value to use fo SGD optimizer. A value of 0.0 '
'corresponds to no momentum.')
# TODO(b/141867576): TFF currently needs a concrete maximum sequence length.
# Follow up when this restriction is lifted.
flags.DEFINE_integer('sequence_length', 100, 'Max sequence length to use')
# There are over 100 million sentences in this dataset; this flag caps the
# epoch size for speed. For comparison: EMNIST contains roughly 300,000
# examples, so we set that as default here.
flags.DEFINE_integer('num_training_examples', 300 * 1000,
'Number of training examples to process per epoch.')
flags.DEFINE_integer('num_val_examples', 1000,
'Number of examples to take for validation set.')
flags.DEFINE_string('root_output_dir', '/tmp/non_federated_stackoverflow/',
'Root directory for writing experiment output.')
FLAGS = flags.FLAGS
def _create_vocab():
vocab_dict = tff.simulation.datasets.stackoverflow.load_word_counts()
sorted_pairs = sorted(
vocab_dict.items(), key=lambda x: -x[1])[:FLAGS.vocab_size]
return list(x[0] for x in sorted_pairs)
def construct_word_level_datasets(vocab):
"""Preprocesses train and test datasets for stackoverflow."""
(stackoverflow_train, _,
stackoverflow_test) = tff.simulation.datasets.stackoverflow.load_data()
# Mix all clients for training and testing in the centralized setting.
raw_test_dataset = stackoverflow_test.create_tf_dataset_from_all_clients()
raw_train_dataset = stackoverflow_train.create_tf_dataset_from_all_clients()
BatchType = collections.namedtuple('BatchType', ['x', 'y']) # pylint: disable=invalid-name
table_values = tf.constant(list(range(FLAGS.vocab_size)), dtype=tf.int64)
table = tf.lookup.StaticVocabularyTable(
tf.lookup.KeyValueTensorInitializer(vocab, table_values),
num_oov_buckets=1)
def to_ids(x):
"""Splits a string into word IDs."""
s = tf.reshape(x['tokens'], shape=[1])
words = tf.string_split(s, sep=' ').values
truncated_words = words[:FLAGS.sequence_length]
ids = table.lookup(truncated_words)
return ids
def split_input_target(chunk):
"""Generate input and target data.
The task of language model is to predict the next word.
Args:
chunk: A Tensor of text data.
Returns:
A namedtuple of input and target data.
"""
input_text = tf.map_fn(lambda x: x[:-1], chunk)
target_text = tf.map_fn(lambda x: x[1:], chunk)
return BatchType(input_text, target_text)
def preprocess(dataset, epochs=1):
return (dataset.map(to_ids).padded_batch(
FLAGS.batch_size,
padded_shapes=[FLAGS.sequence_length
]).map(split_input_target).repeat(epochs))
stackoverflow_train = preprocess(raw_train_dataset)
stackoverflow_val = preprocess(raw_test_dataset).take(1000)
stackoverflow_test = preprocess(raw_test_dataset)
return stackoverflow_train, stackoverflow_val, stackoverflow_test
def run_experiment():
"""Runs the training experiment."""
vocab = _create_vocab()
(stackoverflow_train, stackoverflow_val,
stackoverflow_test) = construct_word_level_datasets(vocab)
num_training_steps = FLAGS.num_training_examples / FLAGS.batch_size
def _lstm_fn():
return tf.keras.layers.LSTM(FLAGS.latent_size, return_sequences=True)
model = models.create_recurrent_model(FLAGS.vocab_size, FLAGS.embedding_size,
FLAGS.num_layers, _lstm_fn,
'stackoverflow-lstm')
model.compile(
loss=tf.keras.losses.sparse_categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(
learning_rate=FLAGS.learning_rate, use_momentum=FLAGS.use_momentum),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
train_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
'train_results.csv')
test_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
'test_results.csv')
train_csv_logger = tf.keras.callbacks.CSVLogger(train_results_path)
test_csv_logger = tf.keras.callbacks.CSVLogger(test_results_path)
model.fit(
stackoverflow_train,
steps_per_epoch=num_training_steps,
epochs=25,
verbose=1,
validation_data=stackoverflow_val,
callbacks=[train_csv_logger])
score = model.evaluate_generator(
stackoverflow_test, verbose=1, callbacks=[test_csv_logger])
print('Final test loss: %.4f' % score[0])
print('Final test accuracy: %.4f' % score[1])
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
tf.compat.v1.enable_v2_behavior()
try:
tf.io.gfile.makedirs(os.path.join(FLAGS.root_output_dir, FLAGS.exp_name))
except tf.errors.OpError:
pass
run_experiment()
if __name__ == '__main__':
app.run(main)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment