custom_federated_algorithms_2.ipynb 62.8 KB
Newer Older
Billy Lamberta's avatar
Billy Lamberta 已提交
1
{
Michael Reneer's avatar
Michael Reneer 已提交
2
3
4
5
6
7
8
9
10
11
12
13
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tqrD7Yzlmlsk"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
14
      "execution_count": 1,
Michael Reneer's avatar
Michael Reneer 已提交
15
16
      "metadata": {
        "cellView": "form",
17
18
19
20
21
22
23
24
25
26
27
        "executionInfo": {
          "elapsed": 5,
          "status": "ok",
          "timestamp": 1638423023724,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 480
        },
Michael Reneer's avatar
Michael Reneer 已提交
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        "id": "2k8X1C1nmpKv"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "32xflLc4NTx-"
      },
      "source": [
        "# Custom Federated Algorithms, Part 2: Implementing Federated Averaging"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jtATV6DlqPs0"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/custom_federated_algorithms_2\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
65
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/federated/blob/main/docs/tutorials/custom_federated_algorithms_2.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
Michael Reneer's avatar
Michael Reneer 已提交
66
67
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
68
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/federated/blob/main/docs/tutorials/custom_federated_algorithms_2.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
Michael Reneer's avatar
Michael Reneer 已提交
69
        "  \u003c/td\u003e\n",
70
71
72
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/federated/docs/tutorials/custom_federated_algorithms_2.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
Michael Reneer's avatar
Michael Reneer 已提交
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_igJ2sfaNWS8"
      },
      "source": [
        "This tutorial is the second part of a two-part series that demonstrates how to\n",
        "implement custom types of federated algorithms in TFF using the\n",
        "[Federated Core (FC)](../federated_core.md), which serves as a foundation for\n",
        "the [Federated Learning (FL)](../federated_learning.md) layer (`tff.learning`).\n",
        "\n",
        "We encourage you to first read the\n",
        "[first part of this series](custom_federated_algorithms_1.ipynb), which\n",
        "introduce some of the key concepts and programming abstractions used here.\n",
        "\n",
        "This second part of the series uses the mechanisms introduced in the first part\n",
        "to implement a simple version of federated training and evaluation algorithms.\n",
        "\n",
        "We encourage you to review the\n",
        "[image classification](federated_learning_for_image_classification.ipynb) and\n",
        "[text generation](federated_learning_for_text_generation.ipynb) tutorials for a\n",
        "higher-level and more gentle introduction to TFF's Federated Learning APIs, as\n",
        "they will help you put the concepts we describe here in context."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cuJuLEh2TfZG"
      },
      "source": [
        "## Before we start\n",
        "\n",
        "Before we start, try to run the following \"Hello World\" example to make sure\n",
        "your environment is correctly setup. If it doesn't work, please refer to the\n",
        "[Installation](../install.md) guide for instructions."
      ]
    },
    {
      "cell_type": "code",
116
      "execution_count": 2,
Michael Reneer's avatar
Michael Reneer 已提交
117
      "metadata": {
118
119
120
121
122
123
124
125
126
127
128
        "executionInfo": {
          "elapsed": 5,
          "status": "ok",
          "timestamp": 1638423023873,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 480
        },
Michael Reneer's avatar
Michael Reneer 已提交
129
130
131
132
        "id": "rB1ovcX1mBxQ"
      },
      "outputs": [],
      "source": [
133
        "#@test {\"skip\": true}\n",
134
135
        "!pip install --quiet --upgrade tensorflow-federated-nightly\n",
        "!pip install --quiet --upgrade nest-asyncio\n",
136
137
138
        "\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()"
Michael Reneer's avatar
Michael Reneer 已提交
139
140
141
142
      ]
    },
    {
      "cell_type": "code",
143
      "execution_count": 3,
Michael Reneer's avatar
Michael Reneer 已提交
144
      "metadata": {
145
146
147
148
149
150
151
152
153
154
155
        "executionInfo": {
          "elapsed": 9212,
          "status": "ok",
          "timestamp": 1638423033241,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 480
        },
Michael Reneer's avatar
Michael Reneer 已提交
156
157
158
159
160
161
162
163
164
165
        "id": "-skNC6aovM46"
      },
      "outputs": [],
      "source": [
        "import collections\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "\n",
166
167
168
169
170
171
172
        "# Must use the Python context because it\n",
        "# supports tff.sequence_* intrinsics.\n",
        "executor_factory = tff.framework.local_executor_factory(\n",
        "    support_sequence_ops=True)\n",
        "execution_context = tff.framework.ExecutionContext(\n",
        "    executor_fn=executor_factory)\n",
        "tff.framework.set_default_context(execution_context)"
Michael Reneer's avatar
Michael Reneer 已提交
173
174
175
176
      ]
    },
    {
      "cell_type": "code",
177
      "execution_count": 4,
Michael Reneer's avatar
Michael Reneer 已提交
178
      "metadata": {
179
        "executionInfo": {
180
          "elapsed": 197,
181
          "status": "ok",
182
          "timestamp": 1638423033604,
183
184
185
186
187
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
188
          "user_tz": 480
189
190
        },
        "id": "zzXwGnZamIMM",
191
        "outputId": "d08b735a-eafa-4263-bc8f-4922207ffc16"
192
193
194
195
196
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
197
              "b'Hello, World!'"
198
199
            ]
          },
