diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index d108894f614e6..393d8f699f80e 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -26,10 +26,10 @@ from torch import Tensor from typing_extensions import override -import lightning.pytorch as pl -from lightning.pytorch.callbacks.callback import Callback -from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn +import pytorch_lightning as pl +from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_prefixed_message, rank_zero_warn log = logging.getLogger(__name__) @@ -64,6 +64,8 @@ class EarlyStopping(Callback): check_on_train_epoch_end: whether to run early stopping at the end of the training epoch. If this is ``False``, then the check runs at the end of the validation. log_rank_zero_only: When set ``True``, logs the status of the early stopping callback only for rank 0 process. + start_from_epoch: the epoch from which to start monitoring for early stopping. Defaults to 0 (start from the + beginning). Set to a higher value to let the model train for a minimum number of epochs before monitoring. Raises: MisconfigurationException: @@ -73,11 +75,15 @@ class EarlyStopping(Callback): Example:: - >>> from lightning.pytorch import Trainer - >>> from lightning.pytorch.callbacks import EarlyStopping + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import EarlyStopping >>> early_stopping = EarlyStopping('val_loss') >>> trainer = Trainer(callbacks=[early_stopping]) + >>> # Start monitoring only from epoch 5 + >>> early_stopping = EarlyStopping('val_loss', start_from_epoch=5) + >>> trainer = Trainer(callbacks=[early_stopping]) + .. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the following arguments: @@ -104,6 +110,7 @@ def __init__( divergence_threshold: Optional[float] = None, check_on_train_epoch_end: Optional[bool] = None, log_rank_zero_only: bool = False, + start_from_epoch: int = 0, ): super().__init__() self.monitor = monitor @@ -119,6 +126,7 @@ def __init__( self.stopped_epoch = 0 self._check_on_train_epoch_end = check_on_train_epoch_end self.log_rank_zero_only = log_rank_zero_only + self.start_from_epoch = start_from_epoch if self.mode not in self.mode_dict: raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") @@ -179,7 +187,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.patience = state_dict["patience"] def _should_skip_check(self, trainer: "pl.Trainer") -> bool: - from lightning.pytorch.trainer.states import TrainerFn + from pytorch_lightning.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking @@ -197,6 +205,10 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: """Checks whether the early stopping condition is met and if so tells the trainer to stop the training.""" + # Skip early stopping check if current epoch is less than start_from_epoch + if trainer.current_epoch < self.start_from_epoch: + return + logs = trainer.callback_metrics if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 9a87b3daaad6e..0ec23d33f3e7c 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -505,3 +505,35 @@ def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, ex log_mock.assert_called_once_with(expected_log) else: log_mock.assert_not_called() + + +def test_early_stopping_start_from_epoch(tmp_path): + """Test that early stopping checks only activate after start_from_epoch.""" + losses = [6, 5, 4, 3, 2, 1] # decreasing losses + start_from_epoch = 3 + + class CurrentModel(BoringModel): + def on_validation_epoch_end(self): + val_loss = losses[self.current_epoch] + self.log("val_loss", val_loss) + + model = CurrentModel() + + # Mock the _run_early_stopping_check method to verify when it's called + with mock.patch("lightning.pytorch.callbacks.early_stopping.EarlyStopping._evaluate_stopping_criteria") as es_mock: + es_mock.return_value = (False, "") + early_stopping = EarlyStopping(monitor="val_loss", start_from_epoch=start_from_epoch) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[early_stopping], + limit_train_batches=0.2, + limit_val_batches=0.2, + max_epochs=len(losses), + ) + trainer.fit(model) + + # Check that _evaluate_stopping_criteria is not called for epochs before start_from_epoch + assert es_mock.call_count == len(losses) - start_from_epoch + # Check that only the correct epochs were processed + for i, call_args in enumerate(es_mock.call_args_list): + assert torch.allclose(call_args[0][0], torch.tensor(losses[i + start_from_epoch]))