fix(checkpoint): honor train_time_interval under manual_optimization#21699
Open
wietzesuijker wants to merge 1 commit intoLightning-AI:masterfrom
Open
fix(checkpoint): honor train_time_interval under manual_optimization#21699wietzesuijker wants to merge 1 commit intoLightning-AI:masterfrom
wietzesuijker wants to merge 1 commit intoLightning-AI:masterfrom
Conversation
ModelCheckpoint silently dropped train_time_interval when the LightningModule used manual optimization. The manual-opt branch in on_train_batch_end only checked every_n_train_steps, so a callback configured with `train_time_interval=timedelta(minutes=15)` and no step trigger never fired mid-run. last.ckpt did still appear at fit completion via on_train_end, which made the bug invisible to most tests but broke any workflow that relies on mid-run saves -- chained SLURM segments resuming from epoch 0 every time, spot/preempt training losing all in-flight progress, etc. The fix mirrors the auto-opt branch's skip_batch + skip_time logic so a save fires when either trigger is satisfied. The new regression test uses a spy callback to observe _last_global_step_saved during fit, since checking the file at end-of-run misses the bug entirely.
justusschock
approved these changes
May 5, 2026
| skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds() | ||
| # in case we have time differences across ranks | ||
| # broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs | ||
| skip_time = trainer.strategy.broadcast(skip_time) |
Member
There was a problem hiding this comment.
we should probably always do that. independent of if train_time_interval is None. worst case we broadcast false but at least that way we prevent a case where some ranks could theoretically hang because they don't expect the broadcast operation (i.e. if code gets out of sync or so)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
ModelCheckpoint(train_time_interval=...)silently no-ops underautomatic_optimization=False. The manual-opt branch inon_train_batch_endonly inspectsevery_n_train_steps, so a callback configured with only a time interval never fires mid-run.last.ckptstill appears at end-of-fit viaon_train_end, which masks the bug at the file level but breaks any workflow that depends on mid-run saves (chained SLURM segments, spot/preempt training).Surfaced debugging a manual-opt GAN run on chained 3h SLURM segments where every segment threw away the previous one's work. No prior issue filed.
Fix
Mirror the auto-opt branch's
skip_batch + skip_timelogic inside the manual-opt branch, reusing the existing_train_time_interval/_last_time_checked/trainer.strategy.broadcastplumbing. Pre-optimization-state warning path unchanged.Test
test_model_checkpoint_manual_opt_train_time_intervalinstalls a spy callback that observes_last_global_step_savedafter each batch. End-of-fit file checks miss the bug becauseon_train_endsaves once anyway. Verified test fails on master, passes with the patch. Broader checkpoint suite (test_model_checkpoint.py,test_checkpoint_callback_frequency.py,test_model_checkpoint_manual_opt.py) all green locally.Before submitting
AI (Claude) supported my development of this PR.