Skip to content
Snippets Groups Projects
Commit 7f6cd06e authored by Zachary Garrett's avatar Zachary Garrett Committed by tensorflow-copybara
Browse files

Update tutorial notebooks to TF 2.0

PiperOrigin-RevId: 273296155
parent 8cff0bf8
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
##### Copyright 2019 The TensorFlow Authors.
%% Cell type:code id: tags:
```
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```
%% Cell type:markdown id: tags:
# Custom Federated Algorithms, Part 2: Implementing Federated Averaging
%% Cell type:markdown id: tags:
<table class="tfo-notebook-buttons" align="left">
<td>
<a target="_blank" href="https://www.tensorflow.org/federated/tutorials/custom_federated_algorithms_2"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
</td>
<td>
<a target="_blank" href="https://colab.research.google.com/github/tensorflow/federated/blob/v0.8.0/docs/tutorials/custom_federated_algorithms_2.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
</td>
<td>
<a target="_blank" href="https://github.com/tensorflow/federated/blob/v0.8.0/docs/tutorials/custom_federated_algorithms_2.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
</td>
</table>
%% Cell type:markdown id: tags:
This tutorial is the second part of a two-part series that demonstrates how to
implement custom types of federated algorithms in TFF using the
[Federated Core (FC)](../federated_core.md), which serves as a foundation for
the [Federated Learning (FL)](../federated_learning.md) layer (`tff.learning`).
We encourage you to first read the
[first part of this series](custom_federated_algorithms_1.ipynb), which
introduce some of the key concepts and programming abstractions used here.
This second part of the series uses the mechanisms introduced in the first part
to implement a simple version of federated training and evaluation algorithms.
We encourage you to review the
[image classification](federated_learning_for_image_classification.ipynb) and
[text generation](federated_learning_for_text_generation.ipynb) tutorials for a
higher-level and more gentle introduction to TFF's Federated Learning APIs, as
they will help you put the concepts we describe here in context.
%% Cell type:markdown id: tags:
## Before we start
Before we start, try to run the following "Hello World" example to make sure
your environment is correctly setup. If it doesn't work, please refer to the
[Installation](../install.md) guide for instructions.
%% Cell type:code id: tags:
```
# NOTE: If you are running a Jupyter notebook, and installing a locally built
# pip package, you may need to edit the following to point to the '.whl' file
# on your local filesystem.
!pip install --quiet --upgrade tensorflow_federated
!pip install --quiet --upgrade tf-nightly
```
%% Cell type:code id: tags:
```
from __future__ import absolute_import, division, print_function
import collections
import numpy as np
from six.moves import range
import tensorflow as tf
import tensorflow_federated as tff
tf.compat.v1.enable_v2_behavior()
```
%% Cell type:code id: tags:
```
@tff.federated_computation
def hello_world():
return 'Hello, World!'
hello_world()
```
%% Output
'Hello, World!'
%% Cell type:markdown id: tags:
## Implementing Federated Averaging
As in
[Federated Learning for Image Classification](federated_learning_for_image_classification.md),
we are going to use the MNIST example, but since this is intended as a low-level
tutorial, we are going to bypass the Keras API and `tff.simulation`, write raw
model code, and construct a federated data set from scratch.
%% Cell type:markdown id: tags:
### Preparing federated data sets
For the sake of a demonstration, we're going to simulate a scenario in which we
have data from 10 users, and each of the users contributes knowledge how to
recognize a different digit. This is about as
non-[i.i.d.](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables)
as it gets.
First, let's load the standard MNIST data:
%% Cell type:code id: tags:
```
mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
```
%% Cell type:code id: tags:
```
[(x.dtype, x.shape) for x in mnist_train]
```
%% Output
[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]
%% Cell type:markdown id: tags:
The data comes as Numpy arrays, one with images and another with digit labels, both
with the first dimension going over the individual examples. Let's write a
helper function that formats it in a way compatible with how we feed federated
sequences into TFF computations, i.e., as a list of lists - the outer list
ranging over the users (digits), the inner ones ranging over batches of data in
each client's sequence. As is customary, we will structure each batch as a pair
of tensors named `x` and `y`, each with the leading batch dimension. While at
it, we'll also flatten each image into a 784-element vector and rescale the
pixels in it into the `0..1` range, so that we don't have to clutter the model
logic with data conversions.
%% Cell type:code id: tags:
```
NUM_EXAMPLES_PER_USER = 1000
BATCH_SIZE = 100
def get_data_for_digit(source, digit):
output_sequence = []
all_samples = [i for i, d in enumerate(source[1]) if d == digit]
for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):
batch_samples = all_samples[i:i + BATCH_SIZE]
output_sequence.append({
'x': np.array([source[0][i].flatten() / 255.0 for i in batch_samples],
dtype=np.float32),
'y': np.array([source[1][i] for i in batch_samples], dtype=np.int32)})
return output_sequence
federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]
federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]
```
%% Cell type:markdown id: tags:
As a quick sanity check, let's look at the `Y` tensor in the last batch of data
contributed by the fifth client (the one corresponding to the digit `5`).
%% Cell type:code id: tags:
```
federated_train_data[5][-1]['y']
```
%% Output
array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)
%% Cell type:markdown id: tags:
Just to be sure, let's also look at the image corresponding to the last element of that batch.
%% Cell type:code id: tags:
```
from matplotlib import pyplot as plt
plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')
plt.grid(False)
plt.show()
```
%% Output
%% Cell type:markdown id: tags:
### On combining TensorFlow and TFF
In this tutorial, for compactness we immediately decorate functions that
introduce TensorFlow logic with `tff.tf_computation`. However, for more complex
logic, this is not the pattern we recommend. Debugging TensorFlow can already be
a challenge, and debugging TensorFlow after it has been fully serialized and
then re-imported necessarily loses some metadata and limits interactivity,
making debugging even more of a challenge.
Therefore, **we strongly recommend writing complex TF logic as stand-alone
Python functions** (that is, without `tff.tf_computation` decoration). This way
the TensorFlow logic can be developed and tested using TF best practices and
tools (like eager mode), before serializing the computation for TFF (e.g., by invoking `tff.tf_computation` with a Python function as the argument).
%% Cell type:markdown id: tags:
### Defining a loss function
Now that we have the data, let's define a loss function that we can use for
training. First, let's define the type of input as a TFF named tuple. Since the
size of data batches may vary, we set the batch dimension to `None` to indicate
that the size of this dimension is unknown.
%% Cell type:code id: tags:
```
BATCH_SPEC = collections.OrderedDict([
('x', tf.TensorSpec(shape=[None, 784], dtype=tf.float32)),
('y', tf.TensorSpec(shape=[None], dtype=tf.int32)),
])
BATCH_TYPE = tff.to_type(BATCH_SPEC)
str(BATCH_TYPE)
```
%% Output
'<x=float32[?,784],y=int32[?]>'
%% Cell type:markdown id: tags:
You may be wondering why we can't just define an ordinary Python type. Recall
the discussion in [part 1](custom_federated_algorithms_1.ipynb), where we
explained that while we can express the logic of TFF computations using Python,
under the hood TFF computations *are not* Python. The symbol `BATCH_TYPE`
defined above represents an abstract TFF type specification. It is important to
distinguish this *abstract* TFF type from concrete Python *representation*
types, e.g., containers such as `dict` or `collections.namedtuple` that may be
used to represent the TFF type in the body of a Python function. Unlike Python,
TFF has a single abstract type constructor `tff.NamedTupleType` for tuple-like
containers, with elements that can be individually named or left unnamed. This
type is also used to model formal parameters of computations, as TFF
computations can formally only declare one parameter and one result - you will
see examples of this shortly.
Let's now define the TFF type of model parameters, again as a TFF named tuple of
*weights* and *bias*.
%% Cell type:code id: tags:
```
MODEL_SPEC = collections.OrderedDict([
('weights', tf.TensorSpec(shape=[784, 10], dtype=tf.float32)),
('bias', tf.TensorSpec(shape=[10], dtype=tf.float32)),
])
MODEL_TYPE = tff.to_type(MODEL_SPEC)
print(MODEL_TYPE)
```
%% Output
<weights=float32[784,10],bias=float32[10]>
%% Cell type:markdown id: tags:
With those definitions in place, now we can define the loss for a given model, over
a single batch. Note how in the body of `batch_loss`, we access named tuple
elements using the dot (`X.Y`) notation, as is standard for TFF.
%% Cell type:code id: tags:
```
# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can
# be later called from within another tf.function. Necessary because a
# @tf.function decorated method cannot invoke a @tff.tf_computation.
@tf.function
def forward_pass(model, batch):
predicted_y = tf.nn.softmax(
tf.matmul(batch['x'], model['weights']) + model['bias'])
return -tf.reduce_mean(
tf.reduce_sum(
tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))
@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
return batch_loss(model, batch)
return forward_pass(model, batch)
```
%% Cell type:markdown id: tags:
As expected, computation `batch_loss` returns `float32` loss given the model and
a single data batch. Note how the `MODEL_TYPE` and `BATCH_TYPE` have been lumped
together into a 2-tuple of formal parameters; you can recognize the type of
`batch_loss` as `(<MODEL_TYPE,BATCH_TYPE> -> float32)`.
%% Cell type:code id: tags:
```
str(batch_loss.type_signature)
```
%% Output
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>> -> float32)'
%% Cell type:markdown id: tags:
As a sanity check, let's construct an initial model filled with zeros and
compute the loss over the batch of data we visualized above.
%% Cell type:code id: tags:
```
initial_model = {
'weights': np.zeros([784, 10], dtype=np.float32),
'bias': np.zeros([10], dtype=np.float32)
}
sample_batch = federated_train_data[5][-1]
batch_loss(initial_model, sample_batch)
```
%% Output
2.3025854
%% Cell type:markdown id: tags:
Note that we feed the TFF computation with the initial model defined as a
`dict`, even though the body of the Python function that defines it consumes
model parameters as `model.weight` and `model.bias`. The arguments of the call
to `batch_loss` aren't simply passed to the body of that function.
What happens when we invoke `batch_loss`?
The Python body of `batch_loss` has already been traced and serialized in the above cell where it was defined. TFF acts as the caller to `batch_loss`
at the computation definition time, and as the target of invocation at the time
`batch_loss` is invoked. In both roles, TFF serves as the bridge between TFF's
abstract type system and Python representation types. At the invocation time,
TFF will accept most standard Python container types (`dict`, `list`, `tuple`,
`collections.namedtuple`, etc.) as concrete representations of abstract TFF
tuples. Also, although as noted above, TFF computations formally only accept a
single parameter, you can use the familiar Python call syntax with positional
and/or keyword arguments in case where the type of the parameter is a tuple - it
works as expected.
%% Cell type:markdown id: tags:
### Gradient descent on a single batch
Now, let's define a computation that uses this loss function to perform a single
step of gradient descent. Note how in defining this function, we use
`batch_loss` as a subcomponent. You can invoke a computation constructed with
`tff.tf_computation` inside the body of another computation, though typically
this is not necessary - as noted above, because serialization looses some
debugging information, it is often preferable for more complex computations to
write and test all the TensorFlow without the `tff.tf_computation` decorator.
%% Cell type:code id: tags:
```
@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
def batch_train(initial_model, batch, learning_rate):
# Define a group of model variables and set them to `initial_model`.
model_vars = collections.OrderedDict([
(name, tf.Variable(name=name, initial_value=value))
for name, value in initial_model.items()
])
@tf.function
def _train_on_batch(model_vars, batch):
# Perform one step of gradient descent using loss from `batch_loss`.
optimizer = tf.keras.optimizers.SGD(learning_rate)
with tf.GradientTape() as tape:
loss = forward_pass(model_vars, batch)
grads = tape.gradient(loss, model_vars)
optimizer.apply_gradients(
zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
return model_vars
return _train_on_batch(model_vars, batch)
```
%% Cell type:code id: tags:
```
str(batch_train.type_signature)
```
%% Output
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>,float32> -> <weights=float32[784,10],bias=float32[10]>)'
%% Cell type:markdown id: tags:
When you invoke a Python function decorated with `tff.tf_computation` within the
body of another such function, the logic of the inner TFF computation is
embedded (essentially, inlined) in the logic of the outer one. As noted above,
if you are writing both computations, it is likely preferable to make the inner
function (`batch_loss` in this case) a regular Python or `tf.function` rather
than a `tff.tf_computation`. However, here we illustrate that calling one
`tff.tf_computation` inside another basically works as expected. This may be
necessary if, for example, you do not have the Python code defining
`batch_loss`, but only its serialized TFF representation.
Now, let's apply this function a few times to the initial model to see whether
the loss decreases.
%% Cell type:code id: tags:
```
model = initial_model
losses = []
for _ in range(5):
model = batch_train(model, sample_batch, 0.1)
losses.append(batch_loss(model, sample_batch))
```
%% Cell type:code id: tags:
```
losses
```
%% Output
[0.19690022, 0.13176313, 0.10113226, 0.082738124, 0.0703014]
%% Cell type:markdown id: tags:
### Gradient descent on a sequence of local data
Now, since `batch_train` appears to work, let's write a similar training
function `local_train` that consumes the entire sequence of all batches from one
user instead of just a single batch. The new computation will need to now
consume `tff.SequenceType(BATCH_TYPE)` instead of `BATCH_TYPE`.
%% Cell type:code id: tags:
```
LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)
@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)
def local_train(initial_model, learning_rate, all_batches):
# Mapping function to apply to each batch.
@tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
def batch_fn(model, batch):
return batch_train(model, batch, learning_rate)
return tff.sequence_reduce(all_batches, initial_model, batch_fn)
```
%% Cell type:code id: tags:
```
str(local_train.type_signature)
```
%% Output
'(<<weights=float32[784,10],bias=float32[10]>,float32,<x=float32[?,784],y=int32[?]>*> -> <weights=float32[784,10],bias=float32[10]>)'
%% Cell type:markdown id: tags:
There are quite a few details buried in this short section of code, let's go
over them one by one.
First, while we could have implemented this logic entirely in TensorFlow,
relying on `tf.data.Dataset.reduce` to process the sequence similarly to how
we've done it earlier, we've opted this time to express the logic in the glue
language, as a `tff.federated_computation`. We've used the federated operator
`tff.sequence_reduce` to perform the reduction.
The operator `tff.sequence_reduce` is used similarly to
`tf.data.Dataset.reduce`. You can think of it as essentially the same as
`tf.data.Dataset.reduce`, but for use inside federated computations, which as
you may remember, cannot contain TensorFlow code. It is a template operator with
a formal parameter 3-tuple that consists of a *sequence* of `T`-typed elements,
the initial state of the reduction (we'll refer to it abstractly as *zero*) of
some type `U`, and the *reduction operator* of type `(<U,T> -> U)` that alters the
state of the reduction by processing a single element. The result is the final
state of the reduction, after processing all elements in a sequential order. In
our example, the state of the reduction is the model trained on a prefix of the
data, and the elements are data batches.
Second, note that we have again used one computation (`batch_train`) as a
component within another (`local_train`), but not directly. We can't use it as a
reduction operator because it takes an additional parameter - the learning rate.
To resolve this, we define an embedded federated computation `batch_fn` that
binds to the `local_train`'s parameter `learning_rate` in its body. It is
allowed for a child computation defined this way to capture a formal parameter
of its parent as long as the child computation is not invoked outside the body
of its parent. You can think of this pattern as an equivalent of
`functools.partial` in Python.
The practical implication of capturing `learning_rate` this way is, of course,
that the same learning rate value is used across all batches.
Now, let's try the newly defined local training function on the entire sequence
of data from the same user who contributed the sample batch (digit `5`).
%% Cell type:code id: tags:
```
locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])
```
%% Cell type:markdown id: tags:
Did it work? To answer this question, we need to implement evaluation.
%% Cell type:markdown id: tags:
### Local evaluation
Here's one way to implement local evaluation by adding up the losses across all data
batches (we could have just as well computed the average; we'll leave it as an
exercise for the reader).
%% Cell type:code id: tags:
```
@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):
# TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.
return tff.sequence_sum(
tff.sequence_map(
tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),
all_batches))
```
%% Cell type:code id: tags:
```
str(local_eval.type_signature)
```
%% Output
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>*> -> float32)'
%% Cell type:markdown id: tags:
Again, there are a few new elements illustrated by this code, let's go over them
one by one.
First, we have used two new federated operators for processing sequences:
`tff.sequence_map` that takes a *mapping function* `T->U` and a *sequence* of
`T`, and emits a sequence of `U` obtained by applying the mapping function
pointwise, and `tff.sequence_sum` that just adds all the elements. Here, we map
each data batch to a loss value, and then add the resulting loss values to
compute the total loss.
Note that we could have again used `tff.sequence_reduce`, but this wouldn't be
the best choice - the reduction process is, by definition, sequential, whereas
the mapping and sum can be computed in parallel. When given a choice, it's best
to stick with operators that don't constrain implementation choices, so that
when our TFF computation is compiled in the future to be deployed to a specific
environment, one can take full advantage of all potential opportunities for a
faster, more scalable, more resource-efficient execution.
Second, note that just as in `local_train`, the component function we need
(`batch_loss`) takes more parameters than what the federated operator
(`tff.sequence_map`) expects, so we again define a partial, this time inline by
directly wrapping a `lambda` as a `tff.federated_computation`. Using wrappers
inline with a function as an argument is the recommended way to use
`tff.tf_computation` to embed TensorFlow logic in TFF.
Now, let's see whether our training worked.
%% Cell type:code id: tags:
```
print('initial_model loss =', local_eval(initial_model, federated_train_data[5]))
print('locally_trained_model loss =', local_eval(locally_trained_model, federated_train_data[5]))
```
%% Output
initial_model loss = 23.025854
locally_trained_model loss = 0.4348469
%% Cell type:markdown id: tags:
Indeed, the loss decreased. But what happens if we evaluated it on another
user's data?
%% Cell type:code id: tags:
```
print('initial_model loss =', local_eval(initial_model, federated_train_data[0]))
print('locally_trained_model loss =', local_eval(locally_trained_model, federated_train_data[0]))
```
%% Output
initial_model loss = 23.025854
locally_trained_model loss = 74.50075
%% Cell type:markdown id: tags:
As expected, things got worse. The model was trained to recognize `5`, and has
never seen a `0`. This brings the question - how did the local training impact
the quality of the model from the global perspective?
%% Cell type:markdown id: tags:
### Federated evaluation
This is the point in our journey where we finally circle back to federated types
and federated computations - the topic that we started with. Here's a pair of
TFF types definitions for the model that originates at the server, and the data
that remains on the clients.
%% Cell type:code id: tags:
```
SERVER_MODEL_TYPE = tff.FederatedType(MODEL_TYPE, tff.SERVER, all_equal=True)
CLIENT_DATA_TYPE = tff.FederatedType(LOCAL_DATA_TYPE, tff.CLIENTS)
```
%% Cell type:markdown id: tags:
With all the definitions introduced so far, expressing federated evaluation in
TFF is a one-liner - we distribute the model to clients, let each client invoke
local evaluation on its local portion of data, and then average out the loss.
Here's one way to write this.
%% Cell type:code id: tags:
```
@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def federated_eval(model, data):
return tff.federated_mean(
tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))
```
%% Cell type:markdown id: tags:
We've already seen examples of `tff.federated_mean` and `tff.federated_map`
in simpler scenarios, and at the intuitive level, they work as expected, but
there's more in this section of code than meets the eye, so let's go over it
carefully.
First, let's break down the *let each client invoke local evaluation on its
local portion of data* part. As you may recall from the preceding sections,
`local_eval` has a type signature of the form `(<MODEL_TYPE, LOCAL_DATA_TYPE> ->
float32)`.
The federated operator `tff.federated_map` is a template that accepts as a
parameter a 2-tuple that consists of the *mapping function* of some type `T->U`
and a federated value of type `{T}@CLIENTS` (i.e., with member constituents of
the same type as the parameter of the mapping function), and returns a result of
type `{U}@CLIENTS`.
Since we're feeding `local_eval` as a mapping function to apply on a per-client
basis, the second argument should be of a federated type `{<MODEL_TYPE,
LOCAL_DATA_TYPE>}@CLIENTS`, i.e., in the nomenclature of the preceding sections,
it should be a federated tuple. Each client should hold a full set of arguments
for `local_eval` as a member consituent. Instead, we're feeding it a 2-element
Python `list`. What's happening here?
Indeed, this is an example of an *implicit type cast* in TFF, similar to
implicit type casts you may have encountered elsewhere, e.g., when you feed an
`int` to a function that accepts a `float`. Implicit casting is used scarcily at
this point, but we plan to make it more pervasive in TFF as a way to minimize
boilerplate.
The implicit cast that's applied in this case is the equivalence between
federated tuples of the form `{<X,Y>}@Z`, and tuples of federated values
`<{X}@Z,{Y}@Z>`. While formally, these two are different type signatures,
looking at it from the programmers's perspective, each device in `Z` holds two
units of data `X` and `Y`. What happens here is not unlike `zip` in Python, and
indeed, we offer an operator `tff.federated_zip` that allows you to perform such
conversions explicity. When the `tff.federated_map` encounters a tuple as a
second argument, it simply invokes `tff.federated_zip` for you.
Given the above, you should now be able to recognize the expression
`tff.federated_broadcast(model)` as representing a value of TFF type
`{MODEL_TYPE}@CLIENTS`, and `data` as a value of TFF type
`{LOCAL_DATA_TYPE}@CLIENTS` (or simply `CLIENT_DATA_TYPE`), the two getting
filtered together through an implicit `tff.federated_zip` to form the second
argument to `tff.federated_map`.
The operator `tff.federated_broadcast`, as you'd expect, simply transfers data
from the server to the clients.
Now, let's see how our local training affected the average loss in the system.
%% Cell type:code id: tags:
```
print('initial_model loss =', federated_eval(initial_model, federated_train_data))
print('locally_trained_model loss =', federated_eval(locally_trained_model, federated_train_data))
```
%% Output
initial_model loss = 23.025852
locally_trained_model loss = 54.432625
%% Cell type:markdown id: tags:
Indeed, as expected, the loss has increased. In order to improve the model for
all users, we'll need to train in on everyone's data.
%% Cell type:markdown id: tags:
### Federated training
The simplest way to implement federated training is to locally train, and then
average the models. This uses the same building blocks and patters we've already
discussed, as you can see below.
%% Cell type:code id: tags:
```
SERVER_FLOAT_TYPE = tff.FederatedType(tf.float32, tff.SERVER, all_equal=True)
@tff.federated_computation(
SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE, CLIENT_DATA_TYPE)
def federated_train(model, learning_rate, data):
return tff.federated_mean(
tff.federated_map(
local_train,
[tff.federated_broadcast(model),
tff.federated_broadcast(learning_rate),
data]))
```
%% Cell type:markdown id: tags:
Note that in the full-featured implementation of Federated Averaging provided by
`tff.learning`, rather than averaging the models, we prefer to average model
deltas, for a number of reasons, e.g., the ability to clip the update norms,
for compression, etc.
Let's see whether the training works by running a few rounds of training and
comparing the average loss before and after.
%% Cell type:code id: tags:
```
model = initial_model
learning_rate = 0.1
for round_num in range(5):
model = federated_train(model, learning_rate, federated_train_data)
learning_rate = learning_rate * 0.9
loss = federated_eval(model, federated_train_data)
print('round {}, loss={}'.format(round_num, loss))
```
%% Output
round 0, loss=21.6055240631
round 1, loss=20.3656787872
round 2, loss=19.2748012543
round 3, loss=18.3111095428
round 4, loss=17.4572544098
%% Cell type:markdown id: tags:
For completeness, let's now also run on the test data to confirm that our model
generalizes well.
%% Cell type:code id: tags:
```
print('initial_model test loss =', federated_eval(initial_model, federated_test_data))
print('trained_model test loss =', federated_eval(model, federated_test_data))
```
%% Output
initial_model test loss = 22.795593
trained_model test loss = 17.278767
%% Cell type:markdown id: tags:
This concludes our tutorial.
Of course, our simplified example doesn't reflect a number of things you'd need
to do in a more realistic scenario - for example, we haven't computed metrics
other than loss. We encourage you to study
[the implementation](https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/federated_averaging.py)
of federated averaging in `tff.learning` as a more complete example, and as a
way to demonstrate some of the coding practices we'd like to encourage.
......
%% Cell type:markdown id: tags:
##### Copyright 2019 The TensorFlow Authors.
%% Cell type:code id: tags:
```
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```
%% Cell type:markdown id: tags:
# Federated Learning for Text Generation
%% Cell type:markdown id: tags:
<table class="tfo-notebook-buttons" align="left">
<td>
<a target="_blank" href="https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
</td>
<td>
<a target="_blank" href="https://colab.research.google.com/github/tensorflow/federated/blob/v0.8.0/docs/tutorials/federated_learning_for_text_generation.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
</td>
<td>
<a target="_blank" href="https://github.com/tensorflow/federated/blob/v0.8.0/docs/tutorials/federated_learning_for_text_generation.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
</td>
</table>
%% Cell type:markdown id: tags:
**NOTE**: This colab has been verified to work with the `0.7.0` version of the `tensorflow_federated` pip package, but the Tensorflow Federated project is still in pre-release development and may not work on `master`.
**NOTE**: This colab has been verified to work with the `0.8.0` version of the `tensorflow_federated` pip package, but the Tensorflow Federated project is still in pre-release development and may not work on `master`.
This tutorial builds on the concepts in the [Federated Learning for Image Classification](federated_learning_for_image_classification.md) tutorial, and demonstrates several other useful approaches for federated learning.
In particular, we load a previously trained Keras model, and refine it using federated training on a (simulated) decentralized dataset. This is practically important for several reasons . The ability to use serialized models makes it easy to mix federated learning with other ML approaches. Further, this allows use of an increasing range of pre-trained models --- for example, training language models from scratch is rarely necessary, as numerous pre-trained models are now widely available (see, e.g., [TF Hub](https://www.tensorflow.org/hub)). Instead, it makes more sense to start from a pre-trained model, and refine it using Federated Learning, adapting to the particular characteristics of the decentralized data for a particular application.
For this tutorial, we start with a RNN that generates ASCII characters, and refine it via federated learning. We also show how the final weights can be fed back to the original Keras model, allowing easy evaluation and text generation using standard tools.
%% Cell type:code id: tags:
```
# NOTE: If you are running a Jupyter notebook, and installing a locally built
# pip package, you may need to edit the following to point to the '.whl' file
# on your local filesystem.
!pip install --quiet tensorflow_federated
!pip install --quiet tf-nightly
# NOTE: Jupyter requires a patch to asyncio.
!pip install --upgrade nest_asyncio
import nest_asyncio
nest_asyncio.apply()
```
%% Cell type:code id: tags:
```
from __future__ import absolute_import, division, print_function
import collections
import functools
import os
import six
import time
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
tf.compat.v1.enable_v2_behavior()
np.random.seed(0)
# Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')()
```
%% Output
'Hello, World!'
%% Cell type:markdown id: tags:
# Load a pre-trained model
We load a model that was pre-trained following the TensorFlow tutorial
[Text generation using a RNN with eager execution](https://www.tensorflow.org/tutorials/sequences/text_generation). However,
rather than training on [The Complete Works of Shakespeare](http://www.gutenberg.org/files/100/100-0.txt), we pre-trained the model on the text from the Charles Dickens'
[A Tale of Two Cities](http://www.ibiblio.org/pub/docs/books/gutenberg/9/98/98.txt)
and
[A Christmas Carol](http://www.ibiblio.org/pub/docs/books/gutenberg/4/46/46.txt).
Other than expanding the vocabularly, we didn't modify the original tutorial, so this initial model isn't state-of-the-art, but it produces reasonable predictions and is sufficient for our tutorial purposes. The final model was saved with `tf.keras.models.save_model(include_optimizer=False)`.
We will use federated learning to fine-tune this model for Shakespeare in this tutorial, using a federated version of the data provided by TFF.
%% Cell type:markdown id: tags:
## Generate the vocab lookup tables
%% Cell type:code id: tags:
```
# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:
vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r')
# Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
```
%% Cell type:markdown id: tags:
## Load the pre-trained model and generate some text
%% Cell type:code id: tags:
```
def load_model(batch_size):
urls = {
1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',
8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}
assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())
url = urls[batch_size]
local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)
return tf.keras.models.load_model(local_file, compile=False)
```
%% Cell type:code id: tags:
```
def generate_text(model, start_string):
# From https://www.tensorflow.org/tutorials/sequences/text_generation
num_generate = 200
input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)
text_generated = []
temperature = 1.0
model.reset_states()
for i in range(num_generate):
predictions = model(input_eval)
predictions = tf.squeeze(predictions, 0)
predictions = predictions / temperature
predicted_id = tf.multinomial(predictions, num_samples=1)[-1,0].numpy()
predicted_id = tf.random.categorical(
predictions, num_samples=1)[-1, 0].numpy()
input_eval = tf.expand_dims([predicted_id], 0)
text_generated.append(idx2char[predicted_id])
return (start_string + ''.join(text_generated))
```
%% Cell type:code id: tags:
```
# Text generation requires a batch_size=1 model.
keras_model_batch1 = load_model(batch_size=1)
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
```
%% Output
What of TensorFlow Federated, you ask? Say what kind then brought no
confidence in any occupied air; "I was writting out at Paris.
Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel
16195584/16193984 [==============================] - 0s 0us/step
16203776/16193984 [==============================] - 0s 0us/step
What of TensorFlow Federated, you ask? Stryver, seemed, unaternight,
Fruncied eyebrows at his forgery and the rest of its contempt.
"I arrived at Mr. Lorry he had compended till
your hand miss, to such paper to death himself against th
Mr. Cruncher had made his opportunity outsidere Stryver, even had some
fardens--natured impossible wor
%% Cell type:markdown id: tags:
# Load and Preprocess the Federated Shakespeare Data
The `tff.simulation.datasets` package provides a variety of datasets that are split into "clients", where each client corresponds to a dataset on a particular device that might participate in federated learning.
These datasets provide realistic non-IID data distributions that replicate in simulation the challenges of training on real decentralized data. Some of the pre-processing of this data was done using tools from the [Leaf project](https://arxiv.org/abs/1812.01097) ([github](https://github.com/TalwalkarLab/leaf)).
%% Cell type:code id: tags:
```
train_data, test_data = tff.simulation.datasets.shakespeare.load_data()
```
%% Cell type:markdown id: tags:
The datasets provided by `shakespeare.load_data()` consist of a sequence of
string `Tensors`, one for each line spoken by a particular character in a
Shakespeare play. The client keys consist of the name of the play joined with
the name of the character, so for example `MUCH_ADO_ABOUT_NOTHING_OTHELLO` corresponds to the lines for the character Othello in the play *Much Ado About Nothing*. Note that in a real federated learning scenario
clients are never identified or tracked by ids, but for simulation it is useful
to work with keyed datasets.
Here, for example, we can look at some data from King Lear:
%% Cell type:code id: tags:
```
# Here the play is "The Tragedy of King Lear" and the character is "King".
raw_example_dataset = train_data.create_tf_dataset_for_client(
'THE_TRAGEDY_OF_KING_LEAR_KING')
# To allow for future extensions, each entry x
# is an OrderedDict with a single key 'snippets' which contains the text.
for x in raw_example_dataset.take(2):
print(x['snippets'])
```
%% Output
tf.Tensor(b'', shape=(), dtype=string)
tf.Tensor(b'What?', shape=(), dtype=string)
tf.Tensor(, shape=(), dtype=string)
tf.Tensor(What?, shape=(), dtype=string)
%% Cell type:markdown id: tags:
We now use `tf.data.Dataset` transformations to prepare this data for training the char RNN loaded above.
%% Cell type:code id: tags:
```
# Input pre-processing parameters
SEQ_LENGTH = 100
BATCH_SIZE = 8
BUFFER_SIZE = 10000 # For dataset shuffling
```
%% Cell type:code id: tags:
```
# Using a namedtuple with keys x and y as the output type of the
# dataset keeps both TFF and Keras happy:
BatchType = collections.namedtuple('BatchType', ['x', 'y'])
# Construct a lookup table to map string chars to indexes,
# using the vocab loaded above:
table = tf.contrib.lookup.index_table_from_tensor(
mapping=vocab,
num_oov_buckets=0,
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
keys=vocab, values=tf.constant(list(range(len(vocab))),
dtype=tf.int64)),
default_value=0)
def to_ids(x):
s = tf.reshape(x['snippets'], shape=[1])
chars = tf.string_split(s, delimiter='').values
chars = tf.strings.bytes_split(s).values
ids = table.lookup(chars)
return ids
def split_input_target(chunk):
input_text = tf.map_fn(lambda x: x[:-1], chunk)
target_text = tf.map_fn(lambda x: x[1:], chunk)
return BatchType(input_text, target_text)
def preprocess(dataset):
return (
# Map ASCII chars to int64 indexes using the vocab
dataset.map(to_ids)
# Split into individual chars
.apply(tf.data.experimental.unbatch())
# Form example sequences of SEQ_LENGTH +1
.batch(SEQ_LENGTH + 1, drop_remainder=True)
.batch(SEQ_LENGTH + 1, drop_remainder=True)
# Shuffle and form minibatches
.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
# And finally split into (input, target) tuples,
# each of length SEQ_LENGTH.
.map(split_input_target))
```
%% Cell type:markdown id: tags:
Note that in the formation of the original sequences and in the formation of
batches above, we use `drop_remainder=True` for simplicity. This means that any
characters (clients) that don't have at least `(SEQ_LENGTH + 1) * BATCH_SIZE`
chars of text will have empty datasets. A typical approach to address this would
be to pad the batches with a special token, and then mask the loss to not take
the padding tokens into account.
This would complicate the example somewhat, so for this tutorial we only use full batches, as in the
[standard tutorial](https://www.tensorflow.org/tutorials/sequences/text_generation).
However, in the federated setting this issue is more significant, because many
users might have small datasets.
Now we can preprocess our `raw_example_dataset`, and check the types:
%% Cell type:code id: tags:
```
example_dataset = preprocess(raw_example_dataset)
print(tf.compat.v1.data.get_output_types(example_dataset), tf.compat.v1.data.get_output_shapes(example_dataset))
print(tf.data.experimental.get_structure(example_dataset))
```
%% Output
BatchType(x=tf.int64, y=tf.int64) BatchType(x=TensorShape([8, 100]), y=TensorShape([8, 100]))
BatchType(x=TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), y=TensorSpec(shape=(8, 100), dtype=tf.int64, name=None))
%% Cell type:markdown id: tags:
# Compile the model and test on the preprocessed data
%% Cell type:markdown id: tags:
We loaded an uncompiled keras model, but in order to run `keras_model.evaluate`, we need to compile it with a loss and metrics. We will also compile in an optimizer, which will be used as the on-device optimizer in Federated Learning.
%% Cell type:markdown id: tags:
The original tutorial didn't have char-level accuracy (the fraction
of predictions where the highest probability was put on the correct
next char). This is a useful metric, so we add it.
However, we need to define a new metric class for this because
our predictions have rank 3 (a vector of logits for each of the
`BATCH_SIZE * SEQ_LENGTH` predictions), and `SparseCategoricalAccuracy`
expects only rank 2 predictions.
%% Cell type:code id: tags:
```
class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):
def __init__(self, name='accuracy', dtype=None):
super(FlattenedCategoricalAccuracy, self).__init__(name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.reshape(y_true, [-1, 1])
y_pred = tf.reshape(y_pred, [-1, len(vocab), 1])
return super(FlattenedCategoricalAccuracy, self).update_state(
y_true, y_pred, sample_weight)
```
%% Cell type:code id: tags:
```
def compile(keras_model):
keras_model.compile(
optimizer=tf.keras.optimizers.SGD(lr=0.5),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[FlattenedCategoricalAccuracy()])
return keras_model
```
%% Cell type:markdown id: tags:
Now we can compile a model, and evaluate it on our `example_dataset`.
%% Cell type:code id: tags:
```
BATCH_SIZE = 8 # The training and eval batch size for the rest of this tutorial.
keras_model = load_model(batch_size=BATCH_SIZE)
compile(keras_model)
# Confirm that loss is much lower on Shakespeare than on random data
print('Evaluating on an example Shakespeare character:')
keras_model.evaluate(example_dataset.take(1))
# As a sanity check, we can construct some completely random data, where we expect
# the accuracy to be essentially random:
random_indexes = np.random.randint(
low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1))
data = {
'snippets':
tf.constant(''.join(np.array(vocab)[random_indexes]), shape=[1, 1])
}
random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))
print('Expected accuracy for random guessing: {:.3f}'.format(1.0 / len(vocab)))
print('\nExpected accuracy for random guessing: {:.3f}'.format(1.0 / len(vocab)))
print('Evaluating on completely random data:')
keras_model.evaluate(random_dataset, steps=1)
```
%% Output
Evaluating on an example Shakespeare character:
1/1 [==============================] - 4s 4s/step - loss: 3.3485 - accuracy: 0.3900
1/Unknown - 1s 1s/step - loss: 2.9296 - accuracy: 0.4412
Expected accuracy for random guessing: 0.012
Evaluating on completely random data:
1/1 [==============================] - 0s 210ms/step - loss: 11.4624 - accuracy: 0.0113
1/1 [==============================] - 0s 426ms/step - loss: 11.4698 - accuracy: 0.0125
[11.462431907653809, 0.01125]
[11.4698486328125, 0.0125]
%% Cell type:markdown id: tags:
# Fine-tune the model with Federated Learning
%% Cell type:markdown id: tags:
TFF serializes all TensorFlow computations so they can potentially be run in a
non-Python environment (even though at the moment, only a simulation runtime implemented in Python is available). Even though we are running in eager mode, (TF 2.0), currently TFF serializes TensorFlow computations by constructing the
necessary ops inside the context of a "`with tf.Graph.as_default()`" statement.
Thus, we need to provide a function that TFF can use to introduce our model into
a graph it controls. We do this as follows:
%% Cell type:code id: tags:
```
# Clone the keras_model inside `create_tff_model()`, which TFF will
# call to produce a new copy of the model inside the graph that it will serialize.
def create_tff_model():
# TFF uses a `dummy_batch` so it knows the types and shapes
# that your model expects.
x = tf.constant(np.random.randint(1, len(vocab), size=[BATCH_SIZE, SEQ_LENGTH]))
dummy_batch = collections.OrderedDict([('x', x), ('y', x)])
keras_model_clone = compile(tf.keras.models.clone_model(keras_model))
return tff.learning.from_compiled_keras_model(
keras_model_clone, dummy_batch=dummy_batch)
```
%% Cell type:markdown id: tags:
Now we are ready to construct a Federated Averaging iterative process, which we will use to improve the model (for details on the Federated Averaging algorithm, see the paper [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629)).
We use a compiled Keras model to perform standard (non-federated) evaluation after each round of federated training. This is useful for research purposes when doing simulated federated learning and there is a standard test dataset.
In a realistic production setting this same technique might be used to take models trained with federated learning and evaluate them on a centralized benchmark dataset for testing or quality assurance purposes.
%% Cell type:code id: tags:
```
# This command builds all the TensorFlow graphs and serializes them:
fed_avg = tff.learning.build_federated_averaging_process(model_fn=create_tff_model)
```
%% Cell type:markdown id: tags:
Here is the simplest possible loop, where we run federated averaging for one round on a single client on a single batch:
%% Cell type:code id: tags:
```
state = fed_avg.initialize()
state, metrics = fed_avg.next(state, [example_dataset.take(1)])
print(metrics)
```
%% Output
<accuracy=0.013749999925494194,loss=4.454826354980469>
<accuracy=0.016249999404,loss=4.45304775238>
%% Cell type:markdown id: tags:
Now let's write a slightly more interesting training and evaluation loop.
So that this simulation still runs relatively quickly, we train on the same three clients each round, only considering two minibatches for each.
%% Cell type:code id: tags:
```
def data(client, source=train_data):
return preprocess(
source.create_tf_dataset_for_client(client)).take(2)
clients = ['ALL_S_WELL_THAT_ENDS_WELL_CELIA',
'MUCH_ADO_ABOUT_NOTHING_OTHELLO',
'THE_TRAGEDY_OF_KING_LEAR_KING']
train_datasets = [data(client) for client in clients]
# We concatenate the test datasets for evaluation with Keras.
test_dataset = functools.reduce(
lambda d1, d2: d1.concatenate(d2),
[data(client, test_data) for client in clients])
# NOTE: If the statement below fails, it means that you are
# using an older version of TFF without the high-performance
# executor stack. Call `tff.framework.set_default_executor()`
# instead to use the default reference runtime.
if six.PY3:
tff.framework.set_default_executor(tff.framework.create_local_executor())
```
%% Cell type:markdown id: tags:
The initial state of the model produced by `fed_avg.initialize()` is based
on the random initializers for the Keras model, not the weights that were loaded,
since `clone_model()` does not clone the weights. To start training
from a pre-trained model, we set the model weights in the server state
directly from the loaded model.
%% Cell type:code id: tags:
```
NUM_ROUNDS = 3
# The state of the FL server, containing the model and optimization state.
state = fed_avg.initialize()
state = tff.learning.state_with_new_model_weights(
state,
trainable_weights=[v.numpy() for v in keras_model.trainable_weights],
non_trainable_weights=[
v.numpy() for v in keras_model.non_trainable_weights
])
def keras_evaluate(state, round_num):
tff.learning.assign_weights_to_keras_model(keras_model, state.model)
print('Evaluating before training round', round_num)
keras_model.evaluate(example_dataset, steps=2)
for round_num in range(NUM_ROUNDS):
keras_evaluate(state, round_num)
# N.B. The TFF runtime is currently fairly slow,
# expect this to get significantly faster in future releases.
state, metrics = fed_avg.next(state, train_datasets)
print('Training metrics: ', metrics)
keras_evaluate(state, NUM_ROUNDS + 1)
```
%% Output
Evaluating before training round 0
2/2 [==============================] - 1s 452ms/step - loss: 3.2266 - accuracy: 0.4081
Training metrics: <accuracy=0.4099999964237213,loss=3.2493884563446045>
2/2 [==============================] - 1s 649ms/step - loss: 3.2126 - accuracy: 0.4288
Training metrics: <accuracy=0.413541674614,loss=3.30218100548>
Evaluating before training round 1
2/2 [==============================] - 1s 442ms/step - loss: 2.9462 - accuracy: 0.4331
Training metrics: <accuracy=0.42916667461395264,loss=2.899351119995117>
2/2 [==============================] - 1s 561ms/step - loss: 2.8153 - accuracy: 0.4588
Training metrics: <accuracy=0.434791654348,loss=2.98462796211>
Evaluating before training round 2
2/2 [==============================] - 1s 449ms/step - loss: 2.7754 - accuracy: 0.4500
Training metrics: <accuracy=0.47083333134651184,loss=2.5915045738220215>
2/2 [==============================] - 1s 601ms/step - loss: 2.7724 - accuracy: 0.4575
Training metrics: <accuracy=0.437916666269,loss=2.87845301628>
Evaluating before training round 4
2/2 [==============================] - 1s 435ms/step - loss: 2.5748 - accuracy: 0.4837
2/2 [==============================] - 1s 645ms/step - loss: 2.7769 - accuracy: 0.4444
%% Cell type:markdown id: tags:
With the default changes, we haven't done enough training to make a big difference, but if you train longer on more Shakespeare data, you should see a difference in the style of the text generated with the updated model:
%% Cell type:code id: tags:
```
keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights])
# Text generation requires batch_size=1
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
```
%% Output
What of TensorFlow Federated, you ask? Saying Doctor Manette's house but deny.
What of TensorFlow Federated, you ask? Says it with the knitting.
g of a spy, kept from such hopes to a passenger, whom, one. Piploing And my will be a moment!" said Monsieur Defarge.
There were more than twice, went them, and clieding
%% Cell type:markdown id: tags:
# Suggested extensions
This tutorial is just the first step! Here are some ideas for how you might try extending this notebook:
* Write a more realistic training loop where you sample clients to train on randomly.
* Use "`.repeat(NUM_EPOCHS)`" on the client datasets to try multiple epochs of local training (e.g., as in [McMahan et. al.](https://arxiv.org/abs/1602.05629)). See also [Federated Learning for Image Classification](federated_learning_for_image_classification.md) which does this.
* Change the `compile()` command to experiment with using different optimization algorithms on the client.
* Try the `server_optimizer` argument to `build_federated_averaging_process` to try different algorithms for applying the model updates on the server.
* Try the `client_weight_fn` argument to to `build_federated_averaging_process` to try different weightings of the clients. The default weights client updates by the number of examples on the client, but you can do e.g. `client_weight_fn=lambda _: tf.constant(1.0)`.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment