Skip to content

Added warmup parameter to early stopping cb #20778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/lightning/pytorch/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from torch import Tensor
from typing_extensions import override

import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_prefixed_message, rank_zero_warn

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,6 +64,8 @@ class EarlyStopping(Callback):
check_on_train_epoch_end: whether to run early stopping at the end of the training epoch.
If this is ``False``, then the check runs at the end of the validation.
log_rank_zero_only: When set ``True``, logs the status of the early stopping callback only for rank 0 process.
start_from_epoch: the epoch from which to start monitoring for early stopping. Defaults to 0 (start from the
beginning). Set to a higher value to let the model train for a minimum number of epochs before monitoring.

Raises:
MisconfigurationException:
Expand All @@ -73,11 +75,15 @@ class EarlyStopping(Callback):

Example::

>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import EarlyStopping
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(callbacks=[early_stopping])

>>> # Start monitoring only from epoch 5
>>> early_stopping = EarlyStopping('val_loss', start_from_epoch=5)
>>> trainer = Trainer(callbacks=[early_stopping])

.. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
following arguments:

Expand All @@ -104,6 +110,7 @@ def __init__(
divergence_threshold: Optional[float] = None,
check_on_train_epoch_end: Optional[bool] = None,
log_rank_zero_only: bool = False,
start_from_epoch: int = 0,
):
super().__init__()
self.monitor = monitor
Expand All @@ -119,6 +126,7 @@ def __init__(
self.stopped_epoch = 0
self._check_on_train_epoch_end = check_on_train_epoch_end
self.log_rank_zero_only = log_rank_zero_only
self.start_from_epoch = start_from_epoch

if self.mode not in self.mode_dict:
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
Expand Down Expand Up @@ -179,7 +187,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.patience = state_dict["patience"]

def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
from lightning.pytorch.trainer.states import TrainerFn
from pytorch_lightning.trainer.states import TrainerFn

return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking

Expand All @@ -197,6 +205,10 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul

def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
"""Checks whether the early stopping condition is met and if so tells the trainer to stop the training."""
# Skip early stopping check if current epoch is less than start_from_epoch
if trainer.current_epoch < self.start_from_epoch:
return

logs = trainer.callback_metrics

if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run
Expand Down
32 changes: 32 additions & 0 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,35 @@ def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, ex
log_mock.assert_called_once_with(expected_log)
else:
log_mock.assert_not_called()


def test_early_stopping_start_from_epoch(tmp_path):
"""Test that early stopping checks only activate after start_from_epoch."""
losses = [6, 5, 4, 3, 2, 1] # decreasing losses
start_from_epoch = 3

class CurrentModel(BoringModel):
def on_validation_epoch_end(self):
val_loss = losses[self.current_epoch]
self.log("val_loss", val_loss)

model = CurrentModel()

# Mock the _run_early_stopping_check method to verify when it's called
with mock.patch("lightning.pytorch.callbacks.early_stopping.EarlyStopping._evaluate_stopping_criteria") as es_mock:
es_mock.return_value = (False, "")
early_stopping = EarlyStopping(monitor="val_loss", start_from_epoch=start_from_epoch)
trainer = Trainer(
default_root_dir=tmp_path,
callbacks=[early_stopping],
limit_train_batches=0.2,
limit_val_batches=0.2,
max_epochs=len(losses),
)
trainer.fit(model)

# Check that _evaluate_stopping_criteria is not called for epochs before start_from_epoch
assert es_mock.call_count == len(losses) - start_from_epoch
# Check that only the correct epochs were processed
for i, call_args in enumerate(es_mock.call_args_list):
assert torch.allclose(call_args[0][0], torch.tensor(losses[i + start_from_epoch]))
Loading