File tree Expand file tree Collapse file tree 3 files changed +11
-4
lines changed
orbax/checkpoint/_src/path Expand file tree Collapse file tree 3 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
88## [ Unreleased]
99
10+ ### Fixed
11+
12+ - Fix ` step_from_checkpoint_name ` to allow the passed in checkpoint name to
13+ include an arbitrary ` step_prefix ` with any character(s) such as underscores.
14+
1015### Changed
1116
1217- Validate checkpoints before writing merged OCDBT database using in-memory
Original file line number Diff line number Diff line change 3838from orbax .checkpoint ._src .path import gcs_utils
3939from orbax .checkpoint ._src .path import temporary_paths
4040
41+ # Allowed checkpoint step naming using any non empty `step_prefix`.
42+ ALLOWED_STEP_NAME_PATTERN = r'(.+)_(\d+)'
4143
4244TMP_DIR_SUFFIX = temporary_paths .TMP_DIR_SUFFIX
4345# prefix_1000.orbax-checkpoint-tmp-1010101
@@ -739,10 +741,8 @@ def step_from_checkpoint_name(name: str) -> int:
739741 """Returns the step from a checkpoint name. Also works for tmp checkpoints."""
740742 if name .isdigit ():
741743 return int (name )
742- elif name .split ('_' )[- 1 ].isdigit ():
743- split = name .split ('_' )
744- if len (split ) == 2 and split [0 ]:
745- return int (split [- 1 ])
744+ elif m := re .fullmatch (ALLOWED_STEP_NAME_PATTERN , name ):
745+ return int (m .group (2 ))
746746 elif tmp_match := re .match (TMP_DIR_STEP_PATTERN , name ):
747747 return int (tmp_match .group (1 ))
748748 raise ValueError (f'Unrecognized name format: { name } .' )
Original file line number Diff line number Diff line change @@ -471,11 +471,13 @@ def test_is_path_temporary(self, step_prefix):
471471 ('checkpoint_0000' , 0 ),
472472 ('checkpoint_003400' , 3400 ),
473473 ('foobar_1000' , 1000 ),
474+ ('foo_bar_1000' , 1000 ),
474475 ('0.orbax-checkpoint-tmp-1010101' , 0 ),
475476 ('0000.orbax-checkpoint-tmp-12323232' , 0 ),
476477 ('foobar_1.orbax-checkpoint-tmp-12424424' , 1 ),
477478 ('foobar_000505.orbax-checkpoint-tmp-13124' , 505 ),
478479 ('checkpoint_16.orbax-checkpoint-tmp-123214324' , 16 ),
480+ ('foo_bar_1000.orbax-checkpoint-tmp-123123' , 1000 ),
479481 )
480482 def test_step_from_checkpoint_name (self , name , step ):
481483 self .assertEqual (step_lib .step_from_checkpoint_name (name ), step )
You can’t perform that action at this time.
0 commit comments