Skip to content

Commit a5b2b51

Browse files
Calvin-Xuclaude
andcommitted
[levanter] Wait for final checkpoint async save before process exit
The final checkpoint save in Levanter is async — the GCS write runs in a background thread while the process continues. In train_dpo.py, the post-training code attempted to wait via `wait_until_finished()`, but called it on a *newly created* Checkpointer instance (which had nothing to wait for) instead of the original one that holds the pending save. train_lm.py had no wait at all. If preemption (or a normal exit) kills the process while the async write is still in flight, the only checkpoint at the target training step is silently lost. This is especially damaging when the target step is not a multiple of the permanent checkpoint interval (e.g. step 9917 with `keep=[dict(every=10000)]`), since every checkpoint below 10000 is temporary and subject to deletion on restart. Fix: - Store the Checkpointer as `self._checkpointer` on the Trainer so it is accessible after training. - train_lm.py: add `trainer._checkpointer.wait_until_finished()` after training completes. - train_dpo.py: replace the broken new-instance wait with `trainer._checkpointer.wait_until_finished()`. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c2df1e7 commit a5b2b51

3 files changed

Lines changed: 15 additions & 6 deletions

File tree

lib/levanter/src/levanter/main/train_dpo.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -488,12 +488,14 @@ def save_policy_hf_checkpoint(step):
488488
else:
489489
train_loader = train_loader.iter_from_step(0)
490490

491-
last_info = trainer.train(state, train_loader)
491+
trainer.train(state, train_loader)
492492

493+
# Ensure the final checkpoint's async GCS write completes before the
494+
# process exits. Without this, preemption (or a normal exit race) can
495+
# kill the process while the write is still in flight, silently losing
496+
# the only checkpoint at the target training step.
493497
if trainer.config.checkpointer is not None:
494-
trainer.run_hooks(last_info, force=True)
495-
checkpointer = trainer.config.checkpointer.create(trainer.run_id)
496-
checkpointer.wait_until_finished()
498+
trainer._checkpointer.wait_until_finished()
497499

498500
trainer.tracker.finish()
499501

lib/levanter/src/levanter/main/train_lm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,13 @@ def compute_logits(model: LmHeadModel, example: LmExample):
319319
## OK, actually run training!
320320
trainer.train(state, train_loader)
321321

322+
# Ensure the final checkpoint's async GCS write completes before the
323+
# process exits. Without this, preemption (or a normal exit race) can
324+
# kill the process while the write is still in flight, silently losing
325+
# the only checkpoint at the target training step.
326+
if trainer.config.checkpointer is not None:
327+
trainer._checkpointer.wait_until_finished()
328+
322329
# This isn't necessary except when Levanter is run in a subprocess (as happens w/ ray)
323330
trainer.tracker.finish()
324331

lib/levanter/src/levanter/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,10 +579,10 @@ def _add_default_hooks(self):
579579
self.add_hook(levanter.callbacks.pbar_logger(total=self.config.num_train_steps), every=1)
580580
self.add_hook(levanter.callbacks.log_step_info(self.config.num_train_steps), every=1)
581581
# engine.add_hook(callbacks.log_memory_usage(), every=1)
582-
checkpointer = self.config.checkpointer.create(self.run_id)
582+
self._checkpointer = self.config.checkpointer.create(self.run_id)
583583

584584
def checkpoint_hook(info, force=False):
585-
checkpointer.on_step(tree=info.state.saveable_state, step=info.step, force=force)
585+
self._checkpointer.on_step(tree=info.state.saveable_state, step=info.step, force=force)
586586

587587
self.add_hook(checkpoint_hook, every=1) # checkpointer manages its own frequency
588588

0 commit comments

Comments
 (0)