Skip to content
Snippets Groups Projects
Unverified Commit e7ae539e authored by Ronald Seoh's avatar Ronald Seoh
Browse files

typo fix in tutorials/federated_learning_for_text_generation.ipynb

parent c6e4cdbb
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
##### Copyright 2019 The TensorFlow Authors. ##### Copyright 2019 The TensorFlow Authors.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
#@title Licensed under the Apache License, Version 2.0 (the "License"); #@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0 # https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Federated Learning for Text Generation # Federated Learning for Text Generation
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
<table class="tfo-notebook-buttons" align="left"> <table class="tfo-notebook-buttons" align="left">
<td> <td>
<a target="_blank" href="https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a> <a target="_blank" href="https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
</td> </td>
<td> <td>
<a target="_blank" href="https://colab.research.google.com/github/tensorflow/federated/blob/master/docs/tutorials/federated_learning_for_text_generation.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a> <a target="_blank" href="https://colab.research.google.com/github/tensorflow/federated/blob/master/docs/tutorials/federated_learning_for_text_generation.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
</td> </td>
<td> <td>
<a target="_blank" href="https://github.com/tensorflow/federated/blob/master/docs/tutorials/federated_learning_for_text_generation.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a> <a target="_blank" href="https://github.com/tensorflow/federated/blob/master/docs/tutorials/federated_learning_for_text_generation.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
</td> </td>
</table> </table>
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
**NOTE**: This colab has been verified to work with the [latest released version](https://github.com/tensorflow/federated#compatibility) of the `tensorflow_federated` pip package, but the Tensorflow Federated project is still in pre-release development and may not work on `master`. **NOTE**: This colab has been verified to work with the [latest released version](https://github.com/tensorflow/federated#compatibility) of the `tensorflow_federated` pip package, but the Tensorflow Federated project is still in pre-release development and may not work on `master`.
This tutorial builds on the concepts in the [Federated Learning for Image Classification](federated_learning_for_image_classification.ipynb) tutorial, and demonstrates several other useful approaches for federated learning. This tutorial builds on the concepts in the [Federated Learning for Image Classification](federated_learning_for_image_classification.ipynb) tutorial, and demonstrates several other useful approaches for federated learning.
In particular, we load a previously trained Keras model, and refine it using federated training on a (simulated) decentralized dataset. This is practically important for several reasons . The ability to use serialized models makes it easy to mix federated learning with other ML approaches. Further, this allows use of an increasing range of pre-trained models --- for example, training language models from scratch is rarely necessary, as numerous pre-trained models are now widely available (see, e.g., [TF Hub](https://www.tensorflow.org/hub)). Instead, it makes more sense to start from a pre-trained model, and refine it using Federated Learning, adapting to the particular characteristics of the decentralized data for a particular application. In particular, we load a previously trained Keras model, and refine it using federated training on a (simulated) decentralized dataset. This is practically important for several reasons . The ability to use serialized models makes it easy to mix federated learning with other ML approaches. Further, this allows use of an increasing range of pre-trained models --- for example, training language models from scratch is rarely necessary, as numerous pre-trained models are now widely available (see, e.g., [TF Hub](https://www.tensorflow.org/hub)). Instead, it makes more sense to start from a pre-trained model, and refine it using Federated Learning, adapting to the particular characteristics of the decentralized data for a particular application.
For this tutorial, we start with a RNN that generates ASCII characters, and refine it via federated learning. We also show how the final weights can be fed back to the original Keras model, allowing easy evaluation and text generation using standard tools. For this tutorial, we start with a RNN that generates ASCII characters, and refine it via federated learning. We also show how the final weights can be fed back to the original Keras model, allowing easy evaluation and text generation using standard tools.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
#@test {"skip": true} #@test {"skip": true}
!pip install --quiet --upgrade tensorflow_federated_nightly !pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio !pip install --quiet --upgrade nest_asyncio
import nest_asyncio import nest_asyncio
nest_asyncio.apply() nest_asyncio.apply()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
import collections import collections
import functools import functools
import os import os
import time import time
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_federated as tff import tensorflow_federated as tff
np.random.seed(0) np.random.seed(0)
# Test the TFF is working: # Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')() tff.federated_computation(lambda: 'Hello, World!')()
``` ```
%% Output %% Output
b'Hello, World!' b'Hello, World!'
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Load a pre-trained model ## Load a pre-trained model
We load a model that was pre-trained following the TensorFlow tutorial We load a model that was pre-trained following the TensorFlow tutorial
[Text generation using a RNN with eager execution](https://www.tensorflow.org/tutorials/sequences/text_generation). However, [Text generation using a RNN with eager execution](https://www.tensorflow.org/tutorials/sequences/text_generation). However,
rather than training on [The Complete Works of Shakespeare](http://www.gutenberg.org/files/100/100-0.txt), we pre-trained the model on the text from the Charles Dickens' rather than training on [The Complete Works of Shakespeare](http://www.gutenberg.org/files/100/100-0.txt), we pre-trained the model on the text from the Charles Dickens'
[A Tale of Two Cities](http://www.ibiblio.org/pub/docs/books/gutenberg/9/98/98.txt) [A Tale of Two Cities](http://www.ibiblio.org/pub/docs/books/gutenberg/9/98/98.txt)
and and
[A Christmas Carol](http://www.ibiblio.org/pub/docs/books/gutenberg/4/46/46.txt). [A Christmas Carol](http://www.ibiblio.org/pub/docs/books/gutenberg/4/46/46.txt).
Other than expanding the vocabularly, we didn't modify the original tutorial, so this initial model isn't state-of-the-art, but it produces reasonable predictions and is sufficient for our tutorial purposes. The final model was saved with `tf.keras.models.save_model(include_optimizer=False)`. Other than expanding the vocabulary, we didn't modify the original tutorial, so this initial model isn't state-of-the-art, but it produces reasonable predictions and is sufficient for our tutorial purposes. The final model was saved with `tf.keras.models.save_model(include_optimizer=False)`.
We will use federated learning to fine-tune this model for Shakespeare in this tutorial, using a federated version of the data provided by TFF. We will use federated learning to fine-tune this model for Shakespeare in this tutorial, using a federated version of the data provided by TFF.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Generate the vocab lookup tables ### Generate the vocab lookup tables
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens: # A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:
vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r') vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r')
# Creating a mapping from unique characters to indices # Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(vocab)} char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab) idx2char = np.array(vocab)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Load the pre-trained model and generate some text ### Load the pre-trained model and generate some text
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
def load_model(batch_size): def load_model(batch_size):
urls = { urls = {
1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel', 1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',
8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'} 8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}
assert batch_size in urls, 'batch_size must be in ' + str(urls.keys()) assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())
url = urls[batch_size] url = urls[batch_size]
local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url) local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)
return tf.keras.models.load_model(local_file, compile=False) return tf.keras.models.load_model(local_file, compile=False)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
def generate_text(model, start_string): def generate_text(model, start_string):
# From https://www.tensorflow.org/tutorials/sequences/text_generation # From https://www.tensorflow.org/tutorials/sequences/text_generation
num_generate = 200 num_generate = 200
input_eval = [char2idx[s] for s in start_string] input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0) input_eval = tf.expand_dims(input_eval, 0)
text_generated = [] text_generated = []
temperature = 1.0 temperature = 1.0
model.reset_states() model.reset_states()
for i in range(num_generate): for i in range(num_generate):
predictions = model(input_eval) predictions = model(input_eval)
predictions = tf.squeeze(predictions, 0) predictions = tf.squeeze(predictions, 0)
predictions = predictions / temperature predictions = predictions / temperature
predicted_id = tf.random.categorical( predicted_id = tf.random.categorical(
predictions, num_samples=1)[-1, 0].numpy() predictions, num_samples=1)[-1, 0].numpy()
input_eval = tf.expand_dims([predicted_id], 0) input_eval = tf.expand_dims([predicted_id], 0)
text_generated.append(idx2char[predicted_id]) text_generated.append(idx2char[predicted_id])
return (start_string + ''.join(text_generated)) return (start_string + ''.join(text_generated))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
# Text generation requires a batch_size=1 model. # Text generation requires a batch_size=1 model.
keras_model_batch1 = load_model(batch_size=1) keras_model_batch1 = load_model(batch_size=1)
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? ')) print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
``` ```
%% Output %% Output
Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel
16195584/16193984 [==============================] - 0s 0us/step 16195584/16193984 [==============================] - 0s 0us/step
16203776/16193984 [==============================] - 0s 0us/step 16203776/16193984 [==============================] - 0s 0us/step
What of TensorFlow Federated, you ask? Sall What of TensorFlow Federated, you ask? Sall
yesterday. Received the Bailey." yesterday. Received the Bailey."
"Mr. Lorry, grimmering himself, or low varked thends the winter, and the eyes of Monsieur "Mr. Lorry, grimmering himself, or low varked thends the winter, and the eyes of Monsieur
Defarge. "Let his mind, hon in his Defarge. "Let his mind, hon in his
life and message; four declare life and message; four declare
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Load and Preprocess the Federated Shakespeare Data ## Load and Preprocess the Federated Shakespeare Data
The `tff.simulation.datasets` package provides a variety of datasets that are split into "clients", where each client corresponds to a dataset on a particular device that might participate in federated learning. The `tff.simulation.datasets` package provides a variety of datasets that are split into "clients", where each client corresponds to a dataset on a particular device that might participate in federated learning.
These datasets provide realistic non-IID data distributions that replicate in simulation the challenges of training on real decentralized data. Some of the pre-processing of this data was done using tools from the [Leaf project](https://arxiv.org/abs/1812.01097) ([github](https://github.com/TalwalkarLab/leaf)). These datasets provide realistic non-IID data distributions that replicate in simulation the challenges of training on real decentralized data. Some of the pre-processing of this data was done using tools from the [Leaf project](https://arxiv.org/abs/1812.01097) ([github](https://github.com/TalwalkarLab/leaf)).
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
train_data, test_data = tff.simulation.datasets.shakespeare.load_data() train_data, test_data = tff.simulation.datasets.shakespeare.load_data()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The datasets provided by `shakespeare.load_data()` consist of a sequence of The datasets provided by `shakespeare.load_data()` consist of a sequence of
string `Tensors`, one for each line spoken by a particular character in a string `Tensors`, one for each line spoken by a particular character in a
Shakespeare play. The client keys consist of the name of the play joined with Shakespeare play. The client keys consist of the name of the play joined with
the name of the character, so for example `MUCH_ADO_ABOUT_NOTHING_OTHELLO` corresponds to the lines for the character Othello in the play *Much Ado About Nothing*. Note that in a real federated learning scenario the name of the character, so for example `MUCH_ADO_ABOUT_NOTHING_OTHELLO` corresponds to the lines for the character Othello in the play *Much Ado About Nothing*. Note that in a real federated learning scenario
clients are never identified or tracked by ids, but for simulation it is useful clients are never identified or tracked by ids, but for simulation it is useful
to work with keyed datasets. to work with keyed datasets.
Here, for example, we can look at some data from King Lear: Here, for example, we can look at some data from King Lear:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
# Here the play is "The Tragedy of King Lear" and the character is "King". # Here the play is "The Tragedy of King Lear" and the character is "King".
raw_example_dataset = train_data.create_tf_dataset_for_client( raw_example_dataset = train_data.create_tf_dataset_for_client(
'THE_TRAGEDY_OF_KING_LEAR_KING') 'THE_TRAGEDY_OF_KING_LEAR_KING')
# To allow for future extensions, each entry x # To allow for future extensions, each entry x
# is an OrderedDict with a single key 'snippets' which contains the text. # is an OrderedDict with a single key 'snippets' which contains the text.
for x in raw_example_dataset.take(2): for x in raw_example_dataset.take(2):
print(x['snippets']) print(x['snippets'])
``` ```
%% Output %% Output
tf.Tensor(b'', shape=(), dtype=string) tf.Tensor(b'', shape=(), dtype=string)
tf.Tensor(b'What?', shape=(), dtype=string) tf.Tensor(b'What?', shape=(), dtype=string)
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We now use `tf.data.Dataset` transformations to prepare this data for training the char RNN loaded above. We now use `tf.data.Dataset` transformations to prepare this data for training the char RNN loaded above.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
# Input pre-processing parameters # Input pre-processing parameters
SEQ_LENGTH = 100 SEQ_LENGTH = 100
BATCH_SIZE = 8 BATCH_SIZE = 8
BUFFER_SIZE = 100 # For dataset shuffling BUFFER_SIZE = 100 # For dataset shuffling
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
# Construct a lookup table to map string chars to indexes, # Construct a lookup table to map string chars to indexes,
# using the vocab loaded above: # using the vocab loaded above:
table = tf.lookup.StaticHashTable( table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer( tf.lookup.KeyValueTensorInitializer(
keys=vocab, values=tf.constant(list(range(len(vocab))), keys=vocab, values=tf.constant(list(range(len(vocab))),
dtype=tf.int64)), dtype=tf.int64)),
default_value=0) default_value=0)
def to_ids(x): def to_ids(x):
s = tf.reshape(x['snippets'], shape=[1]) s = tf.reshape(x['snippets'], shape=[1])
chars = tf.strings.bytes_split(s).values chars = tf.strings.bytes_split(s).values
ids = table.lookup(chars) ids = table.lookup(chars)
return ids return ids
def split_input_target(chunk): def split_input_target(chunk):
input_text = tf.map_fn(lambda x: x[:-1], chunk) input_text = tf.map_fn(lambda x: x[:-1], chunk)
target_text = tf.map_fn(lambda x: x[1:], chunk) target_text = tf.map_fn(lambda x: x[1:], chunk)
return (input_text, target_text) return (input_text, target_text)
def preprocess(dataset): def preprocess(dataset):
return ( return (
# Map ASCII chars to int64 indexes using the vocab # Map ASCII chars to int64 indexes using the vocab
dataset.map(to_ids) dataset.map(to_ids)
# Split into individual chars # Split into individual chars
.unbatch() .unbatch()
# Form example sequences of SEQ_LENGTH +1 # Form example sequences of SEQ_LENGTH +1
.batch(SEQ_LENGTH + 1, drop_remainder=True) .batch(SEQ_LENGTH + 1, drop_remainder=True)
# Shuffle and form minibatches # Shuffle and form minibatches
.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True) .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
# And finally split into (input, target) tuples, # And finally split into (input, target) tuples,
# each of length SEQ_LENGTH. # each of length SEQ_LENGTH.
.map(split_input_target)) .map(split_input_target))
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Note that in the formation of the original sequences and in the formation of Note that in the formation of the original sequences and in the formation of
batches above, we use `drop_remainder=True` for simplicity. This means that any batches above, we use `drop_remainder=True` for simplicity. This means that any
characters (clients) that don't have at least `(SEQ_LENGTH + 1) * BATCH_SIZE` characters (clients) that don't have at least `(SEQ_LENGTH + 1) * BATCH_SIZE`
chars of text will have empty datasets. A typical approach to address this would chars of text will have empty datasets. A typical approach to address this would
be to pad the batches with a special token, and then mask the loss to not take be to pad the batches with a special token, and then mask the loss to not take
the padding tokens into account. the padding tokens into account.
This would complicate the example somewhat, so for this tutorial we only use full batches, as in the This would complicate the example somewhat, so for this tutorial we only use full batches, as in the
[standard tutorial](https://www.tensorflow.org/tutorials/sequences/text_generation). [standard tutorial](https://www.tensorflow.org/tutorials/sequences/text_generation).
However, in the federated setting this issue is more significant, because many However, in the federated setting this issue is more significant, because many
users might have small datasets. users might have small datasets.
Now we can preprocess our `raw_example_dataset`, and check the types: Now we can preprocess our `raw_example_dataset`, and check the types:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
example_dataset = preprocess(raw_example_dataset) example_dataset = preprocess(raw_example_dataset)
print(example_dataset.element_spec) print(example_dataset.element_spec)
``` ```
%% Output %% Output
(TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), TensorSpec(shape=(8, 100), dtype=tf.int64, name=None)) (TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), TensorSpec(shape=(8, 100), dtype=tf.int64, name=None))
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Compile the model and test on the preprocessed data ## Compile the model and test on the preprocessed data
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We loaded an uncompiled keras model, but in order to run `keras_model.evaluate`, we need to compile it with a loss and metrics. We will also compile in an optimizer, which will be used as the on-device optimizer in Federated Learning. We loaded an uncompiled keras model, but in order to run `keras_model.evaluate`, we need to compile it with a loss and metrics. We will also compile in an optimizer, which will be used as the on-device optimizer in Federated Learning.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The original tutorial didn't have char-level accuracy (the fraction The original tutorial didn't have char-level accuracy (the fraction
of predictions where the highest probability was put on the correct of predictions where the highest probability was put on the correct
next char). This is a useful metric, so we add it. next char). This is a useful metric, so we add it.
However, we need to define a new metric class for this because However, we need to define a new metric class for this because
our predictions have rank 3 (a vector of logits for each of the our predictions have rank 3 (a vector of logits for each of the
`BATCH_SIZE * SEQ_LENGTH` predictions), and `SparseCategoricalAccuracy` `BATCH_SIZE * SEQ_LENGTH` predictions), and `SparseCategoricalAccuracy`
expects only rank 2 predictions. expects only rank 2 predictions.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy): class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):
def __init__(self, name='accuracy', dtype=tf.float32): def __init__(self, name='accuracy', dtype=tf.float32):
super().__init__(name, dtype=dtype) super().__init__(name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None): def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.reshape(y_true, [-1, 1]) y_true = tf.reshape(y_true, [-1, 1])
y_pred = tf.reshape(y_pred, [-1, len(vocab), 1]) y_pred = tf.reshape(y_pred, [-1, len(vocab), 1])
return super().update_state(y_true, y_pred, sample_weight) return super().update_state(y_true, y_pred, sample_weight)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Now we can compile a model, and evaluate it on our `example_dataset`. Now we can compile a model, and evaluate it on our `example_dataset`.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
BATCH_SIZE = 8 # The training and eval batch size for the rest of this tutorial. BATCH_SIZE = 8 # The training and eval batch size for the rest of this tutorial.
keras_model = load_model(batch_size=BATCH_SIZE) keras_model = load_model(batch_size=BATCH_SIZE)
keras_model.compile( keras_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[FlattenedCategoricalAccuracy()]) metrics=[FlattenedCategoricalAccuracy()])
# Confirm that loss is much lower on Shakespeare than on random data # Confirm that loss is much lower on Shakespeare than on random data
loss, accuracy = keras_model.evaluate(example_dataset.take(5), verbose=0) loss, accuracy = keras_model.evaluate(example_dataset.take(5), verbose=0)
print( print(
'Evaluating on an example Shakespeare character: {a:3f}'.format(a=accuracy)) 'Evaluating on an example Shakespeare character: {a:3f}'.format(a=accuracy))
# As a sanity check, we can construct some completely random data, where we expect # As a sanity check, we can construct some completely random data, where we expect
# the accuracy to be essentially random: # the accuracy to be essentially random:
random_guessed_accuracy = 1.0 / len(vocab) random_guessed_accuracy = 1.0 / len(vocab)
print('Expected accuracy for random guessing: {a:.3f}'.format( print('Expected accuracy for random guessing: {a:.3f}'.format(
a=random_guessed_accuracy)) a=random_guessed_accuracy))
random_indexes = np.random.randint( random_indexes = np.random.randint(
low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1)) low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1))
data = collections.OrderedDict( data = collections.OrderedDict(
snippets=tf.constant( snippets=tf.constant(
''.join(np.array(vocab)[random_indexes]), shape=[1, 1])) ''.join(np.array(vocab)[random_indexes]), shape=[1, 1]))
random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data)) random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))
loss, accuracy = keras_model.evaluate(random_dataset, steps=10, verbose=0) loss, accuracy = keras_model.evaluate(random_dataset, steps=10, verbose=0)
print('Evaluating on completely random data: {a:.3f}'.format(a=accuracy)) print('Evaluating on completely random data: {a:.3f}'.format(a=accuracy))
``` ```
%% Output %% Output
Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel
16195584/16193984 [==============================] - 0s 0us/step 16195584/16193984 [==============================] - 0s 0us/step
16203776/16193984 [==============================] - 0s 0us/step 16203776/16193984 [==============================] - 0s 0us/step
Evaluating on an example Shakespeare character: 0.402000 Evaluating on an example Shakespeare character: 0.402000
Expected accuracy for random guessing: 0.012 Expected accuracy for random guessing: 0.012
Evaluating on completely random data: 0.011 Evaluating on completely random data: 0.011
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Fine-tune the model with Federated Learning ## Fine-tune the model with Federated Learning
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
TFF serializes all TensorFlow computations so they can potentially be run in a TFF serializes all TensorFlow computations so they can potentially be run in a
non-Python environment (even though at the moment, only a simulation runtime implemented in Python is available). Even though we are running in eager mode, (TF 2.0), currently TFF serializes TensorFlow computations by constructing the non-Python environment (even though at the moment, only a simulation runtime implemented in Python is available). Even though we are running in eager mode, (TF 2.0), currently TFF serializes TensorFlow computations by constructing the
necessary ops inside the context of a "`with tf.Graph.as_default()`" statement. necessary ops inside the context of a "`with tf.Graph.as_default()`" statement.
Thus, we need to provide a function that TFF can use to introduce our model into Thus, we need to provide a function that TFF can use to introduce our model into
a graph it controls. We do this as follows: a graph it controls. We do this as follows:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
# Clone the keras_model inside `create_tff_model()`, which TFF will # Clone the keras_model inside `create_tff_model()`, which TFF will
# call to produce a new copy of the model inside the graph that it will # call to produce a new copy of the model inside the graph that it will
# serialize. Note: we want to construct all the necessary objects we'll need # serialize. Note: we want to construct all the necessary objects we'll need
# _inside_ this method. # _inside_ this method.
def create_tff_model(): def create_tff_model():
# TFF uses an `input_spec` so it knows the types and shapes # TFF uses an `input_spec` so it knows the types and shapes
# that your model expects. # that your model expects.
input_spec = example_dataset.element_spec input_spec = example_dataset.element_spec
keras_model_clone = tf.keras.models.clone_model(keras_model) keras_model_clone = tf.keras.models.clone_model(keras_model)
return tff.learning.from_keras_model( return tff.learning.from_keras_model(
keras_model_clone, keras_model_clone,
input_spec=input_spec, input_spec=input_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[FlattenedCategoricalAccuracy()]) metrics=[FlattenedCategoricalAccuracy()])
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Now we are ready to construct a Federated Averaging iterative process, which we will use to improve the model (for details on the Federated Averaging algorithm, see the paper [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629)). Now we are ready to construct a Federated Averaging iterative process, which we will use to improve the model (for details on the Federated Averaging algorithm, see the paper [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629)).
We use a compiled Keras model to perform standard (non-federated) evaluation after each round of federated training. This is useful for research purposes when doing simulated federated learning and there is a standard test dataset. We use a compiled Keras model to perform standard (non-federated) evaluation after each round of federated training. This is useful for research purposes when doing simulated federated learning and there is a standard test dataset.
In a realistic production setting this same technique might be used to take models trained with federated learning and evaluate them on a centralized benchmark dataset for testing or quality assurance purposes. In a realistic production setting this same technique might be used to take models trained with federated learning and evaluate them on a centralized benchmark dataset for testing or quality assurance purposes.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
# This command builds all the TensorFlow graphs and serializes them: # This command builds all the TensorFlow graphs and serializes them:
fed_avg = tff.learning.build_federated_averaging_process( fed_avg = tff.learning.build_federated_averaging_process(
model_fn=create_tff_model, model_fn=create_tff_model,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=0.5)) client_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=0.5))
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Here is the simplest possible loop, where we run federated averaging for one round on a single client on a single batch: Here is the simplest possible loop, where we run federated averaging for one round on a single client on a single batch:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
state = fed_avg.initialize() state = fed_avg.initialize()
state, metrics = fed_avg.next(state, [example_dataset.take(5)]) state, metrics = fed_avg.next(state, [example_dataset.take(5)])
train_metrics = metrics['train'] train_metrics = metrics['train']
print('loss={l:.3f}, accuracy={a:.3f}'.format( print('loss={l:.3f}, accuracy={a:.3f}'.format(
l=train_metrics['loss'], a=train_metrics['accuracy'])) l=train_metrics['loss'], a=train_metrics['accuracy']))
``` ```
%% Output %% Output
loss=4.403, accuracy=0.132 loss=4.403, accuracy=0.132
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Now let's write a slightly more interesting training and evaluation loop. Now let's write a slightly more interesting training and evaluation loop.
So that this simulation still runs relatively quickly, we train on the same three clients each round, only considering two minibatches for each. So that this simulation still runs relatively quickly, we train on the same three clients each round, only considering two minibatches for each.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
def data(client, source=train_data): def data(client, source=train_data):
return preprocess(source.create_tf_dataset_for_client(client)).take(5) return preprocess(source.create_tf_dataset_for_client(client)).take(5)
clients = [ clients = [
'ALL_S_WELL_THAT_ENDS_WELL_CELIA', 'MUCH_ADO_ABOUT_NOTHING_OTHELLO', 'ALL_S_WELL_THAT_ENDS_WELL_CELIA', 'MUCH_ADO_ABOUT_NOTHING_OTHELLO',
] ]
train_datasets = [data(client) for client in clients] train_datasets = [data(client) for client in clients]
# We concatenate the test datasets for evaluation with Keras by creating a # We concatenate the test datasets for evaluation with Keras by creating a
# Dataset of Datasets, and then identity flat mapping across all the examples. # Dataset of Datasets, and then identity flat mapping across all the examples.
test_dataset = tf.data.Dataset.from_tensor_slices( test_dataset = tf.data.Dataset.from_tensor_slices(
[data(client, test_data) for client in clients]).flat_map(lambda x: x) [data(client, test_data) for client in clients]).flat_map(lambda x: x)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The initial state of the model produced by `fed_avg.initialize()` is based The initial state of the model produced by `fed_avg.initialize()` is based
on the random initializers for the Keras model, not the weights that were loaded, on the random initializers for the Keras model, not the weights that were loaded,
since `clone_model()` does not clone the weights. To start training since `clone_model()` does not clone the weights. To start training
from a pre-trained model, we set the model weights in the server state from a pre-trained model, we set the model weights in the server state
directly from the loaded model. directly from the loaded model.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
NUM_ROUNDS = 5 NUM_ROUNDS = 5
# The state of the FL server, containing the model and optimization state. # The state of the FL server, containing the model and optimization state.
state = fed_avg.initialize() state = fed_avg.initialize()
# Load our pre-trained Keras model weights into the global model state. # Load our pre-trained Keras model weights into the global model state.
state = tff.learning.state_with_new_model_weights( state = tff.learning.state_with_new_model_weights(
state, state,
trainable_weights=[v.numpy() for v in keras_model.trainable_weights], trainable_weights=[v.numpy() for v in keras_model.trainable_weights],
non_trainable_weights=[ non_trainable_weights=[
v.numpy() for v in keras_model.non_trainable_weights v.numpy() for v in keras_model.non_trainable_weights
]) ])
def keras_evaluate(state, round_num): def keras_evaluate(state, round_num):
# Take our global model weights and push them back into a Keras model to # Take our global model weights and push them back into a Keras model to
# use its standard `.evaluate()` method. # use its standard `.evaluate()` method.
keras_model = load_model(batch_size=BATCH_SIZE) keras_model = load_model(batch_size=BATCH_SIZE)
keras_model.compile( keras_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[FlattenedCategoricalAccuracy()]) metrics=[FlattenedCategoricalAccuracy()])
state.model.assign_weights_to(keras_model) state.model.assign_weights_to(keras_model)
loss, accuracy = keras_model.evaluate(example_dataset, steps=2, verbose=0) loss, accuracy = keras_model.evaluate(example_dataset, steps=2, verbose=0)
print('\tEval: loss={l:.3f}, accuracy={a:.3f}'.format(l=loss, a=accuracy)) print('\tEval: loss={l:.3f}, accuracy={a:.3f}'.format(l=loss, a=accuracy))
for round_num in range(NUM_ROUNDS): for round_num in range(NUM_ROUNDS):
print('Round {r}'.format(r=round_num)) print('Round {r}'.format(r=round_num))
keras_evaluate(state, round_num) keras_evaluate(state, round_num)
state, metrics = fed_avg.next(state, train_datasets) state, metrics = fed_avg.next(state, train_datasets)
train_metrics = metrics['train'] train_metrics = metrics['train']
print('\tTrain: loss={l:.3f}, accuracy={a:.3f}'.format( print('\tTrain: loss={l:.3f}, accuracy={a:.3f}'.format(
l=train_metrics['loss'], a=train_metrics['accuracy'])) l=train_metrics['loss'], a=train_metrics['accuracy']))
print('Final evaluation') print('Final evaluation')
keras_evaluate(state, NUM_ROUNDS + 1) keras_evaluate(state, NUM_ROUNDS + 1)
``` ```
%% Output %% Output
Round 0 Round 0
Eval: loss=3.324, accuracy=0.401 Eval: loss=3.324, accuracy=0.401
Train: loss=4.360, accuracy=0.155 Train: loss=4.360, accuracy=0.155
Round 1 Round 1
Eval: loss=4.361, accuracy=0.049 Eval: loss=4.361, accuracy=0.049
Train: loss=4.235, accuracy=0.164 Train: loss=4.235, accuracy=0.164
Round 2 Round 2
Eval: loss=4.219, accuracy=0.177 Eval: loss=4.219, accuracy=0.177
Train: loss=4.081, accuracy=0.221 Train: loss=4.081, accuracy=0.221
Round 3 Round 3
Eval: loss=4.080, accuracy=0.174 Eval: loss=4.080, accuracy=0.174
Train: loss=3.940, accuracy=0.226 Train: loss=3.940, accuracy=0.226
Round 4 Round 4
Eval: loss=3.991, accuracy=0.176 Eval: loss=3.991, accuracy=0.176
Train: loss=3.840, accuracy=0.226 Train: loss=3.840, accuracy=0.226
Final evaluation Final evaluation
Eval: loss=3.909, accuracy=0.171 Eval: loss=3.909, accuracy=0.171
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
With the default changes, we haven't done enough training to make a big difference, but if you train longer on more Shakespeare data, you should see a difference in the style of the text generated with the updated model: With the default changes, we haven't done enough training to make a big difference, but if you train longer on more Shakespeare data, you should see a difference in the style of the text generated with the updated model:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
# Set our newly trained weights back in the originally created model. # Set our newly trained weights back in the originally created model.
keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights]) keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights])
# Text generation requires batch_size=1 # Text generation requires batch_size=1
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? ')) print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
``` ```
%% Output %% Output
What of TensorFlow Federated, you ask? Shalways, I will call your What of TensorFlow Federated, you ask? Shalways, I will call your
compet with any city brought their faces uncompany," besumed him. "When he compet with any city brought their faces uncompany," besumed him. "When he
sticked Madame Defarge pushed the lamps. sticked Madame Defarge pushed the lamps.
"Have I often but no unison. She had probably come, "Have I often but no unison. She had probably come,
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Suggested extensions ## Suggested extensions
This tutorial is just the first step! Here are some ideas for how you might try extending this notebook: This tutorial is just the first step! Here are some ideas for how you might try extending this notebook:
* Write a more realistic training loop where you sample clients to train on randomly. * Write a more realistic training loop where you sample clients to train on randomly.
* Use "`.repeat(NUM_EPOCHS)`" on the client datasets to try multiple epochs of local training (e.g., as in [McMahan et. al.](https://arxiv.org/abs/1602.05629)). See also [Federated Learning for Image Classification](federated_learning_for_image_classification.ipynb) which does this. * Use "`.repeat(NUM_EPOCHS)`" on the client datasets to try multiple epochs of local training (e.g., as in [McMahan et. al.](https://arxiv.org/abs/1602.05629)). See also [Federated Learning for Image Classification](federated_learning_for_image_classification.ipynb) which does this.
* Change the `compile()` command to experiment with using different optimization algorithms on the client. * Change the `compile()` command to experiment with using different optimization algorithms on the client.
* Try the `server_optimizer` argument to `build_federated_averaging_process` to try different algorithms for applying the model updates on the server. * Try the `server_optimizer` argument to `build_federated_averaging_process` to try different algorithms for applying the model updates on the server.
* Try the `client_weight_fn` argument to to `build_federated_averaging_process` to try different weightings of the clients. The default weights client updates by the number of examples on the client, but you can do e.g. `client_weight_fn=lambda _: tf.constant(1.0)`. * Try the `client_weight_fn` argument to to `build_federated_averaging_process` to try different weightings of the clients. The default weights client updates by the number of examples on the client, but you can do e.g. `client_weight_fn=lambda _: tf.constant(1.0)`.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment