提交 594e012e 编辑于 作者: Zachary Garrett's avatar Zachary Garrett 提交者: tensorflow-copybara
浏览文件

Move learning integration tests into their own package under python/tests/.

PiperOrigin-RevId: 392785924
上级 4daa2df3
......@@ -81,13 +81,10 @@ py_cpu_gpu_test(
deps = [
":client_weight_lib",
":federated_averaging",
":keras_utils",
":model_examples",
":model_update_aggregator",
":model_utils",
"//tensorflow_federated/python/common_libs:test_utils",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/test:execution_contexts",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
......@@ -162,14 +159,12 @@ py_cpu_gpu_test(
deps = [
":client_weight_lib",
":federated_sgd",
":keras_utils",
":model_examples",
":model_update_aggregator",
":model_utils",
"//tensorflow_federated/python/common_libs:test_utils",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
......
......@@ -11,20 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for local client training implemented in ClientFedAvg.
Integration tests that include server averaging and alternative tff.aggregator
factories are in found in
tensorflow_federated/python/tests/federated_averaging_integration_test.py.
"""
import collections
import itertools
from unittest import mock
from absl.testing import parameterized
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.common_libs import test_utils
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.backends.test import execution_contexts
from tensorflow_federated.python.learning import client_weight_lib
from tensorflow_federated.python.learning import federated_averaging
from tensorflow_federated.python.learning import keras_utils
from tensorflow_federated.python.learning import model_examples
from tensorflow_federated.python.learning import model_update_aggregator
from tensorflow_federated.python.learning import model_utils
......@@ -32,26 +36,7 @@ from tensorflow_federated.python.learning.framework import dataset_reduce
from tensorflow_federated.python.learning.optimizers import sgdm
def _get_tff_optimizer(learning_rate=0.1):
return sgdm.build_sgdm(learning_rate=learning_rate)
def _get_keras_optimizer_fn(learning_rate=0.1):
return lambda: tf.keras.optimizers.SGD(learning_rate=learning_rate)
class NumExamplesCounter(tf.keras.metrics.Sum):
"""A `tf.keras.metrics.Metric` that counts the number of examples seen."""
def __init__(self, name='num_examples', dtype=tf.int64): # pylint: disable=useless-super-delegation
super().__init__(name, dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
return super().update_state(tf.shape(y_pred)[0], sample_weight)
class FederatedAveragingClientWithModelTest(test_case.TestCase,
parameterized.TestCase):
class FederatedAveragingClientTest(test_case.TestCase, parameterized.TestCase):
"""Tests of ClientFedAvg that use a common model and data."""
def create_dataset(self):
......@@ -86,7 +71,6 @@ class FederatedAveragingClientWithModelTest(test_case.TestCase,
'clipvalue': 0.1
}, 0.02),
)
@test_utils.skip_test_for_multi_gpu
def test_client_tf(self, weighted, simulation, optimizer_kwargs,
expected_norm):
model = self.create_model()
......@@ -146,7 +130,6 @@ class FederatedAveragingClientWithModelTest(test_case.TestCase,
dataset_reduce,
'_dataset_reduce_fn',
wraps=dataset_reduce._dataset_reduce_fn)
@test_utils.skip_test_for_multi_gpu
def test_client_tf_dataset_reduce_fn(self, simulation, mock_method):
model = self.create_model()
dataset = self.create_dataset()
......@@ -161,232 +144,36 @@ class FederatedAveragingClientWithModelTest(test_case.TestCase,
mock_method.assert_called()
class FederatedAveragingModelTffTest(test_case.TestCase,
parameterized.TestCase):
def _run_test(self, process, *, datasets, expected_num_examples):
state = process.initialize()
prev_loss = np.inf
aggregation_metrics = collections.OrderedDict(mean_value=(), mean_weight=())
for _ in range(3):
state, metric_outputs = process.next(state, datasets)
self.assertEqual(
list(metric_outputs.keys()),
['broadcast', 'aggregation', 'train', 'stat'])
self.assertEmpty(metric_outputs['broadcast'])
self.assertEqual(aggregation_metrics, metric_outputs['aggregation'])
train_metrics = metric_outputs['train']
self.assertEqual(train_metrics['num_examples'], expected_num_examples)
self.assertLess(train_metrics['loss'], prev_loss)
prev_loss = train_metrics['loss']
@parameterized.named_parameters([
('unweighted_keras_opt', client_weight_lib.ClientWeighting.UNIFORM,
_get_keras_optimizer_fn),
('example_weighted_keras_opt',
client_weight_lib.ClientWeighting.NUM_EXAMPLES, _get_keras_optimizer_fn),
('custom_weighted_keras_opt', lambda _: tf.constant(1.5),
_get_keras_optimizer_fn),
('unweighted_tff_opt', client_weight_lib.ClientWeighting.UNIFORM,
_get_tff_optimizer),
('example_weighted_tff_opt',
client_weight_lib.ClientWeighting.NUM_EXAMPLES, _get_tff_optimizer),
('custom_weighted_tff_opt', lambda _: tf.constant(1.5),
_get_tff_optimizer),
])
@test_utils.skip_test_for_multi_gpu
def test_basic_orchestration_execute(self, client_weighting,
client_optimizer):
iterative_process = federated_averaging.build_federated_averaging_process(
model_fn=model_examples.LinearRegression,
client_optimizer_fn=client_optimizer(),
client_weighting=client_weighting)
ds = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[1.0, 2.0], [3.0, 4.0]],
y=[[5.0], [6.0]],
)).batch(2)
num_clients = 3
self._run_test(
iterative_process,
datasets=[ds] * num_clients,
expected_num_examples=2 * num_clients)
@parameterized.named_parameters([
('functional_model_keras_opt',
model_examples.build_linear_regression_keras_functional_model,
_get_keras_optimizer_fn),
('sequential_model_keras_opt',
model_examples.build_linear_regression_keras_sequential_model,
_get_keras_optimizer_fn),
('functional_model_tff_opt',
model_examples.build_linear_regression_keras_functional_model,
_get_tff_optimizer),
('sequential_model_tff_opt',
model_examples.build_linear_regression_keras_sequential_model,
_get_tff_optimizer),
])
@test_utils.skip_test_for_multi_gpu
def test_orchestration_execute_from_keras(self, build_keras_model_fn,
client_optimizer):
ds = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[1.0, 2.0], [3.0, 4.0]],
y=[[5.0], [6.0]],
)).batch(2)
def model_fn():
keras_model = build_keras_model_fn(feature_dims=2)
return keras_utils.from_keras_model(
keras_model,
loss=tf.keras.losses.MeanSquaredError(),
input_spec=ds.element_spec,
metrics=[NumExamplesCounter()])
iterative_process = federated_averaging.build_federated_averaging_process(
model_fn=model_fn,
client_optimizer_fn=client_optimizer(learning_rate=0.01))
num_clients = 3
self._run_test(
iterative_process,
datasets=[ds] * num_clients,
expected_num_examples=2 * num_clients)
@parameterized.named_parameters([
('keras_opt', _get_keras_optimizer_fn),
('tff_opt', _get_tff_optimizer),
])
@test_utils.skip_test_for_multi_gpu
def test_orchestration_execute_from_keras_with_lookup(self, client_optimizer):
ds = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[['R'], ['G'], ['B']], y=[[1.0], [2.0], [3.0]])).batch(2)
def model_fn():
keras_model = model_examples.build_lookup_table_keras_model()
return keras_utils.from_keras_model(
keras_model,
loss=tf.keras.losses.MeanSquaredError(),
input_spec=ds.element_spec,
metrics=[NumExamplesCounter()])
iterative_process = federated_averaging.build_federated_averaging_process(
model_fn=model_fn, client_optimizer_fn=client_optimizer())
num_clients = 3
self._run_test(
iterative_process,
datasets=[ds] * num_clients,
expected_num_examples=3 * num_clients)
@parameterized.named_parameters([
('keras_opt', _get_keras_optimizer_fn),
('tff_opt', _get_tff_optimizer),
])
@test_utils.skip_test_for_multi_gpu
def test_execute_empty_data(self, client_optimizer):
iterative_process = federated_averaging.build_federated_averaging_process(
model_fn=model_examples.LinearRegression,
client_optimizer_fn=client_optimizer())
# Results in empty dataset with correct types and shapes.
ds = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[1.0, 2.0]],
y=[[5.0]],
)).batch(
5, drop_remainder=True)
server_state = iterative_process.initialize()
first_state, metric_outputs = iterative_process.next(server_state, [ds] * 2)
self.assertAllClose(
list(first_state.model.trainable), [[[0.0], [0.0]], 0.0])
self.assertEqual(metric_outputs['train']['num_examples'], 0)
self.assertTrue(tf.math.is_nan(metric_outputs['train']['loss']))
@parameterized.named_parameters([
('keras_opt', _get_keras_optimizer_fn),
('tff_opt', _get_tff_optimizer),
])
@test_utils.skip_test_for_multi_gpu
def test_get_model_weights(self, client_optimizer):
iterative_process = federated_averaging.build_federated_averaging_process(
model_fn=model_examples.LinearRegression,
client_optimizer_fn=client_optimizer())
num_clients = 3
ds = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[1.0, 2.0], [3.0, 4.0]],
y=[[5.0], [6.0]],
)).batch(2)
datasets = [ds] * num_clients
state = iterative_process.initialize()
self.assertIsInstance(
iterative_process.get_model_weights(state), model_utils.ModelWeights)
self.assertAllClose(state.model.trainable,
iterative_process.get_model_weights(state).trainable)
for _ in range(3):
state, _ = iterative_process.next(state, datasets)
self.assertIsInstance(
iterative_process.get_model_weights(state), model_utils.ModelWeights)
self.assertAllClose(state.model.trainable,
iterative_process.get_model_weights(state).trainable)
@parameterized.named_parameters([
('robust_tff_opt', model_update_aggregator.robust_aggregator,
_get_tff_optimizer),
('robust_keras_opt', model_update_aggregator.robust_aggregator,
_get_keras_optimizer_fn),
('dp_tff_opt', lambda: model_update_aggregator.dp_aggregator(1e-3, 3),
_get_tff_optimizer),
('dp_keras_opt', lambda: model_update_aggregator.dp_aggregator(1e-3, 3),
_get_keras_optimizer_fn),
('compression_tff_opt', model_update_aggregator.compression_aggregator,
_get_tff_optimizer),
('compression_keras_opt', model_update_aggregator.compression_aggregator,
_get_keras_optimizer_fn),
('secure_tff', model_update_aggregator.secure_aggregator,
_get_tff_optimizer),
('secure_keras_opt', model_update_aggregator.secure_aggregator,
_get_keras_optimizer_fn),
])
@test_utils.skip_test_for_multi_gpu
def test_recommended_aggregations_execute(self, default_aggregation,
client_optimizer):
process = federated_averaging.build_federated_averaging_process(
model_fn=model_examples.LinearRegression,
client_optimizer_fn=client_optimizer(),
model_update_aggregation_factory=default_aggregation())
ds = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[1.0, 2.0], [3.0, 4.0]],
y=[[5.0], [6.0]],
)).batch(2)
num_clients = 3
state = process.initialize()
state, metrics = process.next(state, [ds] * num_clients)
self.assertNotEmpty(metrics['aggregation'])
def test_construction_calls_model_fn(self):
class FederatedAveragingTest(test_case.TestCase, parameterized.TestCase):
"""Tests construction of FedAvg training process."""
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters((
'_'.join(name for name, _ in named_params),
*(param for _, param in named_params),
) for named_params in itertools.product([
('keras_optimizer', tf.keras.optimizers.SGD),
('tff_optimizer', sgdm.build_sgdm(learning_rate=0.1)),
], [
('robust_aggregator', model_update_aggregator.robust_aggregator),
('dp_aggregator', lambda: model_update_aggregator.dp_aggregator(1e-3, 3)),
('compression_aggregator',
model_update_aggregator.compression_aggregator),
('secure_aggreagtor', model_update_aggregator.secure_aggregator),
]))
# pylint: enable=g-complex-comprehension
def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory):
# Assert that the the process building does not call `model_fn` too many
# times. `model_fn` can potentially be expensive (loading weights,
# processing, etc).
mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression)
federated_averaging.build_federated_averaging_process(
model_fn=mock_model_fn, client_optimizer_fn=tf.keras.optimizers.SGD)
model_fn=mock_model_fn,
client_optimizer_fn=optimizer_fn,
model_update_aggregation_factory=aggregation_factory())
# TODO(b/186451541): reduce the number of calls to model_fn.
self.assertEqual(mock_model_fn.call_count, 3)
if __name__ == '__main__':
execution_contexts.set_test_execution_context()
test_case.main()
......@@ -11,6 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for local client training implemented in ClientSgd.
Integration tests that include server averaging and alternative tff.aggregator
factories are in found in
tensorflow_federated/python/tests/federated_sgd_integration_test.py.
"""
import collections
from unittest import mock
......@@ -21,22 +27,12 @@ import tensorflow as tf
from tensorflow_federated.python.common_libs import test_utils
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.learning import client_weight_lib
from tensorflow_federated.python.learning import federated_sgd
from tensorflow_federated.python.learning import keras_utils
from tensorflow_federated.python.learning import model_examples
from tensorflow_federated.python.learning import model_update_aggregator
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.framework import dataset_reduce
from tensorflow_federated.python.learning.optimizers import sgdm
def _get_tff_optimizer(learning_rate=0.1):
return sgdm.build_sgdm(learning_rate=learning_rate)
def _get_keras_optimizer_fn(learning_rate=0.1):
return lambda: tf.keras.optimizers.SGD(learning_rate=learning_rate)
class FederatedSgdTest(test_case.TestCase, parameterized.TestCase):
......@@ -137,184 +133,29 @@ class FederatedSgdTest(test_case.TestCase, parameterized.TestCase):
mock_method.assert_called()
class FederatedSGDTffTest(test_case.TestCase, parameterized.TestCase):
@parameterized.named_parameters([
('unweighted_keras_opt', client_weight_lib.ClientWeighting.UNIFORM,
_get_keras_optimizer_fn),
('example_weighted_keras_opt',
client_weight_lib.ClientWeighting.NUM_EXAMPLES, _get_keras_optimizer_fn),
('custom_weighted_keras_opt', lambda _: tf.constant(1.5),
_get_keras_optimizer_fn),
('unweighted_tff_opt', client_weight_lib.ClientWeighting.UNIFORM,
_get_tff_optimizer),
('example_weighted_tff_opt',
client_weight_lib.ClientWeighting.NUM_EXAMPLES, _get_tff_optimizer),
('custom_weighted_tff_opt', lambda _: tf.constant(1.5),
_get_tff_optimizer),
])
@test_utils.skip_test_for_multi_gpu
def test_orchestration_execute(self, client_weighting, server_optimizer):
iterative_process = federated_sgd.build_federated_sgd_process(
model_fn=model_examples.LinearRegression,
server_optimizer_fn=server_optimizer(),
client_weighting=client_weighting)
# Some data points along [x_1 + 2*x_2 + 3 = y], expecting to learn
# kernel = [1, 2], bias = [3].
ds1 = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[0.0, 0.0], [0.0, 1.0]],
y=[[3.0], [5.0]],
)).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[1.0, 2.0], [3.0, 4.0], [1.0, 0.0], [-1.0, -1.0]],
y=[[8.0], [14.0], [4.00], [0.0]],
)).batch(2)
federated_ds = [ds1, ds2]
server_state = iterative_process.initialize()
prev_loss = np.inf
num_iterations = 3
for _ in range(num_iterations):
server_state, metric_outputs = iterative_process.next(
server_state, federated_ds)
train_metrics = metric_outputs['train']
self.assertEqual(train_metrics['num_examples'],
num_iterations * len(federated_ds))
loss = train_metrics['loss']
self.assertLess(loss, prev_loss)
prev_loss = loss
@parameterized.named_parameters([
('functional_model_keras_opt',
model_examples.build_linear_regression_keras_functional_model,
_get_keras_optimizer_fn),
('sequential_model_keras_opt',
model_examples.build_linear_regression_keras_sequential_model,
_get_keras_optimizer_fn),
('functional_model_tff_opt',
model_examples.build_linear_regression_keras_functional_model,
_get_tff_optimizer),
('sequential_model_tff_opt',
model_examples.build_linear_regression_keras_sequential_model,
_get_tff_optimizer),
])
@test_utils.skip_test_for_multi_gpu
def test_orchestration_execute_from_keras(self, build_keras_model_fn,
server_optimizer):
# Some data points along [x_1 + 2*x_2 + 3 = y], expecting to learn
# kernel = [1, 2], bias = [3].
ds1 = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[0.0, 0.0], [0.0, 1.0]],
y=[[3.0], [5.0]],
)).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[1.0, 2.0], [3.0, 4.0], [1.0, 0.0], [-1.0, -1.0]],
y=[[8.0], [14.0], [4.00], [0.0]],
)).batch(2)
federated_ds = [ds1, ds2]
def model_fn():
# Note: we don't compile with an optimizer here; FedSGD does not use it.
keras_model = build_keras_model_fn(feature_dims=2)
return keras_utils.from_keras_model(
keras_model,
input_spec=ds1.element_spec,
loss=tf.keras.losses.MeanSquaredError())
iterative_process = federated_sgd.build_federated_sgd_process(
model_fn=model_fn, server_optimizer_fn=server_optimizer())
server_state = iterative_process.initialize()
prev_loss = np.inf
num_iterations = 3
for _ in range(num_iterations):
server_state, metrics = iterative_process.next(server_state, federated_ds)
new_loss = metrics['train']['loss']
self.assertLess(new_loss, prev_loss)
prev_loss = new_loss
@parameterized.named_parameters([
('keras_opt', _get_keras_optimizer_fn),
('tff_opt', _get_tff_optimizer),
])
@test_utils.skip_test_for_multi_gpu
def test_execute_empty_data(self, server_optimizer):
iterative_process = federated_sgd.build_federated_sgd_process(
model_fn=model_examples.LinearRegression,
server_optimizer_fn=server_optimizer())
# Results in empty dataset with correct types and shapes.
ds = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(x=[[1.0, 2.0]], y=[[5.0]])).batch(
5, drop_remainder=True) # No batches of size 5 can be created.
federated_ds = [ds] * 2
server_state = iterative_process.initialize()
first_state, metric_outputs = iterative_process.next(
server_state, federated_ds)
self.assertAllClose(
list(first_state.model.trainable), [[[0.0], [0.0]], 0.0])
self.assertEqual(
list(metric_outputs.keys()),
['broadcast', 'aggregation', 'train', 'stat'])
self.assertEmpty(metric_outputs['broadcast'])
self.assertEqual(metric_outputs['aggregation'],
collections.OrderedDict(mean_value=(), mean_weight=()))
self.assertEqual(metric_outputs['train']['num_examples'], 0)
self.assertTrue(tf.math.is_nan(metric_outputs['train']['loss']))
@parameterized.named_parameters([
('keras_opt', _get_keras_optimizer_fn),
('tff_opt', _get_tff_optimizer),
])
@test_utils.skip_test_for_multi_gpu
def test_get_model_weights(self, server_optimizer):
iterative_process = federated_sgd.build_federated_sgd_process(
model_fn=model_examples.LinearRegression,
server_optimizer_fn=server_optimizer())
num_clients = 3
ds = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict(
x=[[1.0, 2.0], [3.0, 4.0]],
y=[[5.0], [6.0]],
)).batch(2)
datasets = [ds] * num_clients
state = iterative_process.initialize()
self.assertIsInstance(
iterative_process.get_model_weights(state), model_utils.ModelWeights)
self.assertAllClose(state.model.trainable,
iterative_process.get_model_weights(state).trainable)
for _ in range(3):
state, _ = iterative_process.next(state, datasets)
self.assertIsInstance(
iterative_process.get_model_weights(state), model_utils.ModelWeights)
self.assertAllClose(state.model.trainable,
iterative_process.get_model_weights(state).trainable)
class FederatedSGDTest(test_case.TestCase, parameterized.TestCase):
"""Tests construction of FedSGD training process."""
@parameterized.named_parameters([
('keras_opt', _get_keras_optimizer_fn),
('tff_opt', _get_tff_optimizer),
])
def test_construction_calls_model_fn(self, server_optimizer):
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(
('robust_aggregator', model_update_aggregator.robust_aggregator),
('dp_aggregator', lambda: model_update_aggregator.dp_aggregator(1e-3, 3)),
('compression_aggregator',
model_update_aggregator.compression_aggregator),
('secure_aggreagtor', model_update_aggregator.secure_aggregator),
)<