fix(checkpoint): honor train_time_interval under manual_optimization#21699
Open
wietzesuijker wants to merge 1 commit into
Open
fix(checkpoint): honor train_time_interval under manual_optimization#21699wietzesuijker wants to merge 1 commit into
wietzesuijker wants to merge 1 commit into
Conversation
justusschock
approved these changes
May 5, 2026
| skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds() | ||
| # in case we have time differences across ranks | ||
| # broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs | ||
| skip_time = trainer.strategy.broadcast(skip_time) |
Member
There was a problem hiding this comment.
we should probably always do that. independent of if train_time_interval is None. worst case we broadcast false but at least that way we prevent a case where some ranks could theoretically hang because they don't expect the broadcast operation (i.e. if code gets out of sync or so)
Author
There was a problem hiding this comment.
Thanks, fixed in 82ffbf1. Broadcast now runs unconditionally; when train_time_interval is None we broadcast True (skip), which is a no-op for saving but keeps every rank on the collective. Added a test that wraps strategy.broadcast and confirms it fires per batch under manual_opt. Rebased on master.
The manual-opt branch of on_train_batch_end only checked every_n_train_steps; train_time_interval silently no-op'd mid-run. last.ckpt only appeared at fit completion via on_train_end, which breaks chained-SLURM and spot/preempt training that rely on periodic saves. Mirror the auto-opt skip_batch + skip_time logic. Broadcast skip_time unconditionally so all ranks reach the collective even when train_time_interval is None. Tests: - train_time_interval fires mid-run (spy on _last_global_step_saved) - strategy.broadcast called per batch under manual_optimization
832130c to
82ffbf1
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
ModelCheckpoint(train_time_interval=...)silently no-ops underautomatic_optimization=False. The manual-opt branch inon_train_batch_endonly inspectsevery_n_train_steps, so a callback configured with only a time interval never fires mid-run.last.ckptstill appears at end-of-fit viaon_train_end, which masks the bug at the file level but breaks any workflow that depends on mid-run saves (chained SLURM segments, spot/preempt training).Surfaced debugging a manual-opt GAN run on chained 3h SLURM segments where every segment threw away the previous one's work. No prior issue filed.
Fix
Mirror the auto-opt branch's
skip_batch + skip_timelogic inside the manual-opt branch, reusing the existing_train_time_interval/_last_time_checked/trainer.strategy.broadcastplumbing. Pre-optimization-state warning path unchanged.Test
test_model_checkpoint_manual_opt_train_time_intervalinstalls a spy callback that observes_last_global_step_savedafter each batch. End-of-fit file checks miss the bug becauseon_train_endsaves once anyway. Verified test fails on master, passes with the patch. Broader checkpoint suite (test_model_checkpoint.py,test_checkpoint_callback_frequency.py,test_model_checkpoint_manual_opt.py) all green locally.Before submitting
AI (Claude) supported my development of this PR.