200
          "execution_count": 4,
201
          "metadata": {},
202
203
204
          "output_type": "execute_result"
        }
      ],
Michael Reneer's avatar
Michael Reneer 已提交
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
      "source": [
        "@tff.federated_computation\n",
        "def hello_world():\n",
        "  return 'Hello, World!'\n",
        "\n",
        "hello_world()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iu5Gd8D6W33s"
      },
      "source": [
        "## Implementing Federated Averaging\n",
        "\n",
        "As in\n",
Michael Reneer's avatar
Michael Reneer 已提交
222
        "[Federated Learning for Image Classification](federated_learning_for_image_classification.ipynb),\n",
Michael Reneer's avatar
Michael Reneer 已提交
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        "we are going to use the MNIST example, but since this is intended as a low-level\n",
        "tutorial, we are going to bypass the Keras API and `tff.simulation`, write raw\n",
        "model code, and construct a federated data set from scratch.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b6qCjef350c_"
      },
      "source": [
        "\n",
        "### Preparing federated data sets\n",
        "\n",
        "For the sake of a demonstration, we're going to simulate a scenario in which we\n",
        "have data from 10 users, and each of the users contributes knowledge how to\n",
        "recognize a different digit. This is about as\n",
        "non-[i.i.d.](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables)\n",
        "as it gets.\n",
        "\n",
        "First, let's load the standard MNIST data:"
      ]
    },
    {
      "cell_type": "code",
248
      "execution_count": 5,
Michael Reneer's avatar
Michael Reneer 已提交
249
      "metadata": {
250
        "executionInfo": {
251
          "elapsed": 362,
252
          "status": "ok",
253
          "timestamp": 1638423034141,
254
255
256
257
258
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
259
          "user_tz": 480
260
        },
261
        "id": "uThZM4Ds-KDQ"
Michael Reneer's avatar
Michael Reneer 已提交
262
      },
263
      "outputs": [],
Michael Reneer's avatar
Michael Reneer 已提交
264
265
266
267
268
269
      "source": [
        "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()"
      ]
    },
    {
      "cell_type": "code",
270
      "execution_count": 6,
Michael Reneer's avatar
Michael Reneer 已提交
271
      "metadata": {
272
        "executionInfo": {
273
          "elapsed": 59,
274
          "status": "ok",
275
          "timestamp": 1638423034374,
276
277
278
279
280
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
281
          "user_tz": 480
282
283
        },
        "id": "PkJc5rHA2no_",
284
        "outputId": "c48dbc92-c682-498a-a1bf-3b779d03c6ab"
285
286
287
288
289
290
291
292
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]"
            ]
          },
293
          "execution_count": 6,
294
          "metadata": {},
295
296
297
          "output_type": "execute_result"
        }
      ],
Michael Reneer's avatar
Michael Reneer 已提交
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
      "source": [
        "[(x.dtype, x.shape) for x in mnist_train]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mFET4BKJFbkP"
      },
      "source": [
        "The data comes as Numpy arrays, one with images and another with digit labels, both\n",
        "with the first dimension going over the individual examples. Let's write a\n",
        "helper function that formats it in a way compatible with how we feed federated\n",
        "sequences into TFF computations, i.e., as a list of lists - the outer list\n",
        "ranging over the users (digits), the inner ones ranging over batches of data in\n",
        "each client's sequence. As is customary, we will structure each batch as a pair\n",
        "of tensors named `x` and `y`, each with the leading batch dimension. While at\n",
        "it, we'll also flatten each image into a 784-element vector and rescale the\n",
        "pixels in it into the `0..1` range, so that we don't have to clutter the model\n",
        "logic with data conversions."
      ]
    },
    {
      "cell_type": "code",
322
      "execution_count": 7,
Michael Reneer's avatar
Michael Reneer 已提交
323
      "metadata": {
324
325
326
327
328
329
330
331
332
333
334
        "executionInfo": {
          "elapsed": 1460,
          "status": "ok",
          "timestamp": 1638423035978,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 480
        },
Michael Reneer's avatar
Michael Reneer 已提交
335
336
337
338
339
340
341
        "id": "XTaTLiq5GNqy"
      },
      "outputs": [],
      "source": [
        "NUM_EXAMPLES_PER_USER = 1000\n",
        "BATCH_SIZE = 100\n",
        "\n",
342
        "\n",
Michael Reneer's avatar
Michael Reneer 已提交
343
344
345
346
347
348
        "def get_data_for_digit(source, digit):\n",
        "  output_sequence = []\n",
        "  all_samples = [i for i, d in enumerate(source[1]) if d == digit]\n",
        "  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):\n",
        "    batch_samples = all_samples[i:i + BATCH_SIZE]\n",
        "    output_sequence.append({\n",
349
350
351
352
353
354
        "        'x':\n",
        "            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],\n",
        "                     dtype=np.float32),\n",
        "        'y':\n",
        "            np.array([source[1][i] for i in batch_samples], dtype=np.int32)\n",
        "    })\n",
Michael Reneer's avatar
Michael Reneer 已提交
355
356
        "  return output_sequence\n",
        "\n",
357
        "\n",
Michael Reneer's avatar
Michael Reneer 已提交
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        "federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]\n",
        "\n",
        "federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xpNdBimWaMHD"
      },
      "source": [
        "As a quick sanity check, let's look at the `Y` tensor in the last batch of data\n",
        "contributed by the fifth client (the one corresponding to the digit `5`)."
      ]
    },
    {
      "cell_type": "code",
375
      "execution_count": 8,
Michael Reneer's avatar
Michael Reneer 已提交
376
      "metadata": {
377
        "executionInfo": {
378
          "elapsed": 12,
379
          "status": "ok",
380
          "timestamp": 1638423036150,
381
382
383
384
385
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
386
          "user_tz": 480
387
388
        },
        "id": "bTNuL1W4bcuc",
389
        "outputId": "9ca5ed7f-6569-42ca-896f-4076cabb0dc4"
390
391
392
393
394
395
396
397
398
399
400
401
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)"
            ]
          },
