Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623))

- Fixed `ModelCheckpoint(train_time_interval=...)` silently no-op'ing under `automatic_optimization=False`; the manual-optimization branch in `on_train_batch_end` now mirrors the auto-opt branch and fires on time-based saves as well as `every_n_train_steps`

---

## [2.6.2] - 2026-03-19
Expand Down
22 changes: 20 additions & 2 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,27 @@ def on_train_batch_end(
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
# For manual optimization, we need to handle saving differently
if not pl_module.automatic_optimization:
# Skip if we don't need to save at this step
if self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0):
# Mirror the auto-opt branch: a save fires when EITHER `every_n_train_steps`
# OR `train_time_interval` is satisfied. Without this, `train_time_interval`
# silently no-ops under manual optimization and `last.ckpt` is never written
# mid-run when `every_n_train_steps` is not also configured.
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)

train_time_interval = self._train_time_interval
skip_time = True
now = time.monotonic()
# Important: allow zero timedelta as a valid interval
if train_time_interval is not None:
prev_time_check = self._last_time_checked
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
# in case we have time differences across ranks
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
skip_time = trainer.strategy.broadcast(skip_time)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably always do that. independent of if train_time_interval is None. worst case we broadcast false but at least that way we prevent a case where some ranks could theoretically hang because they don't expect the broadcast operation (i.e. if code gets out of sync or so)


if skip_batch and skip_time:
return
if not skip_time:
self._last_time_checked = now

# Check if we should skip due to trainer/callback state
if self._should_skip_saving_checkpoint(trainer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import warnings
from contextlib import contextmanager
from copy import deepcopy
from datetime import timedelta
from pathlib import Path

import torch
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch import Callback, LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint


Expand Down Expand Up @@ -180,3 +181,54 @@ def training_step(self, batch, batch_idx):
# Verify our warning was raised
assert len(manual_opt_warnings) > 0, "Expected warning about manual optimization not found"
assert "The checkpoint will contain the model state AFTER optimization" in manual_opt_warnings[0]


def test_model_checkpoint_manual_opt_train_time_interval():
"""Regression: ``train_time_interval`` must fire mid-run under manual optimization.

Before the fix, the manual-optimization branch in ``on_train_batch_end`` only
inspected ``every_n_train_steps`` and silently no-op'd when ``train_time_interval``
was the only configured trigger. ``last.ckpt`` was still written by ``on_train_end``,
so end-of-run state checks miss the bug -- this test asserts the mid-run save by
observing ``_last_global_step_saved`` from a spy callback queued after the
``ModelCheckpoint``.
"""
saved_steps_during_training = []

class _Spy(Callback):
def __init__(self, ckpt: ModelCheckpoint) -> None:
self.ckpt = ckpt

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
saved_steps_during_training.append(self.ckpt._last_global_step_saved)

with cleanup_after_test(), tempfile.TemporaryDirectory() as tmpdir:
dataset = FakeDataset()
train_dataloader = DataLoader(dataset, batch_size=1)
model = SimpleModule()
ckpt = ModelCheckpoint(
dirpath=tmpdir,
save_top_k=0,
save_last=True,
train_time_interval=timedelta(seconds=0),
save_weights_only=True,
)
trainer = Trainer(
max_epochs=1,
callbacks=[ckpt, _Spy(ckpt)],
log_every_n_steps=1,
num_sanity_val_steps=0,
logger=False,
)
try:
trainer.fit(model, train_dataloader)
finally:
trainer._teardown()

# With ``train_time_interval=0``, the callback must fire on every batch.
# Pre-fix the value stayed at 0 until ``on_train_end`` saved once.
assert any(step > 0 for step in saved_steps_during_training), (
"ModelCheckpoint(train_time_interval=...) silently no-op'd mid-run under manual_optimization; "
f"observed _last_global_step_saved values during training: {saved_steps_during_training}"
)
assert (Path(tmpdir) / "last.ckpt").exists()
Loading