Skip to content

fix(checkpoint): honor train_time_interval under manual_optimization#21699

Open
wietzesuijker wants to merge 1 commit intoLightning-AI:masterfrom
wietzesuijker:fix/manual-opt-train-time-interval
Open

fix(checkpoint): honor train_time_interval under manual_optimization#21699
wietzesuijker wants to merge 1 commit intoLightning-AI:masterfrom
wietzesuijker:fix/manual-opt-train-time-interval

Conversation

@wietzesuijker
Copy link
Copy Markdown

@wietzesuijker wietzesuijker commented May 4, 2026

What does this PR do?

ModelCheckpoint(train_time_interval=...) silently no-ops under automatic_optimization=False. The manual-opt branch in on_train_batch_end only inspects every_n_train_steps, so a callback configured with only a time interval never fires mid-run. last.ckpt still appears at end-of-fit via on_train_end, which masks the bug at the file level but breaks any workflow that depends on mid-run saves (chained SLURM segments, spot/preempt training).

Surfaced debugging a manual-opt GAN run on chained 3h SLURM segments where every segment threw away the previous one's work. No prior issue filed.

Fix

Mirror the auto-opt branch's skip_batch + skip_time logic inside the manual-opt branch, reusing the existing _train_time_interval / _last_time_checked / trainer.strategy.broadcast plumbing. Pre-optimization-state warning path unchanged.

Test

test_model_checkpoint_manual_opt_train_time_interval installs a spy callback that observes _last_global_step_saved after each batch. End-of-fit file checks miss the bug because on_train_end saves once anyway. Verified test fails on master, passes with the patch. Broader checkpoint suite (test_model_checkpoint.py, test_checkpoint_callback_frequency.py, test_model_checkpoint_manual_opt.py) all green locally.

Before submitting
  • Read contributor guideline
  • PR does one thing
  • Tests pass locally
  • CHANGELOG entry added under unreleased Fixed. No breaking changes.

AI (Claude) supported my development of this PR.

ModelCheckpoint silently dropped train_time_interval when the LightningModule used manual optimization. The manual-opt branch in on_train_batch_end only checked every_n_train_steps, so a callback configured with `train_time_interval=timedelta(minutes=15)` and no step trigger never fired mid-run. last.ckpt did still appear at fit completion via on_train_end, which made the bug invisible to most tests but broke any workflow that relies on mid-run saves -- chained SLURM segments resuming from epoch 0 every time, spot/preempt training losing all in-flight progress, etc.

The fix mirrors the auto-opt branch's skip_batch + skip_time logic so a save fires when either trigger is satisfied. The new regression test uses a spy callback to observe _last_global_step_saved during fit, since checking the file at end-of-run misses the bug entirely.
@github-actions github-actions Bot added the pl Generic label for PyTorch Lightning package label May 4, 2026
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
# in case we have time differences across ranks
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
skip_time = trainer.strategy.broadcast(skip_time)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably always do that. independent of if train_time_interval is None. worst case we broadcast false but at least that way we prevent a case where some ranks could theoretically hang because they don't expect the broadcast operation (i.e. if code gets out of sync or so)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants