Skip to content
Snippets Groups Projects
Commit f272e5f2 authored by Zachary Garrett's avatar Zachary Garrett Committed by tensorflow-copybara
Browse files

Update federated_sgd_test unittests to reflect improvements made in #4ed574ab

PiperOrigin-RevId: 322178848
parent 4ed574ab
No related branches found
No related tags found
No related merge requests found
......@@ -109,7 +109,6 @@ py_test(
":keras_utils",
":model_examples",
":model_utils",
"//tensorflow_federated/python/common_libs:anonymous_tuple",
"//tensorflow_federated/python/common_libs:test",
],
)
......
......@@ -18,7 +18,6 @@ from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.common_libs import anonymous_tuple
from tensorflow_federated.python.common_libs import test
from tensorflow_federated.python.learning import federated_sgd
from tensorflow_federated.python.learning import keras_utils
......@@ -126,10 +125,12 @@ class FederatedSGDTffTest(test.TestCase, parameterized.TestCase):
for _ in range(num_iterations):
server_state, metric_outputs = iterative_process.next(
server_state, federated_ds)
self.assertEqual(metric_outputs.train.num_examples,
train_metrics = metric_outputs['train']
self.assertEqual(train_metrics['num_examples'],
num_iterations * len(federated_ds))
self.assertLess(metric_outputs.train.loss, prev_loss)
prev_loss = metric_outputs.train.loss
loss = train_metrics['loss']
self.assertLess(loss, prev_loss)
prev_loss = loss
@parameterized.named_parameters([
('functional_model',
......@@ -186,12 +187,11 @@ class FederatedSGDTffTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(
list(first_state.model.trainable), [[[0.0], [0.0]], 0.0])
self.assertEqual(
anonymous_tuple.name_list(metric_outputs),
['broadcast', 'aggregation', 'train'])
self.assertEmpty(metric_outputs.broadcast)
self.assertEmpty(metric_outputs.aggregation)
self.assertEqual(metric_outputs.train.num_examples, 0)
self.assertTrue(tf.math.is_nan(metric_outputs.train.loss))
list(metric_outputs.keys()), ['broadcast', 'aggregation', 'train'])
self.assertEmpty(metric_outputs['broadcast'])
self.assertEmpty(metric_outputs['aggregation'])
self.assertEqual(metric_outputs['train']['num_examples'], 0)
self.assertTrue(tf.math.is_nan(metric_outputs['train']['loss']))
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment