Skip to content
Snippets Groups Projects
Commit 9139a86a authored by Jakub Konecny's avatar Jakub Konecny Committed by tensorflow-copybara
Browse files

Updates non-federated EMNIST-10 and EMNIST-62 baselines.

This also updates the default model for one which is smaller and less prone to overfitting.

PiperOrigin-RevId: 270204256
parent d722e380
No related branches found
No related tags found
No related merge requests found
......@@ -20,3 +20,25 @@ py_binary(
"//tensorflow_federated/python/research/utils:utils_impl",
],
)
py_binary(
name = "non_federated_emnist_10",
srcs = ["non_federated_emnist_10.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":models",
"//tensorflow_federated",
],
)
py_binary(
name = "non_federated_emnist_62",
srcs = ["non_federated_emnist_62.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":models",
"//tensorflow_federated",
],
)
......@@ -4,14 +4,17 @@ Note: This directory is a work-in-progress.
## Overview
This directory contains models and experiment scripts for training baseline
federated and non-federated models on the
This directory contains multiple model architectures and experiment scripts for
training baseline federated and non-federated models on the
[Federated EMNIST](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/emnist/load_data)
dataset. The model architectures and follow those from the paper
dataset.
NOTE: The model architecture from the paper
[Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629),
however, since we use the Federated EMNIST dataset (with its natural user
partitioning), rather than MNIST (with either a synthetic non-IID or shuffled
IID partitioning), the results are not directly comparable.
is exactly reproduced (`models.original_fedavg_cnn_model`). However, since we
use the Federated EMNIST dataset (with its natural user partitioning), rather
than the MNIST partitioning (either a synthetic non-IID or shuffled IID) from
the original paper, the results are not directly comparable.
## Citation
......@@ -19,10 +22,10 @@ If you use these baselines and need to cite this work, please use:
```
@misc{mcmahan19emnist_baseline,
author = {H. Brendan McMahan},
author = {H. Brendan McMahan and Jakub Kone{\v{c}}n{\'y}},
title = {{Federated EMNIST Baseline Training}},
year = 2019,
url = {https://github.com/tensorflow/federated/tree/master/tensorflow_federated/python/research/baselines/emnist}
url = {https://github.com/tensorflow/federated/tree/master/tensorflow_federated/python/research/baselines/emnist}
}
```
......
# Lint as: python3
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -19,16 +19,109 @@ import functools
import tensorflow as tf
def create_keras_model(only_digits=True):
def create_conv_dropout_model(only_digits=True):
"""Recommended model to use for EMNIST experiments.
When `only_digits=True`, the summary of returned model is
```
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
reshape (Reshape) (None, 28, 28, 1) 0
_________________________________________________________________
conv2d (Conv2D) (None, 26, 26, 32) 320
_________________________________________________________________
conv2d_1 (Conv2D) (None, 24, 24, 64) 18496
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 12, 12, 64) 0
_________________________________________________________________
dropout (Dropout) (None, 12, 12, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 9216) 0
_________________________________________________________________
dense (Dense) (None, 128) 1179776
_________________________________________________________________
dropout_1 (Dropout) (None, 128) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 1290
=================================================================
Total params: 1,199,882
Trainable params: 1,199,882
Non-trainable params: 0
```
For `only_digits=False`, the last dense layer is slightly larger.
Args:
only_digits: If True, uses a final layer with 10 outputs, for use with the
digits only EMNIST dataset. If False, uses 62 outputs for the larger
dataset.
Returns:
A `tf.keras.Model`.
"""
data_format = 'channels_last'
input_shape = [28, 28, 1]
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(
32,
kernel_size=(3, 3),
activation='relu',
input_shape=input_shape,
data_format=data_format),
tf.keras.layers.Conv2D(
64, kernel_size=(3, 3), activation='relu', data_format=data_format),
tf.keras.layers.MaxPool2D(pool_size=(2, 2), data_format=data_format),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(
10 if only_digits else 62, activation=tf.nn.softmax),
])
return model
def create_original_fedavg_cnn_model(only_digits=True):
"""The CNN model used in https://arxiv.org/abs/1602.05629.
The number of parameters when `only_digits=True` is (1,663,370), which matches
what is reported in the paper.
When `only_digits=True`, the summary of returned model is
```
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
reshape (Reshape) (None, 28, 28, 1) 0
_________________________________________________________________
conv2d (Conv2D) (None, 28, 28, 32) 832
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 14, 14, 64) 51264
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 3136) 0
_________________________________________________________________
dense (Dense) (None, 512) 1606144
_________________________________________________________________
dense_1 (Dense) (None, 10) 5130
=================================================================
Total params: 1,663,370
Trainable params: 1,663,370
Non-trainable params: 0
```
For `only_digits=False`, the last dense layer is slightly larger.
Args:
only_digits: if True, uses a final layer with 10 outputs, for use with the
digits only EMNIST dataset. If False, uses 62 outputs for the larger
dataset.
only_digits: If True, uses a final layer with 10 outputs, for use with the
digits only EMNIST dataset. If False, uses 62 outputs for the larger
dataset.
Returns:
A `tf.keras.Model`.
......@@ -49,8 +142,7 @@ def create_keras_model(only_digits=True):
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
tf.keras.layers.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),
conv2d(filters=32),
conv2d(filters=32, input_shape=input_shape),
max_pool(),
conv2d(filters=64),
max_pool(),
......
# Lint as: python3
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Baseline experiment on centralized data.
The objective is to demonstrate what is possible if the EMNIST-62 data was
available in a central location, where all common optimization techniques apply.
In practice, federated learning is typically applied in cases where the
on-device data cannot be centralized, and so this ideal goal need not be
realized for FL to be effective. However, when studying optimization algorithms,
it is interesting to study this comparison.
In an example run, the model achieved:
* 98.93% accuracy after 2 passes through data.
* 99.31% accuracy after 10 passes through data.
* 99.45% accuracy after 25 passes through data.
"""
import collections
from absl import app
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow_federated.python.research.baselines.emnist import models
BATCH_SIZE = 100
# The total number of examples in
# emnist_train.create_tf_dataset_from_all_clients()
TOTAL_EXAMPLES = 341873
def run_experiment():
"""Runs the training experiment."""
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
only_digits=True)
example_tuple = collections.namedtuple('Example', ['x', 'y'])
def element_fn(element):
return example_tuple(
# The expand_dims adds a channel dimension.
x=tf.expand_dims(element['pixels'], -1),
y=element['label'])
all_train = emnist_train.create_tf_dataset_from_all_clients().map(element_fn)
all_train = all_train.shuffle(TOTAL_EXAMPLES).repeat().batch(BATCH_SIZE)
all_test = emnist_test.create_tf_dataset_from_all_clients().map(element_fn)
all_test = all_test.batch(BATCH_SIZE)
train_data_elements = int(TOTAL_EXAMPLES / BATCH_SIZE)
model = models.create_conv_dropout_model(only_digits=True)
model.compile(
loss=tf.keras.losses.sparse_categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(
learning_rate=0.01,
momentum=0.9,
decay=0.2 / train_data_elements,
nesterov=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
model.fit_generator(
all_train,
steps_per_epoch=train_data_elements,
epochs=25,
verbose=1,
validation_data=all_test)
score = model.evaluate_generator(all_test, verbose=0)
print('Final test loss: %.4f' % score[0])
print('Final test accuracy: %.4f' % score[1])
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
tf.compat.v1.enable_v2_behavior()
run_experiment()
if __name__ == '__main__':
app.run(main)
# Lint as: python3
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Baseline experiment on centralized data.
The objective is to demonstrate what is possible if the EMNIST-62 data was
available in a central location, where all common optimization techniques apply.
In practice, federated learning is typically applied in cases where the
on-device data cannot be centralized, and so this ideal goal need not be
realized for FL to be effective. However, when studying optimization algorithms,
it is interesting to study this comparison.
In an example run, the model achieved:
* 85.33% accuracy after 2 passes through data.
* 87.20% accuracy after 10 passes through data.
* 87.83% accuracy after 25 passes through data.
* 88.22% accuracy after 75 passes through data.
"""
import collections
from absl import app
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow_federated.python.research.baselines.emnist import models
BATCH_SIZE = 100
# The total number of examples in
# emnist_train.create_tf_dataset_from_all_clients()
TOTAL_EXAMPLES = 671585
def run_experiment():
"""Runs the training experiment."""
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
only_digits=False)
example_tuple = collections.namedtuple('Example', ['x', 'y'])
def element_fn(element):
return example_tuple(
# The expand_dims adds a channel dimension.
x=tf.expand_dims(element['pixels'], -1),
y=element['label'])
all_train = emnist_train.create_tf_dataset_from_all_clients().map(element_fn)
all_train = all_train.shuffle(TOTAL_EXAMPLES).repeat().batch(BATCH_SIZE)
all_test = emnist_test.create_tf_dataset_from_all_clients().map(element_fn)
all_test = all_test.batch(BATCH_SIZE)
train_data_elements = int(TOTAL_EXAMPLES / BATCH_SIZE)
model = models.create_conv_dropout_model(only_digits=False)
model.compile(
loss=tf.keras.losses.sparse_categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(
learning_rate=0.01,
momentum=0.9,
decay=0.2 / train_data_elements,
nesterov=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
model.fit_generator(
all_train,
steps_per_epoch=train_data_elements,
epochs=75,
verbose=1,
validation_data=all_test)
score = model.evaluate_generator(all_test, verbose=0)
print('Final test loss: %.4f' % score[0])
print('Final test accuracy: %.4f' % score[1])
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
tf.compat.v1.enable_v2_behavior()
run_experiment()
if __name__ == '__main__':
app.run(main)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment