building_your_own_federated_learning_algorithm.ipynb 49.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vkdnLiKk71g-"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "0asMuNro71hA"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MnUwFbCAKB2r"
      },
      "source": [
        "## Before we start\n",
        "\n",
        "Before we start, please run the following to make sure that your environment is\n",
        "correctly setup. If you don't see a greeting, please refer to the\n",
        "[Installation](../install.md) guide for instructions. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZrGitA_KnRO0"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow_federated\n",
57
58
59
60
        "!pip install --quiet --upgrade nest_asyncio\n",
        "\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()\n"
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HGTM6tWOLo8M"
      },
      "outputs": [],
      "source": [
        "import collections\n",
        "import attr\n",
        "import functools\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "\n",
        "np.random.seed(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yr3ztf28fa1F"
      },
      "source": [
        "**NOTE**: This colab has been verified to work with the [latest released version](https://github.com/tensorflow/federated#compatibility) of the `tensorflow_federated` pip package, but the Tensorflow Federated project is still in pre-release development and may not work on `master`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Zv28F7QLo8O"
      },
      "source": [
        "# Overview\n",
        "\n",
        "In the [image classification](federated_learning_for_image_classification.ipynb) and\n",
        "[text generation](federated_learning_for_text_generation.ipynb) tutorials, we learned how to set up model and data pipelines for Federated Learning (FL), and performed federated training via the `tff.learning` API layer of TFF.\n",
        "\n",
        "This is only the tip of the iceberg when it comes to FL research. In this tutorial, we discuss how to implement federated learning algorithms *without* deferring to the `tff.learning` API. We aim to accomplish the following:\n",
        "\n",
        "**Goals:**\n",
        "\n",
        "\n",
        "*   Understand the general structure of federated learning algorithms.\n",
        "*   Explore the *Federated Core* of TFF.\n",
        "*   Use the Federated Core to implement Federated Averaging directly.\n",
        "\n",
        "While this tutorial is self-contained, we recommend first reading the [image classification](federated_learning_for_image_classification.ipynb) and\n",
        "[text generation](federated_learning_for_text_generation.ipynb) tutorials.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hQ_N9XbULo8P"
      },
      "source": [
        "## Preparing the input data\n",
        "We first load and preprocess the EMNIST dataset included in TFF. For more details, see the [image classification](federated_learning_for_image_classification.ipynb) tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-WdnFluLLo8P"
      },
      "outputs": [],
      "source": [
        "emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kq8893GogB8E"
      },
      "source": [
        "In order to feed the dataset into our model, we flatten the data, and convert each example into a tuple of the form `(flattened_image_vector, label)`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Blrh8zJgLo8R"
      },
      "outputs": [],
      "source": [
        "NUM_CLIENTS = 10\n",
        "BATCH_SIZE = 20\n",
        "\n",
        "def preprocess(dataset):\n",
        "\n",
        "  def batch_format_fn(element):\n",
        "    \"\"\"Flatten a batch of EMNIST data and return a (features, label) tuple.\"\"\"\n",
        "    return (tf.reshape(element['pixels'], [-1, 784]), \n",
        "            tf.reshape(element['label'], [-1, 1]))\n",
        "\n",
        "  return dataset.batch(BATCH_SIZE).map(batch_format_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Piy8EzqmgNqV"
      },
      "source": [
        "We now sample a small number of clients, and apply the preprocessing above to their datasets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-vYM_IT7Lo8W"
      },
      "outputs": [],
      "source": [
        "client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False)\n",
        "\n",
        "federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))\n",
        "  for x in client_ids\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gNO_Y9j_Lo8X"
      },
      "source": [
        "## Preparing the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LJ0I89ixz8yV"
      },
      "source": [
        "We use the same model as in the [image classification](federated_learning_for_image_classification.ipynb) tutorial. This model (implemented via `tf.keras`) has a single hidden layer, followed by a softmax layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Yfld4oFNLo8Y"
      },
      "outputs": [],
      "source": [
        "def create_keras_model():\n",
        "  return tf.keras.models.Sequential([\n",
        "      tf.keras.layers.Input(shape=(784,)),\n",
        "      tf.keras.layers.Dense(10, kernel_initializer='zeros'),\n",
        "      tf.keras.layers.Softmax(),\n",
        "  ])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vLln0Q8G0Bky"
      },
      "source": [
        "In order to use this model in TFF, we wrap the Keras model as a [`tff.learning.Model`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model). This allows us to perform the model's [forward pass](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model#forward_pass) within TFF, and [extract model outputs](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model#report_local_outputs). For more details, also see the [image classification](federated_learning_for_image_classification.ipynb) tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SPwbipTNLo8a"
      },
      "outputs": [],
      "source": [
        "def model_fn():\n",
        "  keras_model = create_keras_model()\n",
        "  return tff.learning.from_keras_model(\n",
        "      keras_model,\n",
        "      input_spec=federated_train_data[0].element_spec,\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
        "      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pCxa44rFiere"
      },
      "source": [
        "While we used `tf.keras` to create a `tff.learning.Model`, TFF supports much more general models. These models have the following relevant attributes capturing the model weights:\n",
        "\n",
        "*   `trainable_variables`: An iterable of the tensors corresponding to trainable layers.\n",
        "*   `non_trainable_variables`: An iterable of the tensors corresponding to non-trainable layers.\n",
        "\n",
        "For our purposes, we will only use the `trainable_variables`. (as our model only has those!)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fPOWP2JjsfTk"
      },
      "source": [
        "# Building your own Federated Learning algorithm\n",
        "\n",
        "While the `tff.learning` API allows one to create many variants of Federated Averaging, there are other federated algorithms that do not fit neatly into this framework. For example, you may want to add regularization, clipping, or more complicated algorithms such as [federated GAN training](https://github.com/tensorflow/federated/tree/master/tensorflow_federated/python/research/gans). You may also be instead be interested in [federated analytics](https://ai.googleblog.com/2020/05/federated-analytics-collaborative-data.html)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "50N36Zz8qyY-"
      },
      "source": [
        "For these more advanced algorithms, we'll have to write our own custom algorithm using TFF. In many cases, federated algorithms have 4 main components:\n",
        "\n",
        "1. A server-to-client broadcast step.\n",
        "2. A local client update step.\n",
        "3. A client-to-server upload step.\n",
        "4. A server update step."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jH8s0GRdQt3b"
      },
      "source": [
        "In TFF, we generally represent federated algorithms as a [`tff.templates.IterativeProcess`](https://www.tensorflow.org/federated/api_docs/python/tff/templates/IterativeProcess) (which we refer to as just an `IterativeProcess` throughout). This is a class that contains `initialize` and `next` functions. Here, `initialize` is used to initialize the server, and `next` will perform one communication round of the federated algorithm. Let's write a skeleton of what our iterative process for FedAvg should look like.\n",
        "\n",
        "First, we have an initialize function that simply creates a `tff.learning.Model`, and returns its trainable weights."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ylLpRa7T5DDh"
      },
      "outputs": [],
      "source": [
        "def initialize_fn():\n",
        "  model = model_fn()\n",
        "  return model.trainable_variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nb1-XAK8fB2A"
      },
      "source": [
        "This function looks good, but as we will see later, we will need to make a small modification to make it a \"TFF computation\".\n",
        "\n",
        "We also want to sketch the `next_fn`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IeHN-XLZfMso"
      },
      "outputs": [],
      "source": [
        "def next_fn(server_weights, federated_dataset):\n",
        "  # Broadcast the server weights to the clients.\n",
        "  server_weights_at_client = broadcast(server_weights)\n",
        "\n",
        "  # Each client computes their updated weights.\n",
        "  client_weights = client_update(federated_dataset, server_weights_at_client)\n",
        "\n",
        "  # The server averages these updates.\n",
        "  mean_client_weights = mean(client_weights)\n",
        "\n",
        "  # The server updates its model.\n",
        "  server_weights = server_update(mean_client_weights)\n",
        "\n",
        "  return server_weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uWXvjXPWeujU"
      },
      "source": [
        "We'll focus on implementing these four components separately. We first focus on the parts that can be implemented in pure TensorFlow, namely the client and server update steps.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3yKS4VkALo8g"
      },
      "source": [
        "## TensorFlow Blocks "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bxpNYucgLo8g"
      },
      "source": [
        "### Client update\n",
        "\n",
        "We will use our `tff.learning.Model` to do client training in essentially the same way you would train a TensorFlow model. In particular, we will use `tf.GradientTape` to compute the gradient on batches of data, then apply these gradient using a `client_optimizer`. We focus only on the trainable weights.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c5rHPKreLo8g"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def client_update(model, dataset, server_weights, client_optimizer):\n",
        "  \"\"\"Performs training (using the server model weights) on the client's dataset.\"\"\"\n",
        "  # Initialize the client model with the current server weights.\n",
        "  client_weights = model.trainable_variables\n",
        "  # Assign the server weights to the client model.\n",
        "  tf.nest.map_structure(lambda x, y: x.assign(y),\n",
        "                        client_weights, server_weights)\n",
        "\n",
        "  # Use the client_optimizer to update the local model.\n",
        "  for batch in dataset:\n",
        "    with tf.GradientTape() as tape:\n",
        "      # Compute a forward pass on the batch of data\n",
        "      outputs = model.forward_pass(batch)\n",
        "\n",
        "    # Compute the corresponding gradient\n",
        "    grads = tape.gradient(outputs.loss, client_weights)\n",
        "    grads_and_vars = zip(grads, client_weights)\n",
        "\n",
        "    # Apply the gradient using a client optimizer.\n",
        "    client_optimizer.apply_gradients(grads_and_vars)\n",
        "\n",
        "  return client_weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pP0D9XtoLo8i"
      },
      "source": [
        "### Server Update\n",
        "\n",
        "The server update for FedAvg is simpler than the client update. We will implement \"vanilla\" federated averaging, in which we simply replace the server model weights by the average of the client model weights. Again, we only focus on the trainable weights."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rYxErLvHLo8i"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def server_update(model, mean_client_weights):\n",
        "  \"\"\"Updates the server model weights as the average of the client model weights.\"\"\"\n",
        "  model_weights = model.trainable_variables\n",
        "  # Assign the mean client weights to the server model.\n",
        "  tf.nest.map_structure(lambda x, y: x.assign(y),\n",
        "                        model_weights, mean_client_weights)\n",
        "  return model_weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ddCklfWlVr1U"
      },
      "source": [
        "The snippet could be simplified by simply returning the `mean_client_weights`. However, more advanced implementations of Federated Averaging use `mean_client_weights` with more sophisticated techniques, such as momentum or adaptivity.\n",
        "\n",
        "**Challenge**: Implement a version of `server_update` that updates the server weights to be the midpoint of model_weights and mean_client_weights. (Note: This kind of \"midpoint\" approach is analogous to recent work on the [Lookahead optimizer](https://arxiv.org/abs/1907.08610)!)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KuP9g6RFLo8k"
      },
      "source": [
        "So far, we've only written pure TensorFlow code. This is by design, as TFF allows you to use much of the TensorFlow code you're already familiar with. However, now we have to specify the **orchestration logic**, that is, the logic that dictates what the server broadcasts to the client, and what the client uploads to the server.\n",
        "\n",
        "This will require the *Federated Core* of TFF."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0CgFLVPgLo8l"
      },
      "source": [
        "# Introduction to the Federated Core\n",
        "\n",
        "The Federated Core (FC) is a set of lower-level interfaces that serve as the foundation for the `tff.learning` API. However, these interfaces are not limited to learning. In fact, they can be used for analytics and many other computations over distributed data.\n",
        "\n",
        "At a high-level, the federated core is a development environment that enables compactly expressed program logic to combine TensorFlow code with distributed communication operators (such as distributed sums and broadcasts). The goal is to give researchers and practitioners expliict control over the distributed communication in their systems, without requiring system implementation details (such as specifying point-to-point network message exchanges).\n",
        "\n",
        "One key point is that TFF is designed for privacy-preservation. Therefore, it allows explicit control over where data resides, to prevent unwanted accumulation of data at the centralized server location."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EYinjNqZLo8l"
      },
      "source": [
        "## Federated data\n",
        "\n",
        "A key concept in TFF is \"federated data\", which refers to a collection of data items hosted across a group of devices in a distributed system (eg. client datasets, or the server model weights). We model the entire collection of data items across all devices as a single *federated value*.\n",
        "\n",
        "For example, suppose we have client devices that each have a float representing the temperature of a sensor. We could represent it as a *federated float* by"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7EJY0MHpLo8l"
      },
      "outputs": [],
      "source": [
        "federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JSQAXD0FLo8n"
      },
      "source": [
        "Federated types are specified by a type `T` of its member constituents (eg. `tf.float32`) and a group `G` of devices. We will focus on the cases where `G` is either `tff.CLIENTS` or `tff.SERVER`. Such a federated type is represented as `{T}@G`, as shown below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "6mlPgubJLo8n",
        "outputId": "ea9e3297-59d5-492e-b684-fc32ce3aaa06"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'{float32}@CLIENTS'"
            ]
          },
          "execution_count": 0,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(federated_float_on_clients)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pjAQytkeLo8o"
      },
      "source": [
        "Why do we care so much about placements? A key goal of TFF is to enable writing code that could be deployed on a real distributed system. This means that it is vital to reason about which subsets of devices execute which code, and where different pieces of data reside.\n",
        "\n",
        "TFF focuses on three things: *data*, where the data is *placed*, and how the data is being *transformed*. The first two are encapsulated in federated types, while the last is encapsulated in *federated computations*."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZLT2FmVMLo8p"
      },
      "source": [
        "## Federated computations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-XwDC1vTLo8p"
      },
      "source": [
        "TFF is a strongly-typed functional programming environment whose basic units are *federated computations*. These are pieces of logic that accept federated values as input, and return federated values as output.\n",
        "\n",
        "For example, suppose we wanted to average the temperatures on our client sensors. We could define the following (using our federated float):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IfwXDNR1Lo8p"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))\n",
        "def get_average_temperature(client_temperatures):\n",
        "  return tff.federated_mean(client_temperatures)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iSgs6Te5Lo8r"
      },
      "source": [
        "You might ask, how is this different from the `tf.function` decorator in TensorFlow? The key answer is that the code generated by `tff.federated_computation` is neither TensorFlow nor Python code; It is a specification of a distributed system in an internal platform-independent *glue language*.\n",
        "\n",
        "While this may sound complicated, you can think of TFF computations as functions with well-defined type signatures. These type signatures can be directly queried."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "wAG1eDlULo8r",
        "outputId": "b5c11982-c125-4d93-89fb-19d89bf63b81"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'({float32}@CLIENTS -\u003e float32@SERVER)'"
            ]
          },
          "execution_count": 0,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(get_average_temperature.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TveOYFfuLo8s"
      },
      "source": [
        "This `tff.federated_computation` accepts arguments of federated type `{float32}@CLIENTS`, and returns values of federated type `{float32}@SERVER`. Federated computations may also go from server to client, from client to client, or from server to server. Federated computations can also be composed like normal functions, as long as their type signatures match up.\n",
        "\n",
        "To support development, TFF allows you to invoke a `tff.federated_computation` as a Python function. For example, we can call"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "id": "eFoqtuOTLo8t",
        "outputId": "7857100c-73b2-46b6-9226-a6a9a5196e19"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "69.53334"
            ]
          },
          "execution_count": 0,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "get_average_temperature([68.5, 70.3, 69.8])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZXn-yje9RJ6H"
      },
      "source": [
        "## Non-eager computations and TensorFlow"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nwyj8f3HLo8w"
      },
      "source": [
        "There are two key restrictions to be aware of. First, when the Python interpreter encounters a `tff.federated_computation` decorator, the function is traced once and serialized for future use. Due to the decentralized nature of Federated Learning, this future usage may occur elsewhere, such as a remote execution environment. Therefore, TFF computations are fundamentally *non-eager*. This behavior is somewhat analogous to that of the [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) decorator in TensorFlow.\n",
        "\n",
        "Second, a federated computation can only consist of federated operators (such as `tff.federated_mean`), they cannot contain TensorFlow operations. TensorFlow code must be confined to blocks decorated with `tff.tf_computation`. Most ordinary TensorFlow code can be directly decorated, such as the following function that takes a number and adds `0.5` to it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "huz3mNmMLo8w"
      },
      "outputs": [],
      "source": [
        "@tff.tf_computation(tf.float32)\n",
        "def add_half(x):\n",
        "  return tf.add(x, 0.5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5ptjWALDLo8y"
      },
      "source": [
        "These also have type signatures, but *without placements*. For example, we can call"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "xfAb0vG2Lo8y",
        "outputId": "604ec830-7777-4d37-a19f-4bca90013ab4"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'(float32 -\u003e float32)'"
            ]
          },
          "execution_count": 0,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(add_half.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WNjwrNMjLo8z"
      },
      "source": [
        "Here we see an important difference between `tff.federated_computation` and `tff.tf_computation`. The former has explicit placements, while the latter does not.\n",
        "\n",
        "We can use `tff.tf_computation` blocks in federated computations by specifying placements. Let's create a function that adds half, but only to federated floats at the clients. We can do this by using `tff.federated_map`, which applies a given `tff.tf_computation`, while preserving the placement."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pG6nw3wiLo80"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))\n",
        "def add_half_on_clients(x):\n",
        "  return tff.federated_map(add_half, x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h4msKRKJLo81"
      },
      "source": [
        "This function is almost identical to `add_half`, except that it only accepts values with placement at `tff.CLIENTS`, and returns values with the same placement. We can see this in its type signature:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "_C2hDiz0Lo82",
        "outputId": "1233ab4f-dbf8-43bd-c1ea-3aeda3c00735"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'({float32}@CLIENTS -\u003e {float32}@CLIENTS)'"
            ]
          },
          "execution_count": 0,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(add_half_on_clients.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3JxQ0DeiLo83"
      },
      "source": [
        "In summary:\n",
        "\n",
        "*   TFF operates on federated values.\n",
        "*   Each federated value has a *federated type*, with a *type* (eg. `tf.float32`) and a *placement* (eg. `tff.CLIENTS`).\n",
        "*   Federated values can be transformed using *federated computations*, which must be decorated with `tff.federated_computation` and a federated type signature.\n",
        "*   TensorFlow code must be contained in blocks with `tff.tf_computation` decorators. \n",
        "*   These blocks can then be incorporated into federated computations.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PvyFWox3Lo83"
      },
      "source": [
        "# Building your own Federated Learning algorithm, revisited\n",
        "\n",
        "Now that we've gotten a glimpse of the Federated Core, we can build our own federated learning algorithm. Remember that above, we defined an `initialize_fn` and `next_fn` for our algorithm. The `next_fn` will make use of the `client_update` and `server_update` we defined using pure TensorFlow code.\n",
        "\n",
        "However, in order to make our algorithm a federated computation, we will need both the `next_fn` and `initialize_fn` to each be a `tff.federated_computation`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CvY8fh1cLo84"
      },
      "source": [
        "## TensorFlow Federated blocks "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g0zNTO7LLo84"
      },
      "source": [
        "### Creating the initialization computation\n",
        "\n",
        "The initialize function will be quite simple: We will create a model using `model_fn`. However, remember that we must separate out our TensorFlow code using `tff.tf_computation`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jJY9xUBZLo84"
      },
      "outputs": [],
      "source": [
        "@tff.tf_computation\n",
        "def server_init():\n",
        "  model = model_fn()\n",
        "  return model.trainable_variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SGlv8LLgLo85"
      },
      "source": [
        "We can then pass this directly into a federated computation using `tff.federated_value`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m2hinzuRLo86"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation\n",
        "def initialize_fn():\n",
        "  return tff.federated_value(server_init(), tff.SERVER)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NFBghOgxLo88"
      },
      "source": [
        "### Creating the `next_fn`\n",
        "\n",
        "We now use our client and server update code to write the actual algorithm. We will first turn our `client_update` into a `tff.tf_computation` that accepts a client datasets and server weights, and outputs an updated client weights tensor.\n",
        "\n",
        "We will need the corresponding types to properly decorate our function. Luckily, the type of the server weights can be extracted directly from our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ph_noHN2Lo88"
      },
      "outputs": [],
      "source": [
        "dummy_model = model_fn()\n",
        "tf_dataset_type = tff.SequenceType(dummy_model.input_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WMPgpTaW66qx"
      },
      "source": [
        "Let's look at the dataset type signature. Remember that we took 28 by 28 images (with integer labels) and flattened them."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "Ju91izuz64wD",
        "outputId": "b0006b79-5394-4ce1-b812-493f69b83bbc"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'\u003cfloat32[?,784],int32[?,1]\u003e*'"
            ]
          },
          "execution_count": 0,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(tf_dataset_type)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kuS8d0BHLo8-"
      },
      "source": [
        "We can also extract the model weights type by using our `server_init` function above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4yx6CExMLo8-"
      },
      "outputs": [],
      "source": [
        "model_weights_type = server_init.type_signature.result"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mS-Eh6Xj7J15"
      },
      "source": [
        "Examining the type signature, we'll be able to see the architecture of our model!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "nPdhhM1O7IIL",
        "outputId": "de14770d-8e37-4737-8595-34e01d8dbf91"
      },
      "outputs": [
For faster browsing, not all history is shown. View entire blame