|
24 | 24 | import levanter.tracker |
25 | 25 | from levanter.callbacks.state_adapter import StateCallbackRunner |
26 | 26 | from levanter.callbacks.watch import WatchConfig, compute_watch_stats |
27 | | -from levanter.checkpoint import load_checkpoint |
| 27 | +from experiments.grug.checkpointing import restore_grug_state_from_checkpoint |
28 | 28 | from levanter.data import AsyncDataset, DataLoader |
29 | 29 | from levanter.data.mixture import MixtureDataset, rescale_mixture_schedule_for_batch_schedule |
30 | 30 | from levanter.data.text import GrugLmExample, LmDataConfig |
@@ -389,23 +389,13 @@ def _init_state(model_rng): |
389 | 389 | checkpoint_path = trainer.load_checkpoint_path |
390 | 390 | if checkpoint_path is None and checkpointer is not None: |
391 | 391 | checkpoint_path = trainer.checkpointer.expanded_path(run_id) |
392 | | - if checkpoint_path is None: |
393 | | - if trainer.load_checkpoint: |
394 | | - raise FileNotFoundError("load_checkpoint=True but no checkpoint path is configured.") |
395 | | - elif trainer.load_checkpoint is not False: |
396 | | - try: |
397 | | - state = load_checkpoint( |
398 | | - state, |
399 | | - checkpoint_path, |
400 | | - discover_latest=True, |
401 | | - axis_mapping=None, |
402 | | - mesh=mesh, |
403 | | - allow_partial=trainer.allow_partial_checkpoint, |
404 | | - ) |
405 | | - except FileNotFoundError: |
406 | | - if trainer.load_checkpoint is True: |
407 | | - raise |
408 | | - logger.info(f"Checkpoint not found at {checkpoint_path}. Starting from scratch.") |
| 392 | + state = restore_grug_state_from_checkpoint( |
| 393 | + state, |
| 394 | + checkpoint_path=checkpoint_path, |
| 395 | + load_checkpoint_setting=trainer.load_checkpoint, |
| 396 | + mesh=mesh, |
| 397 | + allow_partial=trainer.allow_partial_checkpoint, |
| 398 | + ) |
409 | 399 |
|
410 | 400 | levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)}) |
411 | 401 |
|
@@ -513,11 +503,11 @@ def _init_state(model_rng): |
513 | 503 | levanter.tracker.log(watch_stats, step=step) |
514 | 504 |
|
515 | 505 | if checkpointer is not None: |
516 | | - checkpointer.on_step(tree={"train_state": state}, step=int(state.step)) |
| 506 | + checkpointer.on_step(tree=state, step=int(state.step)) |
517 | 507 | finally: |
518 | 508 | state_callbacks.run(state, loss=last_loss, step_duration=last_step_duration, force=True) |
519 | 509 | if checkpointer is not None: |
520 | | - checkpointer.on_step(tree={"train_state": state}, step=int(state.step), force=True) |
| 510 | + checkpointer.on_step(tree=state, step=int(state.step), force=True) |
521 | 511 | checkpointer.wait_until_finished() |
522 | 512 |
|
523 | 513 | levanter.tracker.current_tracker().finish() |
|
0 commit comments