402
          "execution_count": 8,
403
          "metadata": {},
404
405
406
          "output_type": "execute_result"
        }
      ],
Michael Reneer's avatar
Michael Reneer 已提交
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
      "source": [
        "federated_train_data[5][-1]['y']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xgvcwv7Obhat"
      },
      "source": [
        "Just to be sure, let's also look at the image corresponding to the last element of that batch."
      ]
    },
    {
      "cell_type": "code",
422
      "execution_count": 9,
Michael Reneer's avatar
Michael Reneer 已提交
423
      "metadata": {
424
        "colab": {
425
          "height": 265
426
427
        },
        "executionInfo": {
428
          "elapsed": 259,
429
          "status": "ok",
430
          "timestamp": 1638423036580,
431
432
433
434
435
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
436
          "user_tz": 480
437
438
        },
        "id": "cI4aat1za525",
439
        "outputId": "e670d012-301a-4dcf-cea0-9aa715e0b877"
440
441
442
443
      },
      "outputs": [
        {
          "data": {
444
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAAN6klEQVR4nO3dfaxU9Z3H8c/Ha4sijQGMhlB20canjXGtEt2EZtHU1od/pBJI\nMTbqNqEJmlSzyS52/9Bk3WhcuutfPlAfYNdqUyNWggutAbN2MWm8GlaxbCurbotcQReM+BQVvvvH\nPWyueOc3l5kzcwa+71dyMzPne8853wz3wzkzvzPzc0QIwJHvqKYbANAfhB1IgrADSRB2IAnCDiRx\ndD93Zpu3/oEeiwiPt7yrI7vtS23/zvY228u62RaA3nKn4+y2hyT9XtK3JG2X9LykxRHx28I6HNmB\nHuvFkf18Sdsi4rWI+ETSzyRd0cX2APRQN2GfKemPYx5vr5Z9ju0ltodtD3exLwBd6uYNuvFOFb5w\nmh4RKyStkDiNB5rUzZF9u6RZYx5/VdKO7toB0CvdhP15SafaPtn2lyV9V9KaetoCULeOT+Mj4jPb\nN0j6paQhSQ9GxCu1dQagVh0PvXW0M16zAz3Xk4tqABw+CDuQBGEHkiDsQBKEHUiCsANJEHYgCcIO\nJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1Ioq9TNmPwXH/9\n9cX6Bx98UKyvXLmyxm4+b/bs2cX6UUeVj1WLFi1qWZs58wszlX3O0qVLi/WLL764WH/mmWeK9SZw\nZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnT27+/PnF+kUXXVSsT58+vVjfvHlzy9pVV11VXPfq\nq68u1oeGhor1brz//vvF+p49e3q2717pKuy235C0V9I+SZ9FxJw6mgJQvzqO7BdFxDs1bAdAD/Ga\nHUii27CHpF/ZfsH2kvF+wfYS28O2h7vcF4AudHsaPzcidtg+UdLTtv8rIp4d+wsRsULSCkmyHV3u\nD0CHujqyR8SO6naXpCcknV9HUwDq13HYbR9n+ysH7kv6tqQtdTUGoF7dnMafJOkJ2we280hErK+l\nKxw27rzzzmI9YjBfud10003F+rp164r1bdu21dlOX3Qc9oh4TdKf19gLgB5i6A1IgrADSRB2IAnC\nDiRB2IEk+IjrEaAa/hzX3Llzi+vOmzev7nYm7KOPPirW9+7dW6yvX18e6b3tttta1l5//fXiuoM6\nZNgNjuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kIT7OZ7IN9X0xpQpU1rW3n333Z7u+5NPPinW16xZ\n07K2fPny4rrDw3yTWSciYtwLLziyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASfJ79CLBw4cLG9r10\n6dJifeXKlf1pBG1xZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwwsWrSoWL/rrrt6tu+77767\nWGcc/fDR9shu+0Hbu2xvGbNsmu2nbb9a3U7tbZsAujWR0/iVki49aNkySRsi4lRJG6rHAAZY27BH\nxLOSdh+0+ApJq6r7qyTNr7ctAHXr9DX7SRExIkkRMWL7xFa/aHuJpCUd7gdATXr+Bl1ErJC0QuIL\nJ4EmdTr0ttP2DEmqbnfV1xKAXug07GskXVPdv0bSk/W0A6BX2n5vvO1HJV0o6QRJOyXdIukXkn4u\n6U8k/UHSwog4+E288bbFafw4Jk+eXKw/99xzxfpZZ53V8b43btxYrC9YsKBYbzeHOvqv1ffGt33N\nHhGLW5S+2VVHAPqKy2WBJAg7kARhB5Ig7EAShB1Igo+49sGkSZOK9fvuu69Y72ZorZ3bb7+9WGdo\n7cjBkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcvQ8uvPDCYn3x4lYfLOy9K6+8slg/++yzi/X3\n3nuvWH/ooYcOuSf0Bkd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii7VdJ17qzpF8l/dRTTxXrl156\n8LyZh4+jjiofL558svWUAu2elwceeKBY379/f7GeVauvkubIDiRB2IEkCDuQBGEHkiDsQBKEHUiC\nsANJMM7eB+eee26xfs899xTr5513Xsf73rp1a7E+MjJSrM+aNatYP+2004r1bv6+li1bVqwvX768\n420fyToeZ7f9oO1dtreMWXar7Tdtb65+Lq+zWQD1m8hp/EpJ413i9c8RcU7182/1tgWgbm3DHhHP\nStrdh14A9FA3b9DdYPul6jR/aqtfsr3E9rDt4S72BaBLnYb9Hklfk3SOpBFJP271ixGxIiLmRMSc\nDvcFoAYdhT0idkbEvojYL+knks6vty0Adeso7LZnjHn4HUlbWv0ugMHQdpzd9qOSLpR0gqSdkm6p\nHp8jKSS9IekHEVEesFXecfZ2Jk+eXKyfcsopHW/7zTffLNb37NlTrE+fPr1YP/3004v1m2++uWXt\nsssuK667b9++Yn3+/PnF+rp164r1I1Wrcfa2k0RExHgzGJS/VQDAwOFyWSAJwg4kQdiBJAg7kARh\nB5LgI641OPbYY4v1jz/+uFjv579Bvw0NDbWsbd68ubjumWeeWaxv2rSpWJ83b16xfqTiq6SB5Ag7\nkARhB5Ig7EAShB1IgrADSRB2IIm2n3rDqOOPP75l7ZFHHimuu3DhwmL9ww8/7Kinw8GUKVNa1o45\n5piutn300fz5HgqO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBAOVEzRnTusJbS655JLiuu2mNW73\nue5BVhpHl6SHH364Ze3kk0+uux0UcGQHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQYZ++D9evXF+ul\naY0l6bHHHquznUNy7bXXFuu33HJLsT516tSO9/3pp58W6/fee2/H286o7ZHd9izbz9jeavsV2z+s\nlk+z/bTtV6vbzv9VAfTcRE7jP5P01xFxpqS/kHS97T+TtEzShog4VdKG6jGAAdU27BExEhEvVvf3\nStoqaaakKyStqn5tlaT5PeoRQA0O6TW77dmSvi7pN5JOiogRafQ/BNsntlhniaQlXfYJoEsTDrvt\nKZIel3RjRLxnjzt33BdExApJK6ptHLkzGAIDbkJDb7a/pNGg/zQiVleLd9qeUdVnSNrVmxYB1KHt\nlM0ePYSvkrQ7Im4cs/wfJf1vRNxhe5mkaRHxN222ddge2S+44IKWtY0bNxbXnTRpUt3tDIx2Z3il\nv689e/YU1203JHn//fcX61m1mrJ5IqfxcyV9T9LLtjdXy34k6Q5JP7f9fUl/kFT+cnQAjWob9oj4\nD0mt/vv+Zr3tAOgVLpcFkiDsQBKEHUiCsANJEHYgibbj7LXu7DAeZy+57rrrivV2H8UcGhqqs52+\najfO/vbbb7esLViwoLjupk2bOuopu1bj7BzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtn74Iwz\nzijWV69eXay3m/K5l9pNJ7127dpivXSNwVtvvdVJS2iDcXYgOcIOJEHYgSQIO5AEYQeSIOxAEoQd\nSIJxduAIwzg7kBxhB5Ig7EAShB1IgrADSRB2IAnCDiTRNuy2Z9l+xvZW26/Y/mG1/Fbbb9reXP1c\n3vt2AXSq7UU1tmdImhERL9r+iqQXJM2XtEjS+xGxfMI746IaoOdaXVQzkfnZRySNVPf32t4qaWa9\n7QHotUN6zW57tqSvS/pNtegG2y/ZftD21BbrLLE9bHu4u1YBdGPC18bbniLp3yX9Q0Sstn2SpHck\nhaS/1+ip/l+12Qan8UCPtTqNn1DYbX9J0lpJv4yIfxqnPlvS2og4q812CDvQYx1/EMaj03Q+IGnr\n2KBXb9wd8B1JW7ptEkDvTOTd+G9I+rWklyXtrxb/SNJiSedo9DT+DUk/qN7MK22LIzvQY12dxteF\nsAO9x+fZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSbT9\nwsmavSPpf8Y8PqFaNogGtbdB7Uuit07V2duftir09fPsX9i5PRwRcxproGBQexvUviR661S/euM0\nHkiCsANJNB32FQ3vv2RQexvUviR661Rfemv0NTuA/mn6yA6gTwg7kEQjYbd9qe3f2d5me1kTPbRi\n+w3bL1fTUDc6P101h94u21vGLJtm+2nbr1a3486x11BvAzGNd2Ga8Uafu6anP+/7a3bbQ5J+L+lb\nkrZLel7S4oj4bV8bacH2G5LmRETjF2DY/ktJ70v6lwNTa9m+U9LuiLij+o9yakT87YD0dqsOcRrv\nHvXWaprxa9Xgc1fn9OedaOLIfr6kbRHxWkR8Iulnkq5ooI+BFxHPStp90OIrJK2q7q/S6B9L37Xo\nbSBExEhEvFjd3yvpwDTjjT53hb76oomwz5T0xzGPt2uw5nsPSb+y/YLtJU03M46TDkyzVd2e2HA/\nB2s7jXc/HTTN+MA8d51Mf96tJsI+3tQ0gzT+NzcizpV0maTrq9NVTMw9kr6m0TkARyT9uMlmqmnG\nH5d0Y0S812QvY43TV1+etybCvl3SrDGPvyppRwN9jCsidlS3uyQ9odGXHYNk54EZdKvbXQ338/8i\nYmdE7IuI/ZJ+ogafu2qa8ccl/TQiVleLG3/uxuurX89bE2F/XtKptk+2/WVJ35W0poE+vsD2cdUb\nJ7J9nKRva/Cmol4j6Zrq/jWSnmywl88ZlGm8W00zroafu8anP4+Ivv9Iulyj78j/t6S/a6KHFn2d\nIuk/q59Xmu5N0qMaPa37VKNnRN+XNF3SBkmvVrfTBqi3f9Xo1N4vaTRYMxrq7RsafWn4kqTN1c/l\nTT93hb768rxxuSyQBFfQAUkQdiAJwg4kQdiBJAg7kARhB5Ig7EAS/wfgQlrpjsiFUAAAAABJRU5E\nrkJggg==\n",
445
            "text/plain": [
446
              "\u003cFigure size 600x400 with 1 Axes\u003e"
447
448
            ]
          },
449
          "metadata": {},
450
451
452
          "output_type": "display_data"
        }
      ],
Michael Reneer's avatar
Michael Reneer 已提交
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
      "source": [
        "from matplotlib import pyplot as plt\n",
        "\n",
        "plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')\n",
        "plt.grid(False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J-ox58PA56f8"
      },
      "source": [
        "### On combining TensorFlow and TFF\n",
        "\n",
        "In this tutorial, for compactness we immediately decorate functions that\n",
        "introduce TensorFlow logic with `tff.tf_computation`. However, for more complex\n",
        "logic, this is not the pattern we recommend. Debugging TensorFlow can already be\n",
        "a challenge, and debugging TensorFlow after it has been fully serialized and\n",
        "then re-imported necessarily loses some metadata and limits interactivity,\n",
        "making debugging even more of a challenge.\n",
        "\n",
        "Therefore, **we strongly recommend writing complex TF logic as stand-alone\n",
        "Python functions** (that is, without `tff.tf_computation` decoration). This way\n",
        "the TensorFlow logic can be developed and tested using TF best practices and\n",
        "tools (like eager mode), before serializing the computation for TFF (e.g., by invoking `tff.tf_computation` with a Python function as the argument)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RSd6UatXbzw-"
      },
      "source": [
        "### Defining a loss function\n",
        "\n",
        "Now that we have the data, let's define a loss function that we can use for\n",
        "training. First, let's define the type of input as a TFF named tuple. Since the\n",
        "size of data batches may vary, we set the batch dimension to `None` to indicate\n",
        "that the size of this dimension is unknown."
      ]
    },
    {
      "cell_type": "code",
498
      "execution_count": 10,
Michael Reneer's avatar
Michael Reneer 已提交
499
      "metadata": {
500
        "colab": {
501
          "height": 35
502
503
        },
        "executionInfo": {
504
          "elapsed": 14,
505
          "status": "ok",
506
          "timestamp": 1638423036890,
507
508
509
510
511
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
512
          "user_tz": 480
513
514
        },
        "id": "653xv5NXd4fy",
515
        "outputId": "c38bfd5c-fad3-460f-8696-cde20a1db68b"
516
517
518
519
      },
      "outputs": [
        {
          "data": {
520
521
522
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
523
524
525
526
            "text/plain": [
              "'\u003cx=float32[?,784],y=int32[?]\u003e'"
            ]
          },
527
          "execution_count": 10,
528
          "metadata": {},
529
530
531
532
          "output_type": "execute_result"
        }
      ],
      "source": [
533
534
535
        "BATCH_SPEC = collections.OrderedDict(\n",
        "    x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),\n",
        "    y=tf.TensorSpec(shape=[None], dtype=tf.int32))\n",
536
        "BATCH_TYPE = tff.to_type(BATCH_SPEC)\n",
Michael Reneer's avatar
Michael Reneer 已提交
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        "\n",
        "str(BATCH_TYPE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pb6qPUvyh5A1"
      },
      "source": [
        "You may be wondering why we can't just define an ordinary Python type. Recall\n",
        "the discussion in [part 1](custom_federated_algorithms_1.ipynb), where we\n",
        "explained that while we can express the logic of TFF computations using Python,\n",
        "under the hood TFF computations *are not* Python. The symbol `BATCH_TYPE`\n",
        "defined above represents an abstract TFF type specification. It is important to\n",
        "distinguish this *abstract* TFF type from concrete Python *representation*\n",
        "types, e.g., containers such as `dict` or `collections.namedtuple` that may be\n",
        "used to represent the TFF type in the body of a Python function. Unlike Python,\n",
555
        "TFF has a single abstract type constructor `tff.StructType` for tuple-like\n",
Michael Reneer's avatar
Michael Reneer 已提交
556
557
558
559
560
561
562
563
564
565
566
        "containers, with elements that can be individually named or left unnamed. This\n",
        "type is also used to model formal parameters of computations, as TFF\n",
        "computations can formally only declare one parameter and one result - you will\n",
        "see examples of this shortly.\n",
        "\n",
        "Let's now define the TFF type of model parameters, again as a TFF named tuple of\n",
        "*weights* and *bias*."
      ]
    },
    {
      "cell_type": "code",
567
      "execution_count": 11,
Michael Reneer's avatar
Michael Reneer 已提交
568
      "metadata": {
569
        "executionInfo": {
570
          "elapsed": 55,
571
          "status": "ok",
572
          "timestamp": 1638423037338,
573
574
575
576
577
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
578
          "user_tz": 480
579
580
        },
        "id": "Og7VViafh-30",
581
        "outputId": "ea062043-a778-4fde-b961-572b6a220897"
582
583
584
585
586
587
588
589
590
591
592
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u003cweights=float32[784,10],bias=float32[10]\u003e\n"
          ]
        }
      ],
      "source": [
593
594
595
        "MODEL_SPEC = collections.OrderedDict(\n",
        "    weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),\n",
        "    bias=tf.TensorSpec(shape=[10], dtype=tf.float32))\n",
596
597
        "MODEL_TYPE = tff.to_type(MODEL_SPEC)\n",
        "print(MODEL_TYPE)"
Michael Reneer's avatar
Michael Reneer 已提交
598
599
600
601
602
603
604
605
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iHhdaWSpfQxo"
      },
      "source": [
606
        "With those definitions in place, now we can define the loss for a given model, over a single batch. Note the usage of `@tf.function` decorator inside the `@tff.tf_computation` decorator. This allows us to write TF using Python like semantics even though were inside a `tf.Graph` context created by the `tff.tf_computation` decorator."
Michael Reneer's avatar
Michael Reneer 已提交
607
608
609
610
      ]
    },
    {
      "cell_type": "code",
611
      "execution_count": 12,
Michael Reneer's avatar
Michael Reneer 已提交
612
      "metadata": {
613
614
615
616
617
618
619
620
621
622
623
        "executionInfo": {
          "elapsed": 90,
          "status": "ok",
          "timestamp": 1638423037559,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 480
        },
Michael Reneer's avatar
Michael Reneer 已提交
624
625
626
627
        "id": "4EObiz_Ke0uK"
      },
      "outputs": [],
      "source": [
628
629
630
631
        "# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can \n",
        "# be later called from within another tf.function. Necessary because a\n",
        "# @tf.function  decorated method cannot invoke a @tff.tf_computation.\n",
        "\n",
632
        "@tf.function\n",
633
        "def forward_pass(model, batch):\n",
634
635
636
637
        "  predicted_y = tf.nn.softmax(\n",
        "      tf.matmul(batch['x'], model['weights']) + model['bias'])\n",
        "  return -tf.reduce_mean(\n",
        "      tf.reduce_sum(\n",
638
639
640
641
642
        "          tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))\n",
        "\n",
        "@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)\n",
        "def batch_loss(model, batch):\n",
        "  return forward_pass(model, batch)"
Michael Reneer's avatar
Michael Reneer 已提交
643
644
645
646
647
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
648
        "id": "8K0UZHGnr8SB"
Michael Reneer's avatar
Michael Reneer 已提交
649
650
651
652
653
654
655
656
657
658
      },
      "source": [
        "As expected, computation `batch_loss` returns `float32` loss given the model and\n",
        "a single data batch. Note how the `MODEL_TYPE` and `BATCH_TYPE` have been lumped\n",
        "together into a 2-tuple of formal parameters; you can recognize the type of\n",
        "`batch_loss` as `(\u003cMODEL_TYPE,BATCH_TYPE\u003e -\u003e float32)`."
      ]
    },
    {
      "cell_type": "code",
659
      "execution_count": 13,
Michael Reneer's avatar
Michael Reneer 已提交
660
      "metadata": {
661
        "colab": {
662
          "height": 35
663
664
        },
        "executionInfo": {
665
          "elapsed": 24,
666
          "status": "ok",
667
          "timestamp": 1638423037759,
668
669
670
671
672
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
673
          "user_tz": 480
674
675
        },
        "id": "4WXEAY8Nr89V",
676
        "outputId": "6e893496-83d9-44de-c1e2-7bb6972e4844"
677
678
679
680
      },
      "outputs": [
        {
          "data": {
681
682
683
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
684
            "text/plain": [
685
              "'(\u003cmodel=\u003cweights=float32[784,10],bias=float32[10]\u003e,batch=\u003cx=float32[?,784],y=int32[?]\u003e\u003e -\u003e float32)'"
686
687
            ]
          },
688
          "execution_count": 13,
689
          "metadata": {},
690
691
692
          "output_type": "execute_result"
        }
      ],
Michael Reneer's avatar
Michael Reneer 已提交
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
      "source": [
        "str(batch_loss.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pAnt_UcdnvGa"
      },
      "source": [
        "As a sanity check, let's construct an initial model filled with zeros and\n",
        "compute the loss over the batch of data we visualized above."
      ]
    },
    {
      "cell_type": "code",
709
      "execution_count": 14,
Michael Reneer's avatar
Michael Reneer 已提交
710
      "metadata": {
711
        "executionInfo": {
712
          "elapsed": 168,
713
          "status": "ok",
714
          "timestamp": 1638423038220,
715
716
717
718
719
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
720
          "user_tz": 480
721
722
        },
        "id": "U8Ne8igan3os",
723
        "outputId": "2be9bcc5-ed6a-44f6-fbbb-28bc4da24381"
724
725
726
727
728
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
729
              "2.3025851"
730
731
            ]
          },
