Commit 032d4713 authored by Michael Reneer's avatar Michael Reneer Committed by tensorflow-copybara
Browse files

Remove `run_simulation_with_callbacks` from the public API.

This API was adding complexity and it was difficult to reason about fault tolerance and initialization. This change is one step in reducing this complexity and make it easier to migrate to the Federated Program API.

PiperOrigin-RevId: 413486363
parent be2832d0
......@@ -31,7 +31,6 @@ from tensorflow_federated.python.simulation.training_loop import EVALUATION_TIME
from tensorflow_federated.python.simulation.training_loop import ROUND_NUMBER_KEY
from tensorflow_federated.python.simulation.training_loop import ROUND_TIME_KEY
from tensorflow_federated.python.simulation.training_loop import run_simulation
from tensorflow_federated.python.simulation.training_loop import run_simulation_with_callbacks
from tensorflow_federated.python.simulation.training_loop import run_stateless_simulation
from tensorflow_federated.python.simulation.training_loop import run_training_process
from tensorflow_federated.python.simulation.training_loop import TRAINING_TIME_KEY
......
......@@ -287,12 +287,12 @@ def run_simulation(
metrics_managers, validation_fn)
on_round_end = _create_on_round_end_fn(file_checkpoint_manager,
metrics_managers, validation_fn)
return run_simulation_with_callbacks(process, client_selection_fn,
return _run_simulation_with_callbacks(process, client_selection_fn,
total_rounds, on_loop_start,
on_round_end)
def run_simulation_with_callbacks(
def _run_simulation_with_callbacks(
process: iterative_process.IterativeProcess,
client_selection_fn: Callable[[int], Any],
total_rounds: int,
......
......@@ -365,7 +365,7 @@ class CreateOnRoundEndTest(absltest.TestCase):
class RunSimulationTest(parameterized.TestCase):
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop.run_simulation_with_callbacks')
'training_loop._run_simulation_with_callbacks')
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop._create_on_round_end_fn')
@mock.patch('tensorflow_federated.python.simulation.'
......@@ -398,7 +398,7 @@ class RunSimulationTest(parameterized.TestCase):
('optional_inputs_7', 'arg1', 'arg2', 'arg3'),
)
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop.run_simulation_with_callbacks')
'training_loop._run_simulation_with_callbacks')
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop._create_on_round_end_fn')
@mock.patch('tensorflow_federated.python.simulation.'
......@@ -438,7 +438,7 @@ class RunSimulationTest(parameterized.TestCase):
('optional_inputs_7', 'arg1', 'arg2', 'arg3'),
)
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop.run_simulation_with_callbacks')
'training_loop._run_simulation_with_callbacks')
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop._create_on_round_end_fn')
@mock.patch('tensorflow_federated.python.simulation.'
......@@ -484,7 +484,7 @@ class RunSimulationWithCallbacksTest(parameterized.TestCase):
process = mock.create_autospec(iterative_process.IterativeProcess)
process.next.return_value = ('0', {})
client_selection_fn = mock.MagicMock()
training_loop.run_simulation_with_callbacks(process, client_selection_fn,
training_loop._run_simulation_with_callbacks(process, client_selection_fn,
total_rounds)
self.assertEqual(process.next.call_count, total_rounds)
self.assertEqual(client_selection_fn.call_count, total_rounds)
......@@ -499,7 +499,7 @@ class RunSimulationWithCallbacksTest(parameterized.TestCase):
process = mock.create_autospec(iterative_process.IterativeProcess)
process.next.return_value = ('0', {})
client_selection_fn = mock.MagicMock()
training_loop.run_simulation_with_callbacks(process, client_selection_fn,
training_loop._run_simulation_with_callbacks(process, client_selection_fn,
total_rounds)
expected_calls = [mock.call(i) for i in range(1, total_rounds + 1)]
self.assertEqual(expected_calls, client_selection_fn.mock_calls)
......@@ -519,7 +519,7 @@ class RunSimulationWithCallbacksTest(parameterized.TestCase):
client_selection_fn = mock.MagicMock()
on_round_end = mock.MagicMock()
on_round_end.return_value = (3.0, {'validation/metric': 5})
training_loop.run_simulation_with_callbacks(
training_loop._run_simulation_with_callbacks(
process, client_selection_fn, total_rounds, on_round_end=on_round_end)
for i in range(1, total_rounds + 1):
round_end_call_args = on_round_end.call_args_list[i - 1][0]
......@@ -544,7 +544,7 @@ class RunSimulationWithCallbacksTest(parameterized.TestCase):
client_selection_fn = mock.MagicMock()
on_loop_start = mock.MagicMock()
on_loop_start.return_value = (0, 0)
training_loop.run_simulation_with_callbacks(
training_loop._run_simulation_with_callbacks(
process, client_selection_fn, total_rounds, on_loop_start=on_loop_start)
on_loop_start.assert_called_once_with(initialize_return_value)
......@@ -560,7 +560,7 @@ class RunSimulationWithCallbacksTest(parameterized.TestCase):
# We use `on_round_end` to pass the metrics as an output
on_round_end = mock.MagicMock()
on_round_end.return_value = ((), {})
training_loop.run_simulation_with_callbacks(
training_loop._run_simulation_with_callbacks(
process, client_selection_fn, 1, on_round_end=on_round_end)
expected_metrics_passed_to_round_end = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment