Skip to content

Commit 72ffd63

Browse files
committed
Centralize checkpoint search paths
1 parent a93bf7a commit 72ffd63

5 files changed

Lines changed: 30 additions & 30 deletions

File tree

experiments/grug/base/train.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -372,18 +372,9 @@ def _init_state(model_rng):
372372
state = _init_state(model_key)
373373

374374
checkpointer = trainer.checkpointer.create(run_id)
375-
if trainer.load_checkpoint_path is not None:
376-
checkpoint_search_paths = [trainer.load_checkpoint_path]
377-
elif checkpointer is not None:
378-
checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)]
379-
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
380-
if temp_path is not None:
381-
checkpoint_search_paths.append(temp_path)
382-
else:
383-
checkpoint_search_paths = []
384375
state = restore_grug_state_from_checkpoint(
385376
state,
386-
checkpoint_search_paths=checkpoint_search_paths,
377+
checkpoint_search_paths=trainer.checkpoint_search_paths(run_id),
387378
load_checkpoint_setting=trainer.load_checkpoint,
388379
mesh=mesh,
389380
allow_partial=trainer.allow_partial_checkpoint,

experiments/grug/modular_opt/train.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -372,18 +372,9 @@ def _init_state(model_rng):
372372
state = _init_state(model_key)
373373

374374
checkpointer = trainer.checkpointer.create(run_id)
375-
if trainer.load_checkpoint_path is not None:
376-
checkpoint_search_paths = [trainer.load_checkpoint_path]
377-
elif checkpointer is not None:
378-
checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)]
379-
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
380-
if temp_path is not None:
381-
checkpoint_search_paths.append(temp_path)
382-
else:
383-
checkpoint_search_paths = []
384375
state = restore_grug_state_from_checkpoint(
385376
state,
386-
checkpoint_search_paths=checkpoint_search_paths,
377+
checkpoint_search_paths=trainer.checkpoint_search_paths(run_id),
387378
load_checkpoint_setting=trainer.load_checkpoint,
388379
mesh=mesh,
389380
allow_partial=trainer.allow_partial_checkpoint,

experiments/grug/moe/train.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -410,18 +410,9 @@ def _init_state(model_rng):
410410
state = _init_state(model_key)
411411

412412
checkpointer = trainer.checkpointer.create(run_id)
413-
if trainer.load_checkpoint_path is not None:
414-
checkpoint_search_paths = [trainer.load_checkpoint_path]
415-
elif checkpointer is not None:
416-
checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)]
417-
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
418-
if temp_path is not None:
419-
checkpoint_search_paths.append(temp_path)
420-
else:
421-
checkpoint_search_paths = []
422413
state = restore_grug_state_from_checkpoint(
423414
state,
424-
checkpoint_search_paths=checkpoint_search_paths,
415+
checkpoint_search_paths=trainer.checkpoint_search_paths(run_id),
425416
load_checkpoint_setting=trainer.load_checkpoint,
426417
mesh=mesh,
427418
allow_partial=trainer.allow_partial_checkpoint,

lib/levanter/src/levanter/trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,17 @@ def batch_axis_name(self) -> str | None:
854854
"""if None (default), we'll load a checkpoint if it exists. If true, we must load a checkpoint"""
855855
load_checkpoint_path: Optional[str] = None
856856
"""can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path."""
857+
858+
def checkpoint_search_paths(self, run_id: str) -> list[str]:
859+
if self.load_checkpoint_path is not None:
860+
return [self.load_checkpoint_path]
861+
862+
paths = [self.checkpointer.expanded_path(run_id)]
863+
temp_path = self.checkpointer.expanded_temporary_path(run_id)
864+
if temp_path is not None:
865+
paths.append(temp_path)
866+
return paths
867+
857868
initialize_from: Optional[str] = None # Levanter trainer checkpoint to initialize from
858869
"""Load and continue training from a checkpoint. If None, will initialize from model_init."""
859870
allow_partial_checkpoint: bool = False

lib/levanter/tests/test_checkpoint.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
save_checkpoint,
4040
unregister_debug_checkpointer_state_provider,
4141
)
42+
from levanter.trainer import TrainerConfig
4243
from levanter.trainer_state import TrainerState
4344

4445

@@ -348,6 +349,21 @@ def test_checkpointer_config_no_temporary_base_path():
348349
assert config.expanded_temporary_path("run1") is None
349350

350351

352+
def test_trainer_config_checkpoint_search_paths():
353+
config = dataclasses.replace(
354+
TrainerConfig(),
355+
checkpointer=CheckpointerConfig(
356+
base_path="/tmp/test-perm",
357+
temporary_base_path="/tmp/test-temp",
358+
append_run_id_to_base_path=True,
359+
),
360+
)
361+
assert config.checkpoint_search_paths("run1") == ["/tmp/test-perm/run1", "/tmp/test-temp/run1"]
362+
363+
pinned_config = dataclasses.replace(config, load_checkpoint_path="/tmp/test-perm/run1/step-100")
364+
assert pinned_config.checkpoint_search_paths("run1") == ["/tmp/test-perm/run1/step-100"]
365+
366+
351367
def test_checkpointer_config_propagates_debug_settings():
352368
config = CheckpointerConfig(
353369
base_path="/tmp/checkpoints",

0 commit comments

Comments
 (0)