732
          "execution_count": 14,
733
          "metadata": {},
734
735
736
          "output_type": "execute_result"
        }
      ],
Michael Reneer's avatar
Michael Reneer 已提交
737
      "source": [
738
739
740
        "initial_model = collections.OrderedDict(\n",
        "    weights=np.zeros([784, 10], dtype=np.float32),\n",
        "    bias=np.zeros([10], dtype=np.float32))\n",
Michael Reneer's avatar
Michael Reneer 已提交
741
742
743
744
745
746
747
748
749
750
751
752
753
754
        "\n",
        "sample_batch = federated_train_data[5][-1]\n",
        "\n",
        "batch_loss(initial_model, sample_batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ckigEAyDAWFV"
      },
      "source": [
        "Note that we feed the TFF computation with the initial model defined as a\n",
        "`dict`, even though the body of the Python function that defines it consumes\n",
Michael Reneer's avatar
Michael Reneer 已提交
755
        "model parameters as `model['weight']` and `model['bias']`. The arguments of the call\n",
Michael Reneer's avatar
Michael Reneer 已提交
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
        "to `batch_loss` aren't simply passed to the body of that function.\n",
        "\n",
        "\n",
        "What happens when we invoke `batch_loss`?\n",
        "The Python body of `batch_loss` has already been traced and serialized  in the above cell where it was defined.  TFF acts as the caller to `batch_loss`\n",
        "at the computation definition time, and as the target of invocation at the time\n",
        "`batch_loss` is invoked. In both roles, TFF serves as the bridge between TFF's\n",
        "abstract type system and Python representation types. At the invocation time,\n",
        "TFF will accept most standard Python container types (`dict`, `list`, `tuple`,\n",
        "`collections.namedtuple`, etc.) as concrete representations of abstract TFF\n",
        "tuples. Also, although as noted above, TFF computations formally only accept a\n",
        "single parameter, you can use the familiar Python call syntax with positional\n",
        "and/or keyword arguments in case where the type of the parameter is a tuple - it\n",
        "works as expected."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eB510nILYbId"
      },
      "source": [
        "### Gradient descent on a single batch\n",
        "\n",
        "Now, let's define a computation that uses this loss function to perform a single\n",
        "step of gradient descent. Note how in defining this function, we use\n",
        "`batch_loss` as a subcomponent. You can invoke a computation constructed with\n",
        "`tff.tf_computation` inside the body of another computation, though typically\n",
        "this is not necessary - as noted above, because serialization looses some\n",
        "debugging information, it is often preferable for more complex computations to\n",
        "write and test all the TensorFlow without the `tff.tf_computation` decorator."
      ]
    },
    {
      "cell_type": "code",
791
      "execution_count": 15,
Michael Reneer's avatar
Michael Reneer 已提交
792
      "metadata": {
793
794
795
796
797
798
799
800
801
802
803
        "executionInfo": {
          "elapsed": 189,
          "status": "ok",
          "timestamp": 1638423038575,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 480
        },
Michael Reneer's avatar
Michael Reneer 已提交
804
805
806
807
808
809
        "id": "O4uaVxw3AyYS"
      },
      "outputs": [],
      "source": [
        "@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)\n",
        "def batch_train(initial_model, batch, learning_rate):\n",
810
811
        "  # Define a group of model variables and set them to `initial_model`. Must\n",
        "  # be defined outside the @tf.function.\n",
812
813
814
815
        "  model_vars = collections.OrderedDict([\n",
        "      (name, tf.Variable(name=name, initial_value=value))\n",
        "      for name, value in initial_model.items()\n",
        "  ])\n",
816
        "  optimizer = tf.keras.optimizers.SGD(learning_rate)\n",
Michael Reneer's avatar
Michael Reneer 已提交
817
        "\n",
818
819
820
821
822
823
824
825
826
        "  @tf.function\n",
        "  def _train_on_batch(model_vars, batch):\n",
        "    # Perform one step of gradient descent using loss from `batch_loss`.\n",
        "    with tf.GradientTape() as tape:\n",
        "      loss = forward_pass(model_vars, batch)\n",
        "    grads = tape.gradient(loss, model_vars)\n",
        "    optimizer.apply_gradients(\n",
        "        zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))\n",
        "    return model_vars\n",
Michael Reneer's avatar
Michael Reneer 已提交
827
        "\n",
828
        "  return _train_on_batch(model_vars, batch)"
Michael Reneer's avatar
Michael Reneer 已提交
829
830
831
832
      ]
    },
    {
      "cell_type": "code",
833
      "execution_count": 16,
Michael Reneer's avatar
Michael Reneer 已提交
834
      "metadata": {
835
        "colab": {
836
          "height": 53
837
838
        },
        "executionInfo": {
839
          "elapsed": 67,
840
          "status": "ok",
841
          "timestamp": 1638423038889,
842
843
844
845
846
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
847
          "user_tz": 480
848
849
        },
        "id": "Y84gQsaohC38",
850
        "outputId": "518a7e89-8097-4f98-9e38-3024d520d246"
851
852
853
854
      },
      "outputs": [
        {
          "data": {
855
856
857
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
858
            "text/plain": [
859
              "'(\u003cinitial_model=\u003cweights=float32[784,10],bias=float32[10]\u003e,batch=\u003cx=float32[?,784],y=int32[?]\u003e,learning_rate=float32\u003e -\u003e \u003cweights=float32[784,10],bias=float32[10]\u003e)'"
860
861
            ]
          },
862
          "execution_count": 16,
863
          "metadata": {},
864
865
866
          "output_type": "execute_result"
        }
      ],
Michael Reneer's avatar
Michael Reneer 已提交
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
      "source": [
        "str(batch_train.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ID8xg9FCUL2A"
      },
      "source": [
        "When you invoke a Python function decorated with `tff.tf_computation` within the\n",
        "body of another such function, the logic of the inner TFF computation is\n",
        "embedded (essentially, inlined) in the logic of the outer one. As noted above,\n",
        "if you are writing both computations, it is likely preferable to make the inner\n",
        "function (`batch_loss` in this case) a regular Python or `tf.function` rather\n",
        "than a `tff.tf_computation`. However, here we illustrate that calling one\n",
        "`tff.tf_computation` inside another basically works as expected. This may be\n",
        "necessary if, for example, you do not have the Python code defining\n",
        "`batch_loss`, but only its serialized TFF representation.\n",
        "\n",
        "Now, let's apply this function a few times to the initial model to see whether\n",
        "the loss decreases."
      ]
    },
    {
      "cell_type": "code",
893
      "execution_count": 17,
Michael Reneer's avatar
Michael Reneer 已提交
894
      "metadata": {
895
896
897
898
899
900
901
902
903
904
905
        "executionInfo": {
          "elapsed": 340,
          "status": "ok",
          "timestamp": 1638423039494,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 480
        },
Michael Reneer's avatar
Michael Reneer 已提交
906
907
908
909
910
911
912
913
914
915
916
917
918
        "id": "8edcJTlXUULm"
      },
      "outputs": [],
      "source": [
        "model = initial_model\n",
        "losses = []\n",
        "for _ in range(5):\n",
        "  model = batch_train(model, sample_batch, 0.1)\n",
        "  losses.append(batch_loss(model, sample_batch))"
      ]
    },
    {
      "cell_type": "code",
919
      "execution_count": 18,
Michael Reneer's avatar
Michael Reneer 已提交
920
      "metadata": {
921
        "executionInfo": {
922
          "elapsed": 11,
923
          "status": "ok",
924
          "timestamp": 1638423039687,
925
926
927
928
929
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
930
          "user_tz": 480
931
932
        },
        "id": "3n1onojT1zHv",
933
        "outputId": "9c8587ce-c294-49a5-8c0b-6f78914e3786"
934
935
936
937
938
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
939
              "[0.19690023, 0.13176313, 0.10113225, 0.08273812, 0.070301384]"
940
941
            ]
          },
