"Since the data is already a `tf.data.Dataset`, preprocessing can be accomplished using Dataset transformations. Here, we flatten the `28x28` images\n",
"into `784`-element arrays, shuffle the individual examples, organize them into batches, and renames the features\n",
"into `784`-element arrays, shuffle the individual examples, organize them into batches, and rename the features\n",
"from `pixels` and `label` to `x` and `y` for use with Keras. We also throw in a\n",
"`repeat` over the data set to run several epochs."
]
...
...
%% Cell type:markdown id: tags:
##### Copyright 2019 The TensorFlow Authors.
%% Cell type:code id: tags:
```
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```
%% Cell type:markdown id: tags:
# Federated Learning for Image Classification
%% Cell type:markdown id: tags:
<tableclass="tfo-notebook-buttons"align="left">
<td>
<atarget="_blank"href="https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification"><imgsrc="https://www.tensorflow.org/images/tf_logo_32px.png"/>View on TensorFlow.org</a>
</td>
<td>
<atarget="_blank"href="https://colab.research.google.com/github/tensorflow/federated/blob/master/docs/tutorials/federated_learning_for_image_classification.ipynb"><imgsrc="https://www.tensorflow.org/images/colab_logo_32px.png"/>Run in Google Colab</a>
</td>
<td>
<atarget="_blank"href="https://github.com/tensorflow/federated/blob/master/docs/tutorials/federated_learning_for_image_classification.ipynb"><imgsrc="https://www.tensorflow.org/images/GitHub-Mark-32px.png"/>View source on GitHub</a>
</td>
</table>
%% Cell type:markdown id: tags:
**NOTE**: This colab has been verified to work with the [latest released version](https://github.com/tensorflow/federated#compatibility) of the `tensorflow_federated` pip package, but the Tensorflow Federated project is still in pre-release development and may not work on `master`.
In this tutorial, we use the classic MNIST training example to introduce the
Federated Learning (FL) API layer of TFF, `tff.learning` - a set of
higher-level interfaces that can be used to perform common types of federated
learning tasks, such as federated training, against user-supplied models
implemented in TensorFlow.
This tutorial, and the Federated Learning API, are intended primarily for users
who want to plug their own TensorFlow models into TFF, treating the latter
mostly as a black box. For a more in-depth understanding of TFF and how to
implement your own federated learning algorithms, see the tutorials on the FC Core API - [Custom Federated Algorithms Part 1](custom_federated_algorithms_1.ipynb) and [Part 2](custom_federated_algorithms_2.ipynb).
For more on `tff.learning`, continue with the
[Federated Learning for Text Generation](federated_learning_for_text_generation.ipynb),
tutorial which in addition to covering recurrent models, also demonstrates loading a
pre-trained serialized Keras model for refinement with federated learning
combined with evaluation using Keras.
%% Cell type:markdown id: tags:
## Before we start
Before we start, please run the following to make sure that your environment is
correctly setup. If you don't see a greeting, please refer to the
[Installation](../install.md) guide for instructions.
In order to facilitate experimentation, we seeded the TFF repository with a few
datasets, including a federated version of MNIST that contains a version of the [original NIST dataset](https://www.nist.gov/srd/nist-special-database-19) that has been re-processed using [Leaf](https://github.com/TalwalkarLab/leaf) so that the data is keyed by the original writer of the digits. Since each writer has a unique style, this dataset exhibits the kind of non-i.i.d. behavior expected of federated datasets.
Federated data is typically non-[i.i.d.](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables), users typically have different distributions of data depending on usage patterns. Some clients may have fewer training examples on device, suffering from data paucity locally, while some clients will have more than enough training examples. Let's explore this concept of data heterogeneity typical of a federated system with the EMNIST data we have available. It's important to note that this deep analysis of a client's data is only available to us because this is a simulation environment where all the data is available to us locally. In a real production federated environment you would not be able to inspect a single client's data.
%% Cell type:markdown id: tags:
First, let's grab a sampling of one client's data to get a feel for the examples on one simulated device. Because the dataset we're using has been keyed by unique writer, the data of one client represents the handwriting of one person for a sample of the digits 0 through 9, simulating the unique "usage pattern" of one user.
Now let's visualize the number of examples on each client for each MNIST digit label. In the federated environment, the number of examples on each client can vary quite a bit, depending on user behavior.
%% Cell type:code id: tags:
```
# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
# Append counts individually per label to make plots
# more colorful instead of one color per plot.
label = example['label'].numpy()
plot_data[label].append(label)
plt.subplot(2, 3, i+1)
plt.title('Client {}'.format(i))
for j in range(10):
plt.hist(
plot_data[j],
density=False,
bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
```
%% Output
%% Cell type:markdown id: tags:
Now let's visualize the mean image per client for each MNIST label. This code will produce the mean of each pixel value for all of the user's examples for one label. We'll see that one client's mean image for a digit will look different than another client's mean image for the same digit, due to each person's unique handwriting style. We can muse about how each local training round will nudge the model in a different direction on each client, as we're learning from that user's own unique data in that local round. Later in the tutorial we'll see how we can take each update to the model from all the clients and aggregate them together into our new global model, that has learned from each of our client's own unique data.
%% Cell type:code id: tags:
```
# Each client has different mean images, meaning each client will be nudging
f.suptitle("Client #{}'s Mean Image Per Label".format(i))
for j in range(10):
mean_img = np.mean(plot_data[j], 0)
plt.subplot(2, 5, j+1)
plt.imshow(mean_img.reshape((28, 28)))
plt.axis('off')
```
%% Output
%% Cell type:markdown id: tags:
User data can be noisy and unreliably labeled. For example, looking at Client #2's data above, we can see that for label 2, it is possible that there may have been some mislabeled examples creating a noisier mean image.
%% Cell type:markdown id: tags:
### Preprocessing the input data
%% Cell type:markdown id: tags:
Since the data is already a `tf.data.Dataset`, preprocessing can be accomplished using Dataset transformations. Here, we flatten the `28x28` images
into `784`-element arrays, shuffle the individual examples, organize them into batches, and renames the features
into `784`-element arrays, shuffle the individual examples, organize them into batches, and rename the features
from `pixels` and `label` to `x` and `y` for use with Keras. We also throw in a
`repeat` over the data set to run several epochs.
%% Cell type:code id: tags:
```
NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10
def preprocess(dataset):
def batch_format_fn(element):
"""Flatten a batch `pixels` and return the features as an `OrderedDict`."""
In particular, one should think about `next()` not as being a function that runs on a server, but rather being a declarative functional representation of the entire decentralized computation - some of the inputs are provided by the server (`SERVER_STATE`), but each participating device contributes its own local dataset.
Let's run a single round of training and visualize the results. We can use the
federated data we've already generated above for a sample of users.
for name, value in metrics.train._asdict().items():
tf.summary.scalar(name, value, step=round_num)
```
%% Cell type:markdown id: tags:
Start TensorBoard with the root log directory specified above. It can take a few seconds for the data to load.
%% Cell type:code id: tags:
```
#@test {"skip": true}
%tensorboard --logdir /tmp/logs/scalars/ --port=0
```
%% Cell type:code id: tags:
```
#@test {"skip": true}
# Run this this cell to clean your directory of old output for future graphs from this directory.
!rm -R /tmp/logs/scalars/*
```
%% Cell type:markdown id: tags:
In order to view evaluation metrics the same way, you can create a separate eval folder, like "logs/scalars/eval", to write to TensorBoard.
%% Cell type:markdown id: tags:
## Customizing the model implementation
Keras is the [recommended high-level model API for TensorFlow](https://medium.com/tensorflow/standardizing-on-keras-guidance-on-high-level-apis-in-tensorflow-2-0-bad2b04c819a), and we encourage using Keras models (via
`tff.learning.from_keras_model`) in TFF whenever possible.
However, `tff.learning` provides a lower-level model interface, `tff.learning.Model`, that exposes the minimal functionality necessary for using a model for federated learning. Directly implementing this interface (possibly still using building blocks like `tf.keras.layers`) allows for maximum customization without modifying the internals of the federated learning algorithms.
So let's do it all over again from scratch.
### Defining model variables, forward pass, and metrics
The first step is to identify the TensorFlow variables we're going to work with.
In order to make the following code more legible, let's define a data structure
to represent the entire set. This will include variables such as `weights` and
`bias` that we will train, as well as variables that will hold various
cumulative statistics and counters we will update during training, such as
Next, we define a function that returns a set of local metrics, again using TensorFlow. These are the values (in addition to model updates, which are handled automatically) that are eligible to be aggregated to the server in a federated learning or evaluation process.
Here, we simply return the average `loss` and `accuracy`, as well as the
`num_examples`, which we'll need to correctly weight the contributions from
different users when computing federated aggregates.
Finally, we need to determine how to aggregate the local metrics emitted by each
device via `get_local_mnist_metrics`. This is the only part of the code that isn't written in TensorFlow - it's a *federated computation* expressed in TFF. If you'd like to
dig deeper, skim over the [custom algorithms](custom_federated_algorithms_1.ipynb)
tutorial, but in most applications, you won't really need to; variants of the
pattern shown below should suffice. Here's what it looks like:
The input `metrics` argument corresponds to the `OrderedDict` returned by `get_local_mnist_metrics` above, but critically the values are no longer `tf.Tensors` - they are "boxed" as `tff.Value`s, to make it clear you can no longer manipulate them using TensorFlow, but only using TFF's federated operators like `tff.federated_mean` and `tff.federated_sum`. The returned
dictionary of global aggregates defines the set of metrics which will be available on the server.
%% Cell type:markdown id: tags:
### Constructing an instance of `tff.learning.Model`
With all of the above in place, we are ready to construct a model representation
for use with TFF similar to one that's generated for you when you let TFF ingest
This concludes the tutorial. We encourage you to play with the
parameters (e.g., batch sizes, number of users, epochs, learning rates, etc.), to modify the code above to simulate training on random samples of users in
each round, and to explore the other tutorials we've developed.