提交 032d4713 编辑于 作者: Michael Reneer's avatar Michael Reneer 提交者: tensorflow-copybara
浏览文件

Remove `run_simulation_with_callbacks` from the public API.

This API was adding complexity and it was difficult to reason about fault tolerance and initialization. This change is one step in reducing this complexity and make it easier to migrate to the Federated Program API.

PiperOrigin-RevId: 413486363
上级 be2832d0
......@@ -31,7 +31,6 @@ from tensorflow_federated.python.simulation.training_loop import EVALUATION_TIME
from tensorflow_federated.python.simulation.training_loop import ROUND_NUMBER_KEY
from tensorflow_federated.python.simulation.training_loop import ROUND_TIME_KEY
from tensorflow_federated.python.simulation.training_loop import run_simulation
from tensorflow_federated.python.simulation.training_loop import run_simulation_with_callbacks
from tensorflow_federated.python.simulation.training_loop import run_stateless_simulation
from tensorflow_federated.python.simulation.training_loop import run_training_process
from tensorflow_federated.python.simulation.training_loop import TRAINING_TIME_KEY
......
......@@ -287,12 +287,12 @@ def run_simulation(
metrics_managers, validation_fn)
on_round_end = _create_on_round_end_fn(file_checkpoint_manager,
metrics_managers, validation_fn)
return run_simulation_with_callbacks(process, client_selection_fn,
total_rounds, on_loop_start,
on_round_end)
return _run_simulation_with_callbacks(process, client_selection_fn,
total_rounds, on_loop_start,
on_round_end)
def run_simulation_with_callbacks(
def _run_simulation_with_callbacks(
process: iterative_process.IterativeProcess,
client_selection_fn: Callable[[int], Any],
total_rounds: int,
......
......@@ -365,7 +365,7 @@ class CreateOnRoundEndTest(absltest.TestCase):
class RunSimulationTest(parameterized.TestCase):
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop.run_simulation_with_callbacks')
'training_loop._run_simulation_with_callbacks')
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop._create_on_round_end_fn')
@mock.patch('tensorflow_federated.python.simulation.'
......@@ -398,7 +398,7 @@ class RunSimulationTest(parameterized.TestCase):
('optional_inputs_7', 'arg1', 'arg2', 'arg3'),
)
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop.run_simulation_with_callbacks')
'training_loop._run_simulation_with_callbacks')
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop._create_on_round_end_fn')
@mock.patch('tensorflow_federated.python.simulation.'
......@@ -438,7 +438,7 @@ class RunSimulationTest(parameterized.TestCase):
('optional_inputs_7', 'arg1', 'arg2', 'arg3'),
)
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop.run_simulation_with_callbacks')
'training_loop._run_simulation_with_callbacks')
@mock.patch('tensorflow_federated.python.simulation.'
'training_loop._create_on_round_end_fn')
@mock.patch('tensorflow_federated.python.simulation.'
......@@ -484,8 +484,8 @@ class RunSimulationWithCallbacksTest(parameterized.TestCase):
process = mock.create_autospec(iterative_process.IterativeProcess)
process.next.return_value = ('0', {})
client_selection_fn = mock.MagicMock()
training_loop.run_simulation_with_callbacks(process, client_selection_fn,
total_rounds)
training_loop._run_simulation_with_callbacks(process, client_selection_fn,
total_rounds)
self.assertEqual(process.next.call_count, total_rounds)
self.assertEqual(client_selection_fn.call_count, total_rounds)
......@@ -499,8 +499,8 @@ class RunSimulationWithCallbacksTest(parameterized.TestCase):
process = mock.create_autospec(iterative_process.IterativeProcess)
process.next.return_value = ('0', {})
client_selection_fn = mock.MagicMock()
training_loop.run_simulation_with_callbacks(process, client_selection_fn,
total_rounds)
training_loop._run_simulation_with_callbacks(process, client_selection_fn,
total_rounds)
expected_calls = [mock.call(i) for i in range(1, total_rounds + 1)]
self.assertEqual(expected_calls, client_selection_fn.mock_calls)
......@@ -519,7 +519,7 @@ class RunSimulationWithCallbacksTest(parameterized.TestCase):
client_selection_fn = mock.MagicMock()
on_round_end = mock.MagicMock()
on_round_end.return_value = (3.0, {'validation/metric': 5})
training_loop.run_simulation_with_callbacks(
training_loop._run_simulation_with_callbacks(
process, client_selection_fn, total_rounds, on_round_end=on_round_end)
for i in range(1, total_rounds + 1):
round_end_call_args = on_round_end.call_args_list[i - 1][0]
......@@ -544,7 +544,7 @@ class RunSimulationWithCallbacksTest(parameterized.TestCase):
client_selection_fn = mock.MagicMock()
on_loop_start = mock.MagicMock()
on_loop_start.return_value = (0, 0)
training_loop.run_simulation_with_callbacks(
training_loop._run_simulation_with_callbacks(
process, client_selection_fn, total_rounds, on_loop_start=on_loop_start)
on_loop_start.assert_called_once_with(initialize_return_value)
......@@ -560,7 +560,7 @@ class RunSimulationWithCallbacksTest(parameterized.TestCase):
# We use `on_round_end` to pass the metrics as an output
on_round_end = mock.MagicMock()
on_round_end.return_value = ((), {})
training_loop.run_simulation_with_callbacks(
training_loop._run_simulation_with_callbacks(
process, client_selection_fn, 1, on_round_end=on_round_end)
expected_metrics_passed_to_round_end = {
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册