提交 dcffe1f2 编辑于 作者: Jakub Konecny's avatar Jakub Konecny 提交者: tensorflow-copybara
浏览文件

Updated random noise generation tutorial to remove logarithm from recommentations.

Logarithm of timestamp will return 35 for next many years thus not suitable for random seed generation :|

PiperOrigin-RevId: 391689896
上级 b272adc4
%% Cell type:markdown id: tags:
##### Copyright 2021 The TensorFlow Federated Authors.
%% Cell type:code id: tags:
```
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```
%% Cell type:markdown id: tags:
# Random noise generation in TFF
This tutorial will discuss the recommended best practices for random noise generation in TFF. Random noise generation is an important component of many privacy protection techniques in federated learning algorithms, e.g., differential privacy.
%% Cell type:markdown id: tags:
<table class="tfo-notebook-buttons" align="left">
<td>
<a target="_blank" href="https://www.tensorflow.org/federated/tutorials/random_noise_generation"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
</td>
<td>
<a target="_blank" href="https://colab.research.google.com/github/tensorflow/federated/blob/master/docs/tutorials/random_noise_generation.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
</td>
<td>
<a target="_blank" href="https://github.com/tensorflow/federated/blob/master/docs/tutorials/random_noise_generation.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
</td>
<td>
<a href="https://storage.googleapis.com/tensorflow_docs/federated/docs/tutorials/random_noise_generation.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
</td>
</table>
%% Cell type:markdown id: tags:
## Before we begin
First, let us make sure the notebook is connected to a backend that has the relevant components compiled.
%% Cell type:code id: tags:
```
#@test {"skip": true}
!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio
import nest_asyncio
nest_asyncio.apply()
```
%% Cell type:code id: tags:
```
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
```
%% Cell type:markdown id: tags:
Run the following "Hello World"
example to make sure the TFF environment is correctly setup. If it doesn't work,
please refer to the [Installation](../install.md) guide for instructions.
%% Cell type:code id: tags:
```
@tff.federated_computation
def hello_world():
return 'Hello, World!'
hello_world()
```
%% Output
b'Hello, World!'
%% Cell type:markdown id: tags:
## Random noise on clients
The need for noise on clients generally falls into two cases: identical noise and i.i.d. noise.
* For identical noise, the recommended pattern is to maintain a seed on the server, broadcast it to clients, and use the `tf.random.stateless`
functions to generate noise.
* For i.i.d. noise, use a tf.random.Generator initialized on the client with from_non_deterministic_state, in keeping with TF's recommendation to avoid the tf.random.\<distribution\> functions.
Client behavior is different from server (doesn't suffer from the pitfalls discussed later) because each client will build their own computation graph and initialize their own default seed.
%% Cell type:markdown id: tags:
### Identical noise on clients
%% Cell type:code id: tags:
```
# Set to use 10 clients.
tff.backends.native.set_local_execution_context(num_clients=10)
@tff.tf_computation
def noise_from_seed(seed):
return tf.random.stateless_normal((), seed=seed)
seed_type_at_server = tff.type_at_server(tff.to_type((tf.int64, [2])))
@tff.federated_computation(seed_type_at_server)
def get_random_min_and_max_deterministic(seed):
# Broadcast seed to all clients.
seed_on_clients = tff.federated_broadcast(seed)
# Clients generate noise from seed deterministicly.
noise_on_clients = tff.federated_map(noise_from_seed, seed_on_clients)
# Aggregate and return the min and max of the values generated on clients.
min = tff.aggregators.federated_min(noise_on_clients)
max = tff.aggregators.federated_max(noise_on_clients)
return min, max
seed = tf.constant([1, 1], dtype=tf.int64)
min, max = get_random_min_and_max_deterministic(seed)
assert min == max
print(f'Seed: {seed.numpy()}. All clients sampled value {min:8.3f}.')
seed += 1
min, max = get_random_min_and_max_deterministic(seed)
assert min == max
print(f'Seed: {seed.numpy()}. All clients sampled value {min:8.3f}.')
```
%% Output
Seed: [1 1]. All clients sampled value 1.665.
Seed: [2 2]. All clients sampled value -0.219.
%% Cell type:markdown id: tags:
### Independent noise on clients
%% Cell type:code id: tags:
```
@tff.tf_computation
def nondeterministic_noise():
gen = tf.random.Generator.from_non_deterministic_state()
return gen.normal(())
@tff.federated_computation(seed_type_at_server)
def get_random_min_and_max_nondeterministic(seed):
noise_on_clients = tff.federated_eval(nondeterministic_noise, tff.CLIENTS)
min = tff.aggregators.federated_min(noise_on_clients)
max = tff.aggregators.federated_max(noise_on_clients)
return min, max
min, max = get_random_min_and_max_nondeterministic(seed)
assert min != max
print(f'Values differ across clients. {min:8.3f},{max:8.3f}.')
new_min, new_max = get_random_min_and_max_nondeterministic(seed)
assert new_min != new_max
assert new_min != min and new_max != max
print(f'Values differ across rounds. {new_min:8.3f},{new_max:8.3f}.')
```
%% Output
Values differ across clients. -1.810, 1.079.
Values differ across rounds. -1.205, 0.851.
%% Cell type:markdown id: tags:
## Random noise on the server
%% Cell type:markdown id: tags:
### Discouraged usage: directly using `tf.random.normal`
TF1.x like APIs `tf.random.normal` for random noise generation are strongly discouraged in TF2 according to the [random noise generation tutorial in TF](https://www.tensorflow.org/guide/random_numbers). Surprising behavior may happen when these APIs are used together with `tf.function` and `tf.random.set_seed`. For example, the following code will generate the same value with each call. This surprising behavior is expected for TF, and explanation can be found in the [documentation of `tf.random.set_seed`](https://www.tensorflow.org/api_docs/python/tf/random/set_seed).
%% Cell type:code id: tags:
```
tf.random.set_seed(1)
@tf.function
def return_one_noise(_):
return tf.random.normal([])
n1=return_one_noise(1)
n2=return_one_noise(2)
assert n1 == n2
print(n1.numpy(), n2.numpy())
```
%% Output
0.3052047 0.3052047
%% Cell type:markdown id: tags:
In TFF, things are slightly different. If we wrap the noise generation as `tff.tf_computation` instead of `tf.function`, non-deterministic random noise will be generated. However, if we run this code snippet multiple times, different set of `(n1, n2)` will be generated each time. There is no easy way to set a global random seed for TFF.
%% Cell type:code id: tags:
```
tf.random.set_seed(1)
@tff.tf_computation
def return_one_noise(_):
return tf.random.normal([])
n1=return_one_noise(1)
n2=return_one_noise(2)
assert n1 != n2
print(n1, n2)
```
%% Output
1.3283143 0.45740178
%% Cell type:markdown id: tags:
Moreover, deterministic noise can be generated in TFF without explicitly setting a seed. The function `return_two_noise` in the following code snippet returns two identical noise values. This is expected behavior because TFF will build computation graph in advance before execution. However, this suggests users have to pay attention on the usage of `tf.random.normal` in TFF.
%% Cell type:code id: tags:
```
@tff.tf_computation
def tff_return_one_noise():
return tf.random.normal([])
@tff.federated_computation
def return_two_noise():
return (tff_return_one_noise(), tff_return_one_noise())
n1, n2=return_two_noise()
assert n1 == n2
print(n1, n2)
```
%% Output
-0.15665223 -0.15665223
%% Cell type:markdown id: tags:
### Usage with care: `tf.random.Generator`
We can use `tf.random.Generator` as suggested in the [TF tutorial](https://www.tensorflow.org/guide/random_numbers).
%% Cell type:code id: tags:
```
@tff.tf_computation
def tff_return_one_noise(i):
g=tf.random.Generator.from_seed(i)
@tf.function
def tf_return_one_noise():
return g.normal([])
return tf_return_one_noise()
@tff.federated_computation
def return_two_noise():
return (tff_return_one_noise(1), tff_return_one_noise(2))
n1, n2 = return_two_noise()
assert n1 != n2
print(n1, n2)
```
%% Output
0.3052047 -0.38260338
%% Cell type:markdown id: tags:
However, users may have to be careful on its usage
* `tf.random.Generator` uses `tf.Variable` to maintain the states for RNG algorithms. In TFF, it is recommended to contruct the generator inside a `tff.tf_computation`; and it is difficult to pass the generator and its state between `tff.tf_computation` functions.
* the previous code snippet also relies on carefully setting seeds in generators. We will get expected but surprising results (deterministic `n1==n2`) if we use `tf.random.Generator.from_non_deterministic_state()` instead.
In general, TFF prefers functional operations and we will showcase the usage of `tf.random.stateless_*` functions in the following sections.
%% Cell type:markdown id: tags:
In TFF for federated learning, we often work with nested structures instead of scalars and the previous code snippet can be naturally extended to nested structures.
%% Cell type:code id: tags:
```
@tff.tf_computation
def tff_return_one_noise(i):
g=tf.random.Generator.from_seed(i)
weights = [
tf.ones([2, 2], dtype=tf.float32),
tf.constant([2], dtype=tf.float32)
]
@tf.function
def tf_return_one_noise():
return tf.nest.map_structure(lambda x: g.normal(tf.shape(x)), weights)
return tf_return_one_noise()
@tff.federated_computation
def return_two_noise():
return (tff_return_one_noise(1), tff_return_one_noise(2))
n1, n2 = return_two_noise()
assert n1[1] != n2[1]
print('n1', n1)
print('n2', n2)
```
%% Output
n1 [array([[0.3052047 , 0.5671378 ],
[0.41852272, 0.2326421 ]], dtype=float32), array([1.1675092], dtype=float32)]
n2 [array([[-0.38260338, -0.47804865],
[-0.5187485 , -1.8471988 ]], dtype=float32), array([-0.77835274], dtype=float32)]
%% Cell type:markdown id: tags:
### Recommended usage: `tf.random.stateless_*` with a helper
%% Cell type:markdown id: tags:
A general recommendation in TFF is to use the functional `tf.random.stateless_*` functions for random noise generation. These functions take `seed` as an explicit input argument to generate random noise. We first define a helper class to maintain the seed as pseudo state. The helper `RandomSeedGenerator` has functional operators in a state-in-state-out fashion. It is reasonable to use a counter as pseudo state for `tf.random.stateless_*` as these functions scramble the seed before using it to make noises generated by correlated seeds statistically uncorrelated.
A general recommendation in TFF is to use the functional `tf.random.stateless_*` functions for random noise generation. These functions take `seed` (a Tensor with shape `[2]` or a `tuple` of two scalar tensors) as an explicit input argument to generate random noise. We first define a helper class to maintain the seed as pseudo state. The helper `RandomSeedGenerator` has functional operators in a state-in-state-out fashion. It is reasonable to use a counter as pseudo state for `tf.random.stateless_*` as these functions [scramble](https://github.com/tensorflow/tensorflow/blob/919f693420e35d00c8d0a42100837ae3718f7927/tensorflow/core/kernels/stateless_random_ops.cc#L50-L64) the seed before using it to make noises generated by correlated seeds statistically uncorrelated.
%% Cell type:code id: tags:
```
def timestamp_seed():
# tf.timestamp returns microseconds as decimal places, thus scaling by 1e6.
return tf.math.cast(tf.timestamp() * 1e6, tf.int64)
class RandomSeedGenerator():
def initialize(self, seed=None):
if seed is None:
return tf.cast(tf.stack(
[tf.math.floor(tf.timestamp()*1e6),
tf.math.floor(tf.math.log(tf.timestamp()*1e6))]), dtype=tf.int64)
return tf.stack([timestamp_seed(), 0])
else:
return tf.constant(self.seed, dtype=tf.int64, shape=(2,))
def next(self, state):
return state + 1
return state + tf.constant([0, 1], tf.int64)
def structure_next(self, state, nest_structure):
"Returns seed in nexted structure and the next state seed."
"Returns seed in nested structure and the next state seed."
flat_structure = tf.nest.flatten(nest_structure)
flat_seeds = [state+i for i in range(len(flat_structure))]
flat_seeds = [state + tf.constant([0, i], tf.int64) for
i in range(len(flat_structure))]
nest_seeds = tf.nest.pack_sequence_as(nest_structure, flat_seeds)
return nest_seeds, flat_seeds[-1] + 1
return nest_seeds, flat_seeds[-1] + tf.constant([0, 1], tf.int64)
```
%% Cell type:markdown id: tags:
Now let us use the helper class and `tf.random.stateless_normal` to generate (nested structure of) random noise in TFF. The following code snippet looks a lot like a TFF iterative process, see [simple_fedavg](https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/examples/simple_fedavg/simple_fedavg_tff.py) as an example of expressing federated learning algorithm as TFF iterative process. The pseudo seed state here for random noise generation is `tf.Tensor` that can be easily transported in TFF and TF functions.
%% Cell type:code id: tags:
```
@tff.tf_computation
def tff_return_one_noise(seed_state):
g=RandomSeedGenerator()
weights = [
tf.ones([2, 2], dtype=tf.float32),
tf.constant([2], dtype=tf.float32)
]
@tf.function
def tf_return_one_noise():
nest_seeds, updated_state = g.structure_next(seed_state, weights)
nest_noise = tf.nest.map_structure(lambda x,s: tf.random.stateless_normal(
shape=tf.shape(x), seed=s), weights, nest_seeds)
return nest_noise, updated_state
return tf_return_one_noise()
@tff.tf_computation
def tff_init_state():
g=RandomSeedGenerator()
return g.initialize()
@tff.federated_computation
def return_two_noise():
seed_state = tff_init_state()
n1, seed_state = tff_return_one_noise(seed_state)
n2, seed_state = tff_return_one_noise(seed_state)
return (n1, n2)
n1, n2 = return_two_noise()
assert n1[1] != n2[1]
print('n1', n1)
print('n2', n2)
```
%% Output
n1 [array([[-0.21598858, -0.30700883],
[ 0.7562299 , -0.21218438]], dtype=float32), array([-1.0359321], dtype=float32)]
n2 [array([[ 1.0722181 , 0.81287116],
[-0.7140338 , 0.5896157 ]], dtype=float32), array([0.44190162], dtype=float32)]
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册