Commit 8b8ff87a authored by Michael Reneer's avatar Michael Reneer Committed by tensorflow-copybara
Browse files

Update the `ProgramStateManager` to support user-defined classes.

This change promotes the capability to support user-defined classes from the concrete implementation of `FileProgramStateManager` to the abstract interface and makes it easier to write loops that contain program state containing user-defined classes.

PiperOrigin-RevId: 413696492
parent 032d4713
......@@ -34,12 +34,6 @@ from tensorflow_federated.python.program import program_state_manager
from tensorflow_federated.python.program import value_reference
# TODO(b/199737690): Update `FileProgramStateManager` to not require a structure
# to load program state.
class FileProgramStateManagerStructureError(Exception):
pass
class FileProgramStateManager(program_state_manager.ProgramStateManager):
"""A `tff.program.ProgramStateManager` that is backed by a file system.
......@@ -99,20 +93,6 @@ class FileProgramStateManager(program_state_manager.ProgramStateManager):
self._prefix = prefix
self._keep_total = keep_total
self._keep_first = keep_first
self._structure = None
# TODO(b/199737690): Update `FileProgramStateManager` to not require a
# structure to load program state.
def set_structure(self, structure: Any):
"""Configures a structure to use when loading program state.
The structure must be set before calling `load`.
Args:
structure: A nested structure which `tf.convert_to_tensor` supports to use
as a template when calling `load`.
"""
self._structure = structure
def versions(self) -> Optional[List[int]]:
"""Returns a list of saved versions or `None`.
......@@ -169,16 +149,20 @@ class FileProgramStateManager(program_state_manager.ProgramStateManager):
basename = f'{self._prefix}{version}'
return os.path.join(self._root_dir, basename)
def load(self, version: int) -> Any:
def load(self, version: int, structure: Any) -> Any:
"""Returns the program state for the given `version`.
Args:
version: A integer representing the version of a saved program state.
structure: The nested structure of the saved program state for the given
`version` used to support serialization and deserailization of
user-defined classes in the structure.
Raises:
ProgramStateManagerStateNotFoundError: If there is no program state for
the given `version`.
FileProgramStateManagerStructureError: If `structure` has not been set.
ProgramStateManagerStructureError: If `structure` does not match the value
loaded for the given `version`.
"""
py_typecheck.check_type(version, int)
path = self._get_path_for_version(version)
......@@ -186,15 +170,15 @@ class FileProgramStateManager(program_state_manager.ProgramStateManager):
raise program_state_manager.ProgramStateManagerStateNotFoundError(
f'No program state found for version: {version}')
module = tf.saved_model.load(path)
flattened_value = module()
flattened_state = module()
try:
program_state = tree.unflatten_as(self._structure, flattened_value)
program_state = tree.unflatten_as(structure, flattened_state)
except ValueError as e:
raise FileProgramStateManagerStructureError(
f'The structure of type {type(self._structure)}:\n'
f'{self._structure}\n'
f'does not match the value of type {type(flattened_value)}:\n'
f'{flattened_value}\n') from e
raise program_state_manager.ProgramStateManagerStructureError(
f'The structure of type {type(structure)}:\n'
f'{structure}\n'
f'does not match the value of type {type(flattened_state)}:\n'
f'{flattened_state}\n') from e
logging.info('Program state loaded: %s', path)
return program_state
......@@ -237,8 +221,8 @@ class FileProgramStateManager(program_state_manager.ProgramStateManager):
if tf.io.gfile.exists(path):
raise program_state_manager.ProgramStateManagerStateAlreadyExistsError(
f'Program state already exists for version: {version}')
materialized_value = value_reference.materialize_value(program_state)
flattened_value = tree.flatten(materialized_value)
module = file_utils.ValueModule(flattened_value)
materialized_state = value_reference.materialize_value(program_state)
flattened_state = tree.flatten(materialized_state)
module = file_utils.ValueModule(flattened_state)
file_utils.write_saved_model(module, path)
self._remove_old_program_state()
......@@ -318,10 +318,10 @@ class FileProgramStateManagerLoadTest(parameterized.TestCase, tf.test.TestCase):
temp_dir = self.create_tempdir()
program_state_mngr = file_program_state_manager.FileProgramStateManager(
root_dir=temp_dir, prefix='a_')
program_state_mngr.set_structure(program_state)
program_state_mngr.save(program_state, 1)
structure = program_state
actual_program_state = program_state_mngr.load(1)
actual_program_state = program_state_mngr.load(1, structure)
self.assertEqual(type(actual_program_state), type(expected_program_state))
self.assertAllEqual(actual_program_state, expected_program_state)
......@@ -335,11 +335,11 @@ class FileProgramStateManagerLoadTest(parameterized.TestCase, tf.test.TestCase):
temp_dir = self.create_tempdir()
program_state_mngr = file_program_state_manager.FileProgramStateManager(
root_dir=temp_dir, prefix='a_')
program_state_mngr.set_structure('state')
for i in range(3):
program_state_mngr.save(f'state_{i}', i)
structure = 'state'
actual_program_state = program_state_mngr.load(version)
actual_program_state = program_state_mngr.load(version, structure)
expected_program_state = f'state_{version}'
self.assertEqual(actual_program_state, expected_program_state)
......@@ -348,33 +348,32 @@ class FileProgramStateManagerLoadTest(parameterized.TestCase, tf.test.TestCase):
temp_dir = self.create_tempdir()
program_state_mngr = file_program_state_manager.FileProgramStateManager(
root_dir=temp_dir, prefix='a_')
program_state_mngr.set_structure('state')
with self.assertRaises(
program_state_manager.ProgramStateManagerStateNotFoundError):
_ = program_state_mngr.load(0)
_ = program_state_mngr.load(0, None)
def test_raises_version_not_found_error_with_unknown_version(self):
temp_dir = self.create_tempdir()
program_state_mngr = file_program_state_manager.FileProgramStateManager(
root_dir=temp_dir, prefix='a_')
program_state_mngr.set_structure('state')
program_state_mngr.save('state_1', 1)
structure = 'state'
with self.assertRaises(
program_state_manager.ProgramStateManagerStateNotFoundError):
program_state_mngr.load(10)
program_state_mngr.load(10, structure)
def test_raises_structure_error(self):
temp_dir = self.create_tempdir()
program_state_mngr = file_program_state_manager.FileProgramStateManager(
root_dir=temp_dir, prefix='a_')
program_state_mngr.set_structure([])
program_state_mngr.save('state_1', 1)
structure = []
with self.assertRaises(
file_program_state_manager.FileProgramStateManagerStructureError):
program_state_mngr.load(1)
program_state_manager.ProgramStateManagerStructureError):
program_state_mngr.load(1, structure)
@parameterized.named_parameters(
('none', None),
......@@ -387,7 +386,7 @@ class FileProgramStateManagerLoadTest(parameterized.TestCase, tf.test.TestCase):
root_dir=temp_dir, prefix='a_')
with self.assertRaises(TypeError):
program_state_mngr.load(version)
program_state_mngr.load(version, None)
class FileProgramStateManagerRemoveTest(parameterized.TestCase):
......
......@@ -862,7 +862,9 @@ class SavedModelFileReleaseManagerReleaseTest(parameterized.TestCase,
('numpy_nested',
{'a': [np.bool(True), np.int32(1)], 'b': [np.str_('a')]},
[tf.constant(True), tf.constant(1), tf.constant('a')]),
('server_array_reference', test_utils.TestServerArrayReference(1), [1]),
('server_array_reference',
test_utils.TestServerArrayReference(1),
[tf.constant(1)]),
('server_array_reference_nested',
{'a': [test_utils.TestServerArrayReference(True),
test_utils.TestServerArrayReference(1)],
......
......@@ -25,6 +25,10 @@ class ProgramStateManagerStateNotFoundError(Exception):
pass
class ProgramStateManagerStructureError(Exception):
pass
class ProgramStateManager(metaclass=abc.ABCMeta):
"""An interface for saving and loading program state in a federated program.
......@@ -43,11 +47,14 @@ class ProgramStateManager(metaclass=abc.ABCMeta):
raise NotImplementedError
@abc.abstractmethod
def load(self, version: int) -> Any:
def load(self, version: int, structure: Any) -> Any:
"""Returns the saved program state for the given `version`.
Args:
version: A integer representing the version of a saved program state.
structure: The nested structure of the saved program state for the given
`version` used to support serialization and deserailization of
user-defined classes in the structure.
Raises:
ProgramStateManagerStateNotFoundError: If there is no program state for
......@@ -55,9 +62,14 @@ class ProgramStateManager(metaclass=abc.ABCMeta):
"""
raise NotImplementedError
def load_latest(self) -> Tuple[Any, int]:
def load_latest(self, structure: Any) -> Tuple[Any, int]:
"""Returns the latest saved program state and version or (`None`, 0).
Args:
structure: The nested structure of the saved program state for the given
`version` used to support serialization and deserailization of
user-defined classes in the structure.
Returns:
A tuple of the latest saved (program state, version) or (`None`, 0) if
there is no latest saved program state.
......@@ -67,7 +79,7 @@ class ProgramStateManager(metaclass=abc.ABCMeta):
return None, 0
latest_version = max(versions)
try:
return self.load(latest_version), latest_version
return self.load(latest_version, structure), latest_version
except ProgramStateManagerStateNotFoundError:
return None, 0
......
......@@ -33,7 +33,8 @@ class _TestProgramStateManager(program_state_manager.ProgramStateManager):
def save(self, program_state: Any, version: int):
del program_state, version # Unused.
def load(self, version: int) -> Any:
def load(self, version: int, structure: Any) -> Any:
del structure # Unused.
if self._values is None or version not in self._values:
raise program_state_manager.ProgramStateManagerStateNotFoundError()
return self._values[version]
......@@ -43,28 +44,31 @@ class ProgramStateManagerTest(absltest.TestCase):
def test_load_latest_with_saved_program_state(self):
values = {x: f'test{x}' for x in range(5)}
structure = values[0]
program_state_mngr = _TestProgramStateManager(values)
(program_state, version) = program_state_mngr.load_latest()
(program_state, version) = program_state_mngr.load_latest(structure)
self.assertEqual(program_state, 'test4')
self.assertEqual(version, 4)
def test_load_latest_with_no_saved_program_state(self):
structure = None
program_state_mngr = _TestProgramStateManager()
(program_state, version) = program_state_mngr.load_latest()
(program_state, version) = program_state_mngr.load_latest(structure)
self.assertIsNone(program_state)
self.assertEqual(version, 0)
def test_load_latest_with_load_failure(self):
values = {x: f'test{x}' for x in range(5)}
structure = values[0]
program_state_mngr = _TestProgramStateManager(values)
program_state_mngr.load = mock.MagicMock(
side_effect=program_state_manager.ProgramStateManagerStateNotFoundError)
(program_state, version) = program_state_mngr.load_latest()
(program_state, version) = program_state_mngr.load_latest(structure)
self.assertIsNone(program_state)
self.assertEqual(version, 0)
......
......@@ -168,7 +168,6 @@ py_library(
deps = [
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/program:file_program_state_manager",
"//tensorflow_federated/python/program:program_state_manager",
"//tensorflow_federated/python/program:release_manager",
"//tensorflow_federated/python/simulation:checkpoint_manager",
......
......@@ -16,14 +16,12 @@
import collections
import pprint
import time
import typing
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple
from absl import logging
from tensorflow_federated.python.core.api import computation_base
from tensorflow_federated.python.core.templates import iterative_process
from tensorflow_federated.python.program import file_program_state_manager as file_program_state_manager_lib
from tensorflow_federated.python.program import program_state_manager as program_state_manager_lib
from tensorflow_federated.python.program import release_manager as release_manager_lib
from tensorflow_federated.python.simulation import checkpoint_manager
......@@ -536,20 +534,9 @@ def run_training_process(
The `state` of the training process after training.
"""
logging.info('Running training process')
# TODO(b/199737690): Update `FileProgramStateManager` to not require a
# structure to load program state; once this is fixed, we can move the
# initialize invocation down so it's only called if required.
initial_state = training_process.initialize()
if isinstance(program_state_manager,
file_program_state_manager_lib.FileProgramStateManager):
file_program_state_manager = typing.cast(
file_program_state_manager_lib.FileProgramStateManager,
program_state_manager)
file_program_state_manager.set_structure(initial_state)
if program_state_manager is not None:
program_state, version = program_state_manager.load_latest()
structure = training_process.initialize()
program_state, version = program_state_manager.load_latest(structure)
else:
program_state = None
if program_state is not None:
......@@ -558,7 +545,7 @@ def run_training_process(
start_round = version
else:
logging.info('Initializing training process')
state = initial_state
state = training_process.initialize()
start_round = 1
if evaluation_fn is not None and evaluation_selection_fn is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment