Skip to content

Commit fa54eb7

Browse files
Helw150claude
andcommitted
Fix grug/moe checkpoint resume format
Use restore_grug_state_from_checkpoint and save plain state instead of {"train_state": state} wrapper. Matches PR #3790. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c8bf2c7 commit fa54eb7

1 file changed

Lines changed: 10 additions & 20 deletions

File tree

experiments/grug/moe/train.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import levanter.tracker
2525
from levanter.callbacks.state_adapter import StateCallbackRunner
2626
from levanter.callbacks.watch import WatchConfig, compute_watch_stats
27-
from levanter.checkpoint import load_checkpoint
27+
from experiments.grug.checkpointing import restore_grug_state_from_checkpoint
2828
from levanter.data import AsyncDataset, DataLoader
2929
from levanter.data.mixture import MixtureDataset, rescale_mixture_schedule_for_batch_schedule
3030
from levanter.data.text import GrugLmExample, LmDataConfig
@@ -389,23 +389,13 @@ def _init_state(model_rng):
389389
checkpoint_path = trainer.load_checkpoint_path
390390
if checkpoint_path is None and checkpointer is not None:
391391
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
392-
if checkpoint_path is None:
393-
if trainer.load_checkpoint:
394-
raise FileNotFoundError("load_checkpoint=True but no checkpoint path is configured.")
395-
elif trainer.load_checkpoint is not False:
396-
try:
397-
state = load_checkpoint(
398-
state,
399-
checkpoint_path,
400-
discover_latest=True,
401-
axis_mapping=None,
402-
mesh=mesh,
403-
allow_partial=trainer.allow_partial_checkpoint,
404-
)
405-
except FileNotFoundError:
406-
if trainer.load_checkpoint is True:
407-
raise
408-
logger.info(f"Checkpoint not found at {checkpoint_path}. Starting from scratch.")
392+
state = restore_grug_state_from_checkpoint(
393+
state,
394+
checkpoint_path=checkpoint_path,
395+
load_checkpoint_setting=trainer.load_checkpoint,
396+
mesh=mesh,
397+
allow_partial=trainer.allow_partial_checkpoint,
398+
)
409399

410400
levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)})
411401

@@ -513,11 +503,11 @@ def _init_state(model_rng):
513503
levanter.tracker.log(watch_stats, step=step)
514504

515505
if checkpointer is not None:
516-
checkpointer.on_step(tree={"train_state": state}, step=int(state.step))
506+
checkpointer.on_step(tree=state, step=int(state.step))
517507
finally:
518508
state_callbacks.run(state, loss=last_loss, step_duration=last_step_duration, force=True)
519509
if checkpointer is not None:
520-
checkpointer.on_step(tree={"train_state": state}, step=int(state.step), force=True)
510+
checkpointer.on_step(tree=state, step=int(state.step), force=True)
521511
checkpointer.wait_until_finished()
522512

523513
levanter.tracker.current_tracker().finish()

0 commit comments

Comments
 (0)