Skip to content

Commit b2c2b12

Browse files
author
Orbax Authors
committed
Fix step_prefix option to allow for arbitrary naming
PiperOrigin-RevId: 832504226
1 parent 9643a9a commit b2c2b12

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

checkpoint/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Fixed
11+
12+
- Fix `step_from_checkpoint_name` to allow the passed in checkpoint name to
13+
include an arbitrary `step_prefix` with any character(s) such as underscores.
14+
1015
### Changed
1116

1217
- Validate checkpoints before writing merged OCDBT database using in-memory

checkpoint/orbax/checkpoint/_src/path/step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from orbax.checkpoint._src.path import gcs_utils
3939
from orbax.checkpoint._src.path import temporary_paths
4040

41+
# Allowed checkpoint step naming using any non empty `step_prefix`.
42+
ALLOWED_STEP_NAME_PATTERN = r'(.+)_(\d+)'
4143

4244
TMP_DIR_SUFFIX = temporary_paths.TMP_DIR_SUFFIX
4345
# prefix_1000.orbax-checkpoint-tmp-1010101
@@ -739,10 +741,8 @@ def step_from_checkpoint_name(name: str) -> int:
739741
"""Returns the step from a checkpoint name. Also works for tmp checkpoints."""
740742
if name.isdigit():
741743
return int(name)
742-
elif name.split('_')[-1].isdigit():
743-
split = name.split('_')
744-
if len(split) == 2 and split[0]:
745-
return int(split[-1])
744+
elif m := re.fullmatch(ALLOWED_STEP_NAME_PATTERN, name):
745+
return int(m.group(2))
746746
elif tmp_match := re.match(TMP_DIR_STEP_PATTERN, name):
747747
return int(tmp_match.group(1))
748748
raise ValueError(f'Unrecognized name format: {name}.')

checkpoint/orbax/checkpoint/_src/path/step_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,11 +471,13 @@ def test_is_path_temporary(self, step_prefix):
471471
('checkpoint_0000', 0),
472472
('checkpoint_003400', 3400),
473473
('foobar_1000', 1000),
474+
('foo_bar_1000', 1000),
474475
('0.orbax-checkpoint-tmp-1010101', 0),
475476
('0000.orbax-checkpoint-tmp-12323232', 0),
476477
('foobar_1.orbax-checkpoint-tmp-12424424', 1),
477478
('foobar_000505.orbax-checkpoint-tmp-13124', 505),
478479
('checkpoint_16.orbax-checkpoint-tmp-123214324', 16),
480+
('foo_bar_1000.orbax-checkpoint-tmp-123123', 1000),
479481
)
480482
def test_step_from_checkpoint_name(self, name, step):
481483
self.assertEqual(step_lib.step_from_checkpoint_name(name), step)

0 commit comments

Comments
 (0)