Skip to content

step_prefix cannot contain _ -- Checkpoint manager does not recognized multiple _. #1499

Open
@scott-yj-yang

Description

@scott-yj-yang

Bug Description:

When I created a checkpoint manager option like the following,

options = ocp.CheckpointManagerOptions(step_prefix="ppo_networks")
with ocp.CheckpointManager(
    ".../model_checkpoints/7358284e-a603-453f-9024-f69a27a293c4",
    options=options,
) as mngr:
    mngr.restore(0)

with my directory looks like this

Image

it gives me an value error of the following when instantiating the manager object.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 5
      1 import orbax.checkpoint as ocp
      4 options = ocp.CheckpointManagerOptions(step_prefix="ppo_networks")
----> 5 with ocp.CheckpointManager(
      6     "/root/vast/scott-yang/track-mjx/model_checkpoints/7358284e-a603-453f-9024-f69a27a293c4",
      7     options=options,
      8 ) as mngr:
      9     mngr.restore(0)

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py:685, in CheckpointManager.__init__(self, directory, checkpointers, options, metadata, item_names, item_handlers, logger, handler_registry)
    675   self._cleanup_tmp_directories()
    677 self._step_name_format = (
    678     self._options.step_name_format
    679     or step_lib.standard_name_format(
   (...)
    682     )
    683 )
--> 685 self._checkpoints = self._load_checkpoint_infos()
    687 self._metadata_checkpointer = Checkpointer(
    688     JsonCheckpointHandler(
    689         multiprocessing_options=self._multiprocessing_options
   (...)
    694     temporary_path_class=self._options.temporary_path_class,
    695 )
    696 if self._options.read_only and not self._metadata_path().exists():

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py:1431, in CheckpointManager._load_checkpoint_infos(self)
   1423 """Loads a list of CheckpointInfo for existing checkpoints.
   1424 
   1425 If none are present, returns empty list.
   (...)
   1428   a list of CheckpointInfo, sorted by increasing step.
   1429 """
   1430 start = time.time()
-> 1431 steps = utils.checkpoint_steps(
   1432     self.directory, self._options.single_host_load_and_broadcast
   1433 )
   1434 steps.sort()  # Prefer in-place sort.
   1436 if not steps:

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:698, in checkpoint_steps(checkpoint_dir, single_host_load_and_broadcast)
    696   padded_step_list = multihost.broadcast_one_to_all(padded_step_list)
    697   return [step for step in padded_step_list if step >= 0]
--> 698 return _checkpoint_steps(checkpoint_dir)

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:682, in checkpoint_steps.<locals>._checkpoint_steps(path)
    681 def _checkpoint_steps(path: epath.Path) -> List[int]:
--> 682   return [
    683       step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(path)
    684   ]

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:683, in <listcomp>(.0)
    681 def _checkpoint_steps(path: epath.Path) -> List[int]:
    682   return [
--> 683       step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(path)
    684   ]

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:645, in step_from_checkpoint_name(name)
    643 elif tmp_match := re.match(TMP_DIR_STEP_PATTERN, name):
    644   return int(tmp_match.group(1))
--> 645 raise ValueError(f'Unrecognized name format: {name}.')

ValueError: Unrecognized name format: ppo_networks_1024000.

Specifically, when I check the step.py

def step_from_checkpoint_name(name: str) -> int:
"""Returns the step from a checkpoint name. Also works for tmp checkpoints."""
if name.isdigit():
return int(name)
elif name.split('_')[-1].isdigit():
split = name.split('_')
if len(split) == 2 and split[0]:
return int(split[-1])
elif tmp_match := re.match(TMP_DIR_STEP_PATTERN, name):
return int(tmp_match.group(1))
raise ValueError(f'Unrecognized name format: {name}.')
it assumes that after the split by _, there are only two members. An input validation of the prefix is needed.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions