Skip to content

fix(checkpoint): honor train_time_interval under manual_optimization#21699

Open
wietzesuijker wants to merge 1 commit into
Lightning-AI:masterfrom
wietzesuijker:fix/manual-opt-train-time-interval
Open

fix(checkpoint): honor train_time_interval under manual_optimization#21699
wietzesuijker wants to merge 1 commit into
Lightning-AI:masterfrom
wietzesuijker:fix/manual-opt-train-time-interval

Conversation

@wietzesuijker
Copy link
Copy Markdown

@wietzesuijker wietzesuijker commented May 4, 2026

What does this PR do?

ModelCheckpoint(train_time_interval=...) silently no-ops under automatic_optimization=False. The manual-opt branch in on_train_batch_end only inspects every_n_train_steps, so a callback configured with only a time interval never fires mid-run. last.ckpt still appears at end-of-fit via on_train_end, which masks the bug at the file level but breaks any workflow that depends on mid-run saves (chained SLURM segments, spot/preempt training).

Surfaced debugging a manual-opt GAN run on chained 3h SLURM segments where every segment threw away the previous one's work. No prior issue filed.

Fix

Mirror the auto-opt branch's skip_batch + skip_time logic inside the manual-opt branch, reusing the existing _train_time_interval / _last_time_checked / trainer.strategy.broadcast plumbing. Pre-optimization-state warning path unchanged.

Test

test_model_checkpoint_manual_opt_train_time_interval installs a spy callback that observes _last_global_step_saved after each batch. End-of-fit file checks miss the bug because on_train_end saves once anyway. Verified test fails on master, passes with the patch. Broader checkpoint suite (test_model_checkpoint.py, test_checkpoint_callback_frequency.py, test_model_checkpoint_manual_opt.py) all green locally.

Before submitting
  • Read contributor guideline
  • PR does one thing
  • Tests pass locally
  • CHANGELOG entry added under unreleased Fixed. No breaking changes.

AI (Claude) supported my development of this PR.

@github-actions github-actions Bot added the pl Generic label for PyTorch Lightning package label May 4, 2026
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
# in case we have time differences across ranks
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
skip_time = trainer.strategy.broadcast(skip_time)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably always do that. independent of if train_time_interval is None. worst case we broadcast false but at least that way we prevent a case where some ranks could theoretically hang because they don't expect the broadcast operation (i.e. if code gets out of sync or so)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed in 82ffbf1. Broadcast now runs unconditionally; when train_time_interval is None we broadcast True (skip), which is a no-op for saving but keeps every rank on the collective. Added a test that wraps strategy.broadcast and confirms it fires per batch under manual_opt. Rebased on master.

The manual-opt branch of on_train_batch_end only checked
every_n_train_steps; train_time_interval silently no-op'd mid-run.
last.ckpt only appeared at fit completion via on_train_end, which
breaks chained-SLURM and spot/preempt training that rely on
periodic saves.

Mirror the auto-opt skip_batch + skip_time logic. Broadcast
skip_time unconditionally so all ranks reach the collective even
when train_time_interval is None.

Tests:
- train_time_interval fires mid-run (spy on _last_global_step_saved)
- strategy.broadcast called per batch under manual_optimization
@wietzesuijker wietzesuijker force-pushed the fix/manual-opt-train-time-interval branch from 832130c to 82ffbf1 Compare May 20, 2026 19:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants