Skip to content

Commit 832130c

Browse files
committed
fix(checkpoint): honor train_time_interval under manual_optimization
ModelCheckpoint silently dropped train_time_interval when the LightningModule used manual optimization. The manual-opt branch in on_train_batch_end only checked every_n_train_steps, so a callback configured with `train_time_interval=timedelta(minutes=15)` and no step trigger never fired mid-run. last.ckpt did still appear at fit completion via on_train_end, which made the bug invisible to most tests but broke any workflow that relies on mid-run saves -- chained SLURM segments resuming from epoch 0 every time, spot/preempt training losing all in-flight progress, etc. The fix mirrors the auto-opt branch's skip_batch + skip_time logic so a save fires when either trigger is satisfied. The new regression test uses a spy callback to observe _last_global_step_saved during fit, since checking the file at end-of-run misses the bug entirely.
1 parent 0e20e15 commit 832130c

3 files changed

Lines changed: 75 additions & 3 deletions

File tree

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626

2727
- Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623))
2828

29+
- Fixed `ModelCheckpoint(train_time_interval=...)` silently no-op'ing under `automatic_optimization=False`; the manual-optimization branch in `on_train_batch_end` now mirrors the auto-opt branch and fires on time-based saves as well as `every_n_train_steps`
30+
2931
---
3032

3133
## [2.6.2] - 2026-03-19

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,27 @@ def on_train_batch_end(
348348
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
349349
# For manual optimization, we need to handle saving differently
350350
if not pl_module.automatic_optimization:
351-
# Skip if we don't need to save at this step
352-
if self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0):
351+
# Mirror the auto-opt branch: a save fires when EITHER `every_n_train_steps`
352+
# OR `train_time_interval` is satisfied. Without this, `train_time_interval`
353+
# silently no-ops under manual optimization and `last.ckpt` is never written
354+
# mid-run when `every_n_train_steps` is not also configured.
355+
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
356+
357+
train_time_interval = self._train_time_interval
358+
skip_time = True
359+
now = time.monotonic()
360+
# Important: allow zero timedelta as a valid interval
361+
if train_time_interval is not None:
362+
prev_time_check = self._last_time_checked
363+
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
364+
# in case we have time differences across ranks
365+
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
366+
skip_time = trainer.strategy.broadcast(skip_time)
367+
368+
if skip_batch and skip_time:
353369
return
370+
if not skip_time:
371+
self._last_time_checked = now
354372

355373
# Check if we should skip due to trainer/callback state
356374
if self._should_skip_saving_checkpoint(trainer):

tests/tests_pytorch/callbacks/test_model_checkpoint_manual_opt.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import warnings
44
from contextlib import contextmanager
55
from copy import deepcopy
6+
from datetime import timedelta
67
from pathlib import Path
78

89
import torch
910
from torch.utils.data import DataLoader, Dataset
1011

11-
from lightning.pytorch import LightningModule, Trainer
12+
from lightning.pytorch import Callback, LightningModule, Trainer
1213
from lightning.pytorch.callbacks import ModelCheckpoint
1314

1415

@@ -180,3 +181,54 @@ def training_step(self, batch, batch_idx):
180181
# Verify our warning was raised
181182
assert len(manual_opt_warnings) > 0, "Expected warning about manual optimization not found"
182183
assert "The checkpoint will contain the model state AFTER optimization" in manual_opt_warnings[0]
184+
185+
186+
def test_model_checkpoint_manual_opt_train_time_interval():
187+
"""Regression: ``train_time_interval`` must fire mid-run under manual optimization.
188+
189+
Before the fix, the manual-optimization branch in ``on_train_batch_end`` only
190+
inspected ``every_n_train_steps`` and silently no-op'd when ``train_time_interval``
191+
was the only configured trigger. ``last.ckpt`` was still written by ``on_train_end``,
192+
so end-of-run state checks miss the bug -- this test asserts the mid-run save by
193+
observing ``_last_global_step_saved`` from a spy callback queued after the
194+
``ModelCheckpoint``.
195+
"""
196+
saved_steps_during_training = []
197+
198+
class _Spy(Callback):
199+
def __init__(self, ckpt: ModelCheckpoint) -> None:
200+
self.ckpt = ckpt
201+
202+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
203+
saved_steps_during_training.append(self.ckpt._last_global_step_saved)
204+
205+
with cleanup_after_test(), tempfile.TemporaryDirectory() as tmpdir:
206+
dataset = FakeDataset()
207+
train_dataloader = DataLoader(dataset, batch_size=1)
208+
model = SimpleModule()
209+
ckpt = ModelCheckpoint(
210+
dirpath=tmpdir,
211+
save_top_k=0,
212+
save_last=True,
213+
train_time_interval=timedelta(seconds=0),
214+
save_weights_only=True,
215+
)
216+
trainer = Trainer(
217+
max_epochs=1,
218+
callbacks=[ckpt, _Spy(ckpt)],
219+
log_every_n_steps=1,
220+
num_sanity_val_steps=0,
221+
logger=False,
222+
)
223+
try:
224+
trainer.fit(model, train_dataloader)
225+
finally:
226+
trainer._teardown()
227+
228+
# With ``train_time_interval=0``, the callback must fire on every batch.
229+
# Pre-fix the value stayed at 0 until ``on_train_end`` saved once.
230+
assert any(step > 0 for step in saved_steps_during_training), (
231+
"ModelCheckpoint(train_time_interval=...) silently no-op'd mid-run under manual_optimization; "
232+
f"observed _last_global_step_saved values during training: {saved_steps_during_training}"
233+
)
234+
assert (Path(tmpdir) / "last.ckpt").exists()

0 commit comments

Comments
 (0)