Skip to content

Commit d83f637

Browse files
Mohammad NorouziFlax Authors
Mohammad Norouzi
authored and
Flax Authors
committed
Add preceding zeros to checkpoint files so they appear sorted under a directory.
PiperOrigin-RevId: 411882335
1 parent df26680 commit d83f637

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

flax/training/checkpoints.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@
4949
def _checkpoint_path(ckpt_dir: str,
5050
step: Union[int, str],
5151
prefix: str = 'checkpoint_') -> str:
52-
return os.path.join(ckpt_dir, f'{prefix}{step}')
52+
if isinstance(step, int):
53+
ckpt_file = f'{prefix}{step:0>9d}'
54+
else:
55+
ckpt_file = f'{prefix}{step}'
56+
return os.path.join(ckpt_dir, ckpt_file)
5357

5458

5559
def _checkpoint_path_step(path: str) -> Optional[float]:

tests/checkpoints_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_save_restore_checkpoints(self):
130130
gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
131131
checkpoints.save_checkpoint(
132132
tmp_dir, test_object1, 0, prefix='test_', keep=1)
133-
self.assertIn('test_0', os.listdir(tmp_dir))
133+
self.assertIn('test_000000000', os.listdir(tmp_dir))
134134
new_object = checkpoints.restore_checkpoint(
135135
tmp_dir, test_object0, prefix='test_')
136136
jtu.check_eq(new_object, test_object1)
@@ -153,7 +153,7 @@ def test_save_restore_checkpoints(self):
153153
jtu.check_eq(new_object, test_object2)
154154
# Restore a specific path.
155155
new_object = checkpoints.restore_checkpoint(
156-
os.path.join(tmp_dir, 'test_3'), test_object0)
156+
os.path.join(tmp_dir, 'test_000000003'), test_object0)
157157
jtu.check_eq(new_object, test_object2)
158158
# If a specific path is specified, but it does not exist, the same behavior
159159
# as when a directory is empty should apply: the target is returned

0 commit comments

Comments
 (0)