Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623))
- Fixed `trainer.test()` restoring the incorrect epoch and global step when given a specific `ckpt_path` ([#20052](https://github.com/Lightning-AI/pytorch-lightning/issues/20052))

---

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,13 @@ def restore_loops(self) -> None:
if self.trainer.state.fn == TrainerFn.FITTING:
fit_loop.load_state_dict(state_dict["fit_loop"])
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self._restore_fit_progress_for_evaluation(state_dict["fit_loop"])
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])
elif self.trainer.state.fn == TrainerFn.TESTING:
self._restore_fit_progress_for_evaluation(state_dict["fit_loop"])
self.trainer.test_loop.load_state_dict(state_dict["test_loop"])
elif self.trainer.state.fn == TrainerFn.PREDICTING:
self._restore_fit_progress_for_evaluation(state_dict["fit_loop"])
self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"])

if self.trainer.state.fn != TrainerFn.FITTING:
Expand All @@ -364,6 +367,23 @@ def restore_loops(self) -> None:
f" but you have set Trainer(max_epochs={self.trainer.max_epochs})."
)

def _restore_fit_progress_for_evaluation(self, fit_loop_state_dict: dict[str, Any]) -> None:
"""Restores the fit-loop counters that back ``trainer.current_epoch`` and ``trainer.global_step``.

Evaluation-only entry points should not load the full fit loop because that marks the loop as restarting.
However, loggers and hooks still read their epoch and step from the fit-loop progress.

"""
fit_loop = self.trainer.fit_loop
fit_loop.epoch_progress.load_state_dict(fit_loop_state_dict["epoch_progress"])
fit_loop.epoch_loop.automatic_optimization.optim_progress.load_state_dict(
fit_loop_state_dict["epoch_loop.automatic_optimization.optim_progress"]
)
fit_loop.epoch_loop.manual_optimization.optim_step_progress.load_state_dict(
fit_loop_state_dict["epoch_loop.manual_optimization.optim_step_progress"]
)
fit_loop.epoch_loop.on_load_checkpoint(fit_loop_state_dict["epoch_loop.state_dict"])

def restore_optimizers_and_schedulers(self) -> None:
"""Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint."""
if not self._loaded_checkpoint:
Expand Down
38 changes: 38 additions & 0 deletions tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,44 @@ def validation_step(self, batch, batch_idx):
assert f"epoch={idx + 1}" in best_model_path


def test_test_ckpt_path_restores_fit_progress_for_test_hooks(tmp_path):
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{epoch}", save_top_k=-1)

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.on_test_start_current_epoch = None
self.on_test_start_global_step = None

def on_test_start(self):
self.on_test_start_current_epoch = self.trainer.current_epoch
self.on_test_start_global_step = self.trainer.global_step

model = TestModel()
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=3,
limit_train_batches=2,
limit_val_batches=0,
limit_test_batches=1,
callbacks=[checkpoint_callback],
enable_progress_bar=False,
logger=False,
)
trainer.fit(model)
assert trainer.current_epoch == 3

checkpoint_path = tmp_path / "epoch=1.ckpt"
checkpoint = torch.load(checkpoint_path, weights_only=False)

trainer.test(model, ckpt_path=checkpoint_path, verbose=False)

assert (model.on_test_start_current_epoch, model.on_test_start_global_step) == (
checkpoint["epoch"],
checkpoint["global_step"],
)


def test_trainer_save_checkpoint_storage_options(tmp_path, xla_available):
"""This test validates that storage_options argument is properly passed to ``CheckpointIO``"""
model = BoringModel()
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _check_model_state_dict(self):

def _test_on_val_test_predict_start(self):
assert self.trainer.current_epoch == state_dict["epoch"]
assert self.trainer.global_step == 0
assert self.trainer.global_step == state_dict["global_step"]
assert self._check_model_state_dict()

def on_train_start(self):
Expand Down