Skip to content
Snippets Groups Projects
Commit af7a2250 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by tensorflow-copybara
Browse files

Factor checkpoint serial number extraction out to module-level.

Use the serial number pattern to check target directory naming before saving a
checkpoint.

PiperOrigin-RevId: 272911215
parent 72fba747
No related branches found
No related tags found
No related merge requests found
......@@ -21,36 +21,43 @@ import re
import tensorflow as tf
def latest_checkpoint(root_output_dir, checkpoint_prefix='ckpt_'):
def get_serial_number(export_dir, prefix='ckpt_'):
r"""Get the integer component of a checkpoint directory name.
Args:
export_dir: A checkpoint directory.
prefix: Common prefix shared by all checkpoint directories.
Returns:
The number extracted from the checkpoint directory, or -1 if the directory
is not formatted correctly.
"""
matcher = re.match(r'^{}(?P<num>\d+)$'.format(prefix),
os.path.basename(export_dir))
return int(matcher.group('num')) if matcher else -1
def latest_checkpoint(root_output_dir, prefix='ckpt_'):
r"""Get the latest checkpoint name.
Searches `root_output_dir` for directories matching the regular expression
`checkpoint_prefix_\d+$` and returns the directory with the largest integer
suffix.
`prefix_\d+$` and returns the directory with the largest integer suffix.
Args:
root_output_dir: The directory where all checkpoints stored.
checkpoint_prefix: The common prefix shared by all checkpoint directories.
prefix: The common prefix shared by all checkpoint directories.
Returns:
Dirname of the lastest checkpoint.
"""
checkpoints = tf.io.gfile.glob(
os.path.join(root_output_dir, '{}*'.format(checkpoint_prefix)))
os.path.join(root_output_dir, '{}*'.format(prefix)))
if not checkpoints:
return None
return max(checkpoints, key=lambda ckpt: get_serial_number(ckpt, prefix))
checkpoint_regex = re.compile(
r'^(?P<prefix>{})(?P<num>\d+)$'.format(checkpoint_prefix))
def by_checkpoint_number(ckpt):
matcher = checkpoint_regex.match(os.path.basename(ckpt))
return int(matcher.group('num')) if matcher else -1
return max(checkpoints, key=by_checkpoint_number)
def save(obj, export_dir):
def save(obj, export_dir, prefix=None):
r"""Save a nested structure to `export_dir`.
NOTE: to be compatible with `latest_checkpoint`, the basename of `export_dir`
......@@ -60,7 +67,17 @@ def save(obj, export_dir):
Args:
obj: A nested structure which `tf.convert_to_tensor` supports.
export_dir: A directory in which to write the state.
prefix: The common prefix shared by all checkpoint directories. If provided,
we will fail if the export directory doesn't match this prefix. If not
provided, no check will be performed.
Raises:
ValueError: If `prefix` is provided and `export_dir` doesn't use the prefix.
"""
if prefix is not None and get_serial_number(export_dir, prefix) < 0:
raise ValueError('Checkpoint dir "{}" is not named like "{}XXXX!'.format(
export_dir, prefix))
model = tf.Module()
model.obj = tf.nest.flatten(obj)
model.build_obj_fn = tf.function(lambda: model.obj, input_signature=())
......
......@@ -77,7 +77,7 @@ class SavedStateTest(tf.test.TestCase):
for round_num in range(5):
export_dir = os.path.join(self.get_temp_dir(),
'{}{:03d}'.format(prefix, round_num))
checkpoint_utils.save(state, export_dir)
checkpoint_utils.save(state, export_dir, prefix)
latest_checkpoint_path = checkpoint_utils.latest_checkpoint(
self.get_temp_dir(), prefix)
self.assertEndsWith(
......@@ -95,7 +95,7 @@ class SavedStateTest(tf.test.TestCase):
self.get_temp_dir(), prefix)
self.assertEndsWith(
latest_checkpoint_path,
'{:03d}'.format(round_num - 1),
'{}{:03d}'.format(prefix, round_num - 1),
msg=latest_checkpoint_path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment