Skip to content
Snippets Groups Projects
Commit e7376566 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by tensorflow-copybara
Browse files

Factor out fixed sampling function so it can be used for arbitrary sequences

PiperOrigin-RevId: 318882349
parent 3200e7ff
No related branches found
No related tags found
No related merge requests found
......@@ -107,6 +107,7 @@ py_library(
name = "training_utils",
srcs = ["training_utils.py"],
srcs_version = "PY3",
deps = ["//tensorflow_federated"],
)
py_test(
......
......@@ -15,10 +15,12 @@
import collections
import functools
from typing import Any, Callable, Optional, Sequence, Union
from absl import logging
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
# Settings for a multiplicative linear congruential generator (aka Lehmer
......@@ -119,28 +121,31 @@ def build_evaluate_fn(eval_dataset, model_builder, loss_builder,
return evaluate_fn
def build_client_datasets_fn(train_dataset,
train_clients_per_round,
random_seed=None):
"""Builds the function for generating client datasets at each round.
The function samples a number of clients and returns their datasets.
def build_sample_fn(
a: Union[Sequence[Any], int],
size: int,
replace: bool = False,
random_seed: Optional[int] = None) -> Callable[[int], np.ndarray]:
"""Builds the function for sampling from the input iterator at each round.
Args:
train_dataset: A `tff.simulation.ClientData` object.
train_clients_per_round: The number of client participants in each round.
random_seed: If random_seed is set as an integer, then we use it as a
random seed for which clients are sampled at each round. In this case,
we set a random seed before sampling clients according
to a multiplicative linear congruential generator (aka Lehmer generator,
see 'The Art of Computer Programming, Vol. 3' by Donald Knuth for
reference). This does not affect model initialization,
shuffling, or other such aspects of the federated training process.
a: A 1-D array-like sequence or int that satisfies np.random.choice.
size: The number of samples to return each round.
replace: A boolean indicating whether the sampling is done with replacement
(True) or without replacement (False).
random_seed: If random_seed is set as an integer, then we use it as a random
seed for which clients are sampled at each round. In this case, we set a
random seed before sampling clients according to a multiplicative linear
congruential generator (aka Lehmer generator, see 'The Art of Computer
Programming, Vol. 3' by Donald Knuth for reference). This does not affect
model initialization, shuffling, or other such aspects of the federated
training process. Note that this will alter the global numpy random seed.
Returns:
A function which returns a list of `tff.simulation.ClientData` objects at a
A function which returns a list of elements from the input iterator at a
given round round_num.
"""
if isinstance(random_seed, int):
np.random.seed(random_seed)
mlcg_start = np.random.randint(1, MLCG_MODULUS - 1)
......@@ -149,14 +154,48 @@ def build_client_datasets_fn(train_dataset,
return pow(MLCG_MULTIPLIER, round_num,
MLCG_MODULUS) * mlcg_start % MLCG_MODULUS
def client_datasets(round_num, random_seed):
def sample(round_num, random_seed):
if isinstance(random_seed, int):
np.random.seed(get_pseudo_random_int(round_num))
sampled_clients = np.random.choice(
train_dataset.client_ids, size=train_clients_per_round, replace=False)
return np.random.choice(a, size=size, replace=replace)
return functools.partial(sample, random_seed=random_seed)
def build_client_datasets_fn(train_dataset: tff.simulation.ClientData,
train_clients_per_round: int,
random_seed: Optional[int] = None):
"""Builds the function for generating client datasets at each round.
The function samples a number of clients (without replacement within a given
round, but with replacement across rounds) and returns their datasets.
Args:
train_dataset: A `tff.simulation.ClientData` object.
train_clients_per_round: The number of client participants in each round.
random_seed: If random_seed is set as an integer, then we use it as a random
seed for which clients are sampled at each round. In this case, we set a
random seed before sampling clients according to a multiplicative linear
congruential generator (aka Lehmer generator, see 'The Art of Computer
Programming, Vol. 3' by Donald Knuth for reference). This does not affect
model initialization, shuffling, or other such aspects of the federated
training process. Note that this will alter the global numpy random seed.
Returns:
A function which returns a list of `tff.simulation.ClientData` objects at a
given round round_num.
"""
sample_clients_fn = build_sample_fn(
train_dataset.client_ids,
size=train_clients_per_round,
replace=False,
random_seed=random_seed)
def client_datasets(round_num):
sampled_clients = sample_clients_fn(round_num)
return [
train_dataset.create_tf_dataset_for_client(client)
for client in sampled_clients
]
return functools.partial(client_datasets, random_seed=random_seed)
return client_datasets
......@@ -15,6 +15,7 @@
import collections
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
......@@ -50,7 +51,70 @@ def create_tf_dataset_for_client(client_id, batch_data=True):
return dataset
class TrainingUtilsTest(tf.test.TestCase):
class TrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
{
'testcase_name': '_int_no_replace',
'a': 100,
'replace': False
}, {
'testcase_name': '_int_replace',
'a': 5,
'replace': True
}, {
'testcase_name': '_sequence_no_replace',
'a': [str(i) for i in range(100)],
'replace': False
}, {
'testcase_name': '_sequence_replace',
'a': [str(i) for i in range(5)],
'replace': True
})
def test_build_sample_fn_with_random_seed(self, a, replace):
size = 10
random_seed = 1
round_num = 5
sample_fn_1 = training_utils.build_sample_fn(
a, size, replace=replace, random_seed=random_seed)
sample_1 = sample_fn_1(round_num)
sample_fn_2 = training_utils.build_sample_fn(
a, size, replace=replace, random_seed=random_seed)
sample_2 = sample_fn_2(round_num)
self.assertAllEqual(sample_1, sample_2)
@parameterized.named_parameters(
{
'testcase_name': '_int_no_replace',
'a': 100,
'replace': False
}, {
'testcase_name': '_int_replace',
'a': 5,
'replace': True
}, {
'testcase_name': '_sequence_no_replace',
'a': [str(i) for i in range(100)],
'replace': False
}, {
'testcase_name': '_sequence_replace',
'a': [str(i) for i in range(5)],
'replace': True
})
def test_build_sample_fn_without_random_seed(self, a, replace):
size = 10
round_num = 5
sample_fn_1 = training_utils.build_sample_fn(a, size, replace=replace)
sample_1 = sample_fn_1(round_num)
sample_fn_2 = training_utils.build_sample_fn(a, size, replace=replace)
sample_2 = sample_fn_2(round_num)
self.assertNotAllEqual(sample_1, sample_2)
def test_build_client_datasets_fn(self):
tff_dataset = tff.simulation.client_data.ConcreteClientData(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment