Commit 97135890 authored by Michael Reneer's avatar Michael Reneer Committed by tensorflow-copybara
Browse files

Upstream function to create a set of managers for a simulation to TFF.

* Added function to create a set of managers for a simulation to TFF.
* Fixed a bug in `CSVFileReleaseManager`; if you created an instance with a file that did not exist the CSVFileReleaseManager` would create the file for you, but if the folder did not exist it would not create the folder for you.
* Added test to cover bug in `CSVFileReleaseManager`.

PiperOrigin-RevId: 413759504
parent 7681cec7
......@@ -109,6 +109,9 @@ class CSVFileReleaseManager(release_manager.ReleaseManager):
py_typecheck.check_type(save_mode, CSVSaveMode)
if not file_path:
raise ValueError('Expected `file_path` to not be an empty string.')
file_dir = os.path.dirname(file_path)
if not tf.io.gfile.exists(file_dir):
tf.io.gfile.makedirs(file_dir)
self._file_path = file_path
self._save_mode = save_mode
......
......@@ -49,7 +49,7 @@ def _write_values_to_csv(file_path: os.PathLike, fieldnames: Sequence[str],
class CSVFileReleaseManagerInitTest(parameterized.TestCase):
def test_creates_root_dir(self):
def test_creates_file_path(self):
temp_file = self.create_tempfile()
os.remove(temp_file)
self.assertFalse(os.path.exists(temp_file))
......@@ -58,6 +58,16 @@ class CSVFileReleaseManagerInitTest(parameterized.TestCase):
self.assertTrue(os.path.exists(temp_file))
def test_creates_file_dir(self):
temp_dir = self.create_tempdir()
shutil.rmtree(temp_dir)
self.assertFalse(os.path.exists(temp_dir))
temp_file = os.path.join(temp_dir, 'a')
file_release_manager.CSVFileReleaseManager(file_path=temp_file)
self.assertTrue(os.path.exists(temp_file))
def test_initializes_with_empty_file(self):
temp_file = self.create_tempfile()
_write_values_to_csv(
......
......@@ -26,7 +26,7 @@ from tensorflow_federated.python.program import test_utils
class TensorboardReleaseManagerInitTest(parameterized.TestCase):
def test_creates_root_dir(self):
def test_creates_summary_dir(self):
temp_dir = self.create_tempdir()
summary_dir = os.path.join(temp_dir, 'test')
self.assertFalse(os.path.exists(summary_dir))
......
......@@ -168,8 +168,12 @@ py_library(
deps = [
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/program:file_program_state_manager",
"//tensorflow_federated/python/program:file_release_manager",
"//tensorflow_federated/python/program:logging_release_manager",
"//tensorflow_federated/python/program:program_state_manager",
"//tensorflow_federated/python/program:release_manager",
"//tensorflow_federated/python/program:tensorboard_release_manager",
"//tensorflow_federated/python/simulation:checkpoint_manager",
"//tensorflow_federated/python/simulation:metrics_manager",
],
......@@ -185,6 +189,10 @@ py_test(
":training_loop",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/program:file_program_state_manager",
"//tensorflow_federated/python/program:file_release_manager",
"//tensorflow_federated/python/program:logging_release_manager",
"//tensorflow_federated/python/program:tensorboard_release_manager",
"//tensorflow_federated/python/simulation:checkpoint_manager",
"//tensorflow_federated/python/simulation:metrics_manager",
],
......
......@@ -14,6 +14,7 @@
"""Training loops for iterative process simulations."""
import collections
import os
import pprint
import time
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple
......@@ -22,8 +23,12 @@ from absl import logging
from tensorflow_federated.python.core.api import computation_base
from tensorflow_federated.python.core.templates import iterative_process
from tensorflow_federated.python.program import file_program_state_manager as file_program_state_manager_lib
from tensorflow_federated.python.program import file_release_manager as file_release_manager_lib
from tensorflow_federated.python.program import logging_release_manager as logging_release_manager_lib
from tensorflow_federated.python.program import program_state_manager as program_state_manager_lib
from tensorflow_federated.python.program import release_manager as release_manager_lib
from tensorflow_federated.python.program import tensorboard_release_manager as tensorboard_release_manager_lib
from tensorflow_federated.python.simulation import checkpoint_manager
from tensorflow_federated.python.simulation import metrics_manager as metrics_manager_lib
......@@ -42,6 +47,54 @@ EVALUATION_METRICS_PREFIX = 'evaluation/'
EVALUATION_TIME_KEY = 'evaluation_time_in_seconds'
def create_managers(
root_dir: str,
experiment_name: str,
csv_save_mode: file_release_manager_lib
.CSVSaveMode = file_release_manager_lib.CSVSaveMode.APPEND
) -> Tuple[file_program_state_manager_lib.FileProgramStateManager,
List[release_manager_lib.ReleaseManager]]:
"""Creates a set of managers for running a simulation.
The managers that are created and how they are configured are indended to be
used with `tff.simulation.run_training_process` to run a simulation.
Args:
root_dir: A string representing the root output directory for the
simulation.
experiment_name: A unique identifier for the simulation, used to create
appropriate subdirectories in `root_dir`.
csv_save_mode: A `tff.program.CSVSaveMode` specifying the save mode for the
`tff.program.CSVFileReleaseManager`.
Returns:
A `tff.program.FileProgramStateManager`, and a list of
`tff.program.ReleaseManager`s consisting of a
`tff.program.LoggingReleaseManager`, a `tff.program.CSVFileReleaseManager`,
and a `tff.program.TensorboardReleaseManager`.
"""
program_state_dir = os.path.join(root_dir, 'program_state', experiment_name)
program_state_manager = file_program_state_manager_lib.FileProgramStateManager(
root_dir=program_state_dir)
logging_release_manager = logging_release_manager_lib.LoggingReleaseManager()
csv_file_path = os.path.join(root_dir, 'metrics', experiment_name,
'experiment.metrics.csv')
csv_file_release_manager = file_release_manager_lib.CSVFileReleaseManager(
file_path=csv_file_path, save_mode=csv_save_mode)
summary_dir = os.path.join(root_dir, 'logdir', experiment_name)
tensorboard_release_manager = tensorboard_release_manager_lib.TensorboardReleaseManager(
summary_dir=summary_dir)
return program_state_manager, [
logging_release_manager,
csv_file_release_manager,
tensorboard_release_manager,
]
def _load_initial_checkpoint(
template_state: Any,
file_checkpoint_manager: FileCheckpointManager) -> Tuple[Any, int]:
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import collections
import os
from unittest import mock
from absl.testing import absltest
......@@ -20,11 +21,70 @@ from absl.testing import parameterized
from tensorflow_federated.python.core.api import computation_base
from tensorflow_federated.python.core.templates import iterative_process
from tensorflow_federated.python.program import file_program_state_manager as file_program_state_manager_lib
from tensorflow_federated.python.program import file_release_manager as file_release_manager_lib
from tensorflow_federated.python.program import logging_release_manager as logging_release_manager_lib
from tensorflow_federated.python.program import tensorboard_release_manager as tensorboard_release_manager_lib
from tensorflow_federated.python.simulation import checkpoint_manager
from tensorflow_federated.python.simulation import metrics_manager as metrics_manager_lib
from tensorflow_federated.python.simulation import training_loop
class CreateManagersTest(parameterized.TestCase):
def test_create_managers_returns_managers(self):
root_dir = self.create_tempdir()
file_program_state_manager, release_managers = training_loop.create_managers(
root_dir=root_dir, experiment_name='test')
self.assertIsInstance(
file_program_state_manager,
file_program_state_manager_lib.FileProgramStateManager)
self.assertLen(release_managers, 3)
self.assertIsInstance(release_managers[0],
logging_release_manager_lib.LoggingReleaseManager)
self.assertIsInstance(release_managers[1],
file_release_manager_lib.CSVFileReleaseManager)
self.assertIsInstance(
release_managers[2],
tensorboard_release_manager_lib.TensorboardReleaseManager)
@mock.patch('tensorflow_federated.python.program.'
'tensorboard_release_manager.TensorboardReleaseManager')
@mock.patch('tensorflow_federated.python.program.'
'file_release_manager.CSVFileReleaseManager')
@mock.patch('tensorflow_federated.python.program.'
'logging_release_manager.LoggingReleaseManager')
@mock.patch('tensorflow_federated.python.program.'
'file_program_state_manager.FileProgramStateManager')
def test_create_managers_creates_managers(self,
mock_file_program_state_manager,
mock_logging_release_manager,
mock_csv_file_release_manager,
mock_tensorboard_release_manager):
root_dir = self.create_tempdir()
experiment_name = 'test'
csv_save_mode = file_release_manager_lib.CSVSaveMode.APPEND
training_loop.create_managers(
root_dir=root_dir,
experiment_name=experiment_name,
csv_save_mode=csv_save_mode)
program_state_dir = os.path.join(root_dir, 'program_state', experiment_name)
mock_file_program_state_manager.assert_called_with(
root_dir=program_state_dir)
mock_logging_release_manager.assert_called_once_with()
csv_file_path = os.path.join(root_dir, 'metrics', experiment_name,
'experiment.metrics.csv')
mock_csv_file_release_manager.assert_called_once_with(
file_path=csv_file_path, save_mode=csv_save_mode)
summary_dir = os.path.join(root_dir, 'logdir', experiment_name)
mock_tensorboard_release_manager.assert_called_once_with(
summary_dir=summary_dir)
class LoadInitialCheckpointTest(parameterized.TestCase):
def test_returns_input_state_and_zero_if_checkpoint_is_none(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment