"Keras is the [recommended high-level model API for TensorFlow](https://medium.com/tensorflow/standardizing-on-keras-guidance-on-high-level-apis-in-tensorflow-2-0-bad2b04c819a), and we encourage using Keras models (via \n",
"`tff.learning.from_keras_model` or\n",
"`tff.learning.from_compiled_keras_model`) in TFF whenever possible.\n",
"`tff.learning.from_keras_model`) in TFF whenever possible.\n",
"\n",
"However, `tff.learning` provides a lower-level model interface, `tff.learning.Model`, that exposes the minimal functionality necessary for using a model for federated learning. Directly implementing this interface (possibly still using building blocks like `tf.keras.layers`) allows for maximum customization without modifying the internals of the federated learning algorithms.\n",
"\n",
...
...
%% Cell type:markdown id: tags:
##### Copyright 2019 The TensorFlow Authors.
%% Cell type:code id: tags:
```
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```
%% Cell type:markdown id: tags:
# Federated Learning for Image Classification
%% Cell type:markdown id: tags:
<tableclass="tfo-notebook-buttons"align="left">
<td>
<atarget="_blank"href="https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification"><imgsrc="https://www.tensorflow.org/images/tf_logo_32px.png"/>View on TensorFlow.org</a>
</td>
<td>
<atarget="_blank"href="https://colab.research.google.com/github/tensorflow/federated/blob/v0.14.0/docs/tutorials/federated_learning_for_image_classification.ipynb"><imgsrc="https://www.tensorflow.org/images/colab_logo_32px.png"/>Run in Google Colab</a>
</td>
<td>
<atarget="_blank"href="https://github.com/tensorflow/federated/blob/v0.14.0/docs/tutorials/federated_learning_for_image_classification.ipynb"><imgsrc="https://www.tensorflow.org/images/GitHub-Mark-32px.png"/>View source on GitHub</a>
</td>
</table>
%% Cell type:markdown id: tags:
**NOTE**: This colab has been verified to work with the [latest released version](https://github.com/tensorflow/federated#compatibility) of the `tensorflow_federated` pip package, but the Tensorflow Federated project is still in pre-release development and may not work on `master`.
In this tutorial, we use the classic MNIST training example to introduce the
Federated Learning (FL) API layer of TFF, `tff.learning` - a set of
higher-level interfaces that can be used to perform common types of federated
learning tasks, such as federated training, against user-supplied models
implemented in TensorFlow.
This tutorial, and the Federated Learning API, are intended primarily for users
who want to plug their own TensorFlow models into TFF, treating the latter
mostly as a black box. For a more in-depth understanding of TFF and how to
implement your own federated learning algorithms, see the tutorials on the FC Core API - [Custom Federated Algorithms Part 1](custom_federated_algorithms_1.ipynb) and [Part 2](custom_federated_algorithms_2.ipynb).
For more on `tff.learning`, continue with the
[Federated Learning for Text Generation](federated_learning_for_text_generation.ipynb),
tutorial which in addition to covering recurrent models, also demonstrates loading a
pre-trained serialized Keras model for refinement with federated learning
combined with evaluation using Keras.
%% Cell type:markdown id: tags:
## Before we start
Before we start, please run the following to make sure that your environment is
correctly setup. If you don't see a greeting, please refer to the
[Installation](../install.md) guide for instructions.
In order to facilitate experimentation, we seeded the TFF repository with a few
datasets, including a federated version of MNIST that contains a version of the [original NIST dataset](https://www.nist.gov/srd/nist-special-database-19) that has been re-processed using [Leaf](https://github.com/TalwalkarLab/leaf) so that the data is keyed by the original writer of the digits. Since each writer has a unique style, this dataset exhibits the kind of non-i.i.d. behavior expected of federated datasets.
In particular, one should think about `next()` not as being a function that runs on a server, but rather being a declarative functional representation of the entire decentralized computation - some of the inputs are provided by the server (`SERVER_STATE`), but each participating device contributes its own local dataset.
Let's run a single round of training and visualize the results. We can use the
federated data we've already generated above for a sample of users.
Start TensorBoard with the root log directory specified above. It can take a few seconds for the data to load.
%% Cell type:code id: tags:
```
#@test {"skip": true}
%tensorboard --logdir /tmp/logs/scalars/ --port=0
```
%% Cell type:code id: tags:
```
#@test {"skip": true}
# Run this this cell to clean your directory of old output for future graphs from this directory.
!rm -R /tmp/logs/scalars/*
```
%% Cell type:markdown id: tags:
In order to view evaluation metrics the same way, you can create a separate eval folder, like "logs/scalars/eval", to write to TensorBoard.
%% Cell type:markdown id: tags:
## Customizing the model implementation
Keras is the [recommended high-level model API for TensorFlow](https://medium.com/tensorflow/standardizing-on-keras-guidance-on-high-level-apis-in-tensorflow-2-0-bad2b04c819a), and we encourage using Keras models (via
`tff.learning.from_keras_model` or
`tff.learning.from_compiled_keras_model`) in TFF whenever possible.
`tff.learning.from_keras_model`) in TFF whenever possible.
However, `tff.learning` provides a lower-level model interface, `tff.learning.Model`, that exposes the minimal functionality necessary for using a model for federated learning. Directly implementing this interface (possibly still using building blocks like `tf.keras.layers`) allows for maximum customization without modifying the internals of the federated learning algorithms.
So let's do it all over again from scratch.
### Defining model variables, forward pass, and metrics
The first step is to identify the TensorFlow variables we're going to work with.
In order to make the following code more legible, let's define a data structure
to represent the entire set. This will include variables such as `weights` and
`bias` that we will train, as well as variables that will hold various
cumulative statistics and counters we will update during training, such as
Next, we define a function that returns a set of local metrics, again using TensorFlow. These are the values (in addition to model updates, which are handled automatically) that are eligible to be aggregated to the server in a federated learning or evaluation process.
Here, we simply return the average `loss` and `accuracy`, as well as the
`num_examples`, which we'll need to correctly weight the contributions from
different users when computing federated aggregates.
Finally, we need to determine how to aggregate the local metrics emitted by each
device via `get_local_mnist_metrics`. This is the only part of the code that isn't written in TensorFlow - it's a *federated computation* expressed in TFF. If you'd like to
dig deeper, skim over the [custom algorithms](custom_federated_algorithms_1.ipynb)
tutorial, but in most applications, you won't really need to; variants of the
pattern shown below should suffice. Here's what it looks like:
The input `metrics` argument corresponds to the `OrderedDict` returned by `get_local_mnist_metrics` above, but critically the values are no longer `tf.Tensors` - they are "boxed" as `tff.Value`s, to make it clear you can no longer manipulate them using TensorFlow, but only using TFF's federated operators like `tff.federated_mean` and `tff.federated_sum`. The returned
dictionary of global aggregates defines the set of metrics which will be available on the server.
%% Cell type:markdown id: tags:
### Constructing an instance of `tff.learning.Model`
With all of the above in place, we are ready to construct a model representation
for use with TFF similar to one that's generated for you when you let TFF ingest
This concludes the tutorial. We encourage you to play with the
parameters (e.g., batch sizes, number of users, epochs, learning rates, etc.), to modify the code above to simulate training on random samples of users in
each round, and to explore the other tutorials we've developed.