File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -372,18 +372,9 @@ def _init_state(model_rng):
372372 state = _init_state (model_key )
373373
374374 checkpointer = trainer .checkpointer .create (run_id )
375- if trainer .load_checkpoint_path is not None :
376- checkpoint_search_paths = [trainer .load_checkpoint_path ]
377- elif checkpointer is not None :
378- checkpoint_search_paths = [trainer .checkpointer .expanded_path (run_id )]
379- temp_path = trainer .checkpointer .expanded_temporary_path (run_id )
380- if temp_path is not None :
381- checkpoint_search_paths .append (temp_path )
382- else :
383- checkpoint_search_paths = []
384375 state = restore_grug_state_from_checkpoint (
385376 state ,
386- checkpoint_search_paths = checkpoint_search_paths ,
377+ checkpoint_search_paths = trainer . checkpoint_search_paths ( run_id ) ,
387378 load_checkpoint_setting = trainer .load_checkpoint ,
388379 mesh = mesh ,
389380 allow_partial = trainer .allow_partial_checkpoint ,
Original file line number Diff line number Diff line change @@ -372,18 +372,9 @@ def _init_state(model_rng):
372372 state = _init_state (model_key )
373373
374374 checkpointer = trainer .checkpointer .create (run_id )
375- if trainer .load_checkpoint_path is not None :
376- checkpoint_search_paths = [trainer .load_checkpoint_path ]
377- elif checkpointer is not None :
378- checkpoint_search_paths = [trainer .checkpointer .expanded_path (run_id )]
379- temp_path = trainer .checkpointer .expanded_temporary_path (run_id )
380- if temp_path is not None :
381- checkpoint_search_paths .append (temp_path )
382- else :
383- checkpoint_search_paths = []
384375 state = restore_grug_state_from_checkpoint (
385376 state ,
386- checkpoint_search_paths = checkpoint_search_paths ,
377+ checkpoint_search_paths = trainer . checkpoint_search_paths ( run_id ) ,
387378 load_checkpoint_setting = trainer .load_checkpoint ,
388379 mesh = mesh ,
389380 allow_partial = trainer .allow_partial_checkpoint ,
Original file line number Diff line number Diff line change @@ -410,18 +410,9 @@ def _init_state(model_rng):
410410 state = _init_state (model_key )
411411
412412 checkpointer = trainer .checkpointer .create (run_id )
413- if trainer .load_checkpoint_path is not None :
414- checkpoint_search_paths = [trainer .load_checkpoint_path ]
415- elif checkpointer is not None :
416- checkpoint_search_paths = [trainer .checkpointer .expanded_path (run_id )]
417- temp_path = trainer .checkpointer .expanded_temporary_path (run_id )
418- if temp_path is not None :
419- checkpoint_search_paths .append (temp_path )
420- else :
421- checkpoint_search_paths = []
422413 state = restore_grug_state_from_checkpoint (
423414 state ,
424- checkpoint_search_paths = checkpoint_search_paths ,
415+ checkpoint_search_paths = trainer . checkpoint_search_paths ( run_id ) ,
425416 load_checkpoint_setting = trainer .load_checkpoint ,
426417 mesh = mesh ,
427418 allow_partial = trainer .allow_partial_checkpoint ,
Original file line number Diff line number Diff line change @@ -854,6 +854,17 @@ def batch_axis_name(self) -> str | None:
854854 """if None (default), we'll load a checkpoint if it exists. If true, we must load a checkpoint"""
855855 load_checkpoint_path : Optional [str ] = None
856856 """can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path."""
857+
858+ def checkpoint_search_paths (self , run_id : str ) -> list [str ]:
859+ if self .load_checkpoint_path is not None :
860+ return [self .load_checkpoint_path ]
861+
862+ paths = [self .checkpointer .expanded_path (run_id )]
863+ temp_path = self .checkpointer .expanded_temporary_path (run_id )
864+ if temp_path is not None :
865+ paths .append (temp_path )
866+ return paths
867+
857868 initialize_from : Optional [str ] = None # Levanter trainer checkpoint to initialize from
858869 """Load and continue training from a checkpoint. If None, will initialize from model_init."""
859870 allow_partial_checkpoint : bool = False
Original file line number Diff line number Diff line change 3939 save_checkpoint ,
4040 unregister_debug_checkpointer_state_provider ,
4141)
42+ from levanter .trainer import TrainerConfig
4243from levanter .trainer_state import TrainerState
4344
4445
@@ -348,6 +349,21 @@ def test_checkpointer_config_no_temporary_base_path():
348349 assert config .expanded_temporary_path ("run1" ) is None
349350
350351
352+ def test_trainer_config_checkpoint_search_paths ():
353+ config = dataclasses .replace (
354+ TrainerConfig (),
355+ checkpointer = CheckpointerConfig (
356+ base_path = "/tmp/test-perm" ,
357+ temporary_base_path = "/tmp/test-temp" ,
358+ append_run_id_to_base_path = True ,
359+ ),
360+ )
361+ assert config .checkpoint_search_paths ("run1" ) == ["/tmp/test-perm/run1" , "/tmp/test-temp/run1" ]
362+
363+ pinned_config = dataclasses .replace (config , load_checkpoint_path = "/tmp/test-perm/run1/step-100" )
364+ assert pinned_config .checkpoint_search_paths ("run1" ) == ["/tmp/test-perm/run1/step-100" ]
365+
366+
351367def test_checkpointer_config_propagates_debug_settings ():
352368 config = CheckpointerConfig (
353369 base_path = "/tmp/checkpoints" ,
You can’t perform that action at this time.
0 commit comments