942
          "execution_count": 18,
943
          "metadata": {},
944
945
946
          "output_type": "execute_result"
        }
      ],
Michael Reneer's avatar
Michael Reneer 已提交
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
      "source": [
        "losses"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EQk4Ha8PU-3P"
      },
      "source": [
        "### Gradient descent on a sequence of local data\n",
        "\n",
        "Now, since `batch_train` appears to work, let's write a similar training\n",
        "function `local_train` that consumes the entire sequence of all batches from one\n",
        "user instead of just a single batch. The new computation will need to now\n",
        "consume `tff.SequenceType(BATCH_TYPE)` instead of `BATCH_TYPE`."
      ]
    },
    {
      "cell_type": "code",
967
      "execution_count": 19,
Michael Reneer's avatar
Michael Reneer 已提交
968
      "metadata": {
969
970
971
972
973
974
975
976
977
978
979
        "executionInfo": {
          "elapsed": 72,
          "status": "ok",
          "timestamp": 1638423039885,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 480
        },
Michael Reneer's avatar
Michael Reneer 已提交
980
981
982
983
984
985
986
987
988
        "id": "EfPD5a6QVNXM"
      },
      "outputs": [],
      "source": [
        "LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)\n",
        "\n",
        "@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)\n",
        "def local_train(initial_model, learning_rate, all_batches):\n",
        "\n",
989
990
991
992
993
994
        "  @tff.tf_computation(LOCAL_DATA_TYPE, tf.float32)\n",
        "  def _insert_learning_rate_to_sequence(dataset, learning_rate):\n",
        "    return dataset.map(lambda x: (x, learning_rate))\n",
        "\n",
        "  batches_with_learning_rate = _insert_learning_rate_to_sequence(all_batches, learning_rate)\n",
        "\n",
Michael Reneer's avatar
Michael Reneer 已提交
995
        "  # Mapping function to apply to each batch.\n",
996
997
998
999
        "  @tff.federated_computation(MODEL_TYPE, batches_with_learning_rate.type_signature.element)\n",
        "  def batch_fn(model, batch_with_lr):\n",
        "    batch, lr = batch_with_lr\n",
        "    return batch_train(model, batch, lr)\n",
Michael Reneer's avatar
Michael Reneer 已提交
1000
        "\n",
For faster browsing, not all history is shown. View entire blame