diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index cdc420045109e..5376b417e86b2 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623)) +- Fixed `WeightAveraging` swapping in the un-updated average model during validation before its first update, which evaluated the untrained weights during a delayed-start warmup ([#21724](https://github.com/Lightning-AI/pytorch-lightning/issues/21724)) + --- ## [2.6.4] - 2026-05-20 diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index 0640efed3d87b..649647322c3b3 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -223,7 +223,9 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. """ - if self._average_model is not None: + # Only swap in the averaged weights once the average model has actually been updated. Until then it only holds + # the copy of the initial weights made in setup(), so validating with it would discard the trained weights. + if self._average_model is not None and self._average_model.n_averaged > 0: self._swap_models(pl_module) @override @@ -237,7 +239,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. """ - if self._average_model is not None: + if self._average_model is not None and self._average_model.n_averaged > 0: self._swap_models(pl_module) @override diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index cfb066f023af0..736c16e0505de 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -129,8 +129,6 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(avg_fn=get_swa_avg_fn(), **kwargs) self.swap_calls = 0 self.copy_calls = 0 - # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0. - self.first_epoch: Optional[int] = None def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool: return epoch_idx in (3, 5, 7) @@ -148,14 +146,6 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: assert self.swap_calls == 0 assert self.copy_calls == 0 - def on_train_epoch_start(self, trainer: Trainer, *args: Any) -> None: - super().on_train_epoch_start(trainer, *args) - # Since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will not update the - # model and will just call the epoch-level hooks. For that reason, we check that we are not restarting before - # choosing the first epoch. - if self.first_epoch is None and not trainer.fit_loop.restarting: - self.first_epoch = trainer.current_epoch - def on_train_epoch_end(self, trainer: Trainer, *args: Any) -> None: super().on_train_epoch_end(trainer, *args) if trainer.current_epoch < 3: @@ -166,13 +156,16 @@ def on_train_epoch_end(self, trainer: Trainer, *args: Any) -> None: assert self._average_model.n_averaged == 2 else: assert self._average_model.n_averaged == 3 - assert self.swap_calls == (trainer.current_epoch + 1 - self.first_epoch) * 2 + # The model is only swapped during validation once the average model has been updated. The first update happens + # at the end of epoch 3, so the validation epochs that swap are those from epoch 4 onwards (two swaps each). + assert self.swap_calls == max(0, trainer.current_epoch - 3) * 2 assert self.copy_calls == 0 def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_train_end(trainer, pl_module) assert self._average_model.n_averaged == 3 - assert self.swap_calls == (trainer.max_epochs - self.first_epoch) * 2 + # Validation epochs 4 to ``max_epochs - 1`` each swap twice (see ``on_train_epoch_end``). + assert self.swap_calls == max(0, trainer.max_epochs - 1 - 3) * 2 assert self.copy_calls == 1 @@ -388,6 +381,63 @@ def test_ema_weight_averaging_starting_epoch(tmp_path): assert callback._average_model is not None +def test_weight_averaging_no_swap_before_first_update(tmp_path): + """Validation must use the current (trained) weights while the average model has not been updated yet. + + Before the first update, the ``AveragedModel`` only holds the copy of the initial weights made in ``setup()``. + Swapping it in for validation during a delayed-start warmup would discard the trained weights and evaluate the + untrained snapshot instead. See https://github.com/Lightning-AI/pytorch-lightning/issues/21724. + + """ + + class SwapProbeModel(BoringModel): + def __init__(self) -> None: + super().__init__() + self.layer = nn.Linear(32, 2) + self.initial_weight: Optional[Tensor] = None + self.validation_weight: Optional[Tensor] = None + + def on_train_start(self) -> None: + # Snapshot of the weights the average model copies in ``setup()``. + self.initial_weight = self.layer.weight.detach().clone() + + def on_validation_epoch_start(self) -> None: + self.validation_weight = self.layer.weight.detach().clone() + + def configure_optimizers(self): + # A large learning rate guarantees the weights move noticeably away from their initial values. + return torch.optim.SGD(self.parameters(), lr=1.0) + + model = SwapProbeModel() + dataset = RandomDataset(32, 32) + # The update threshold is never reached during this run, so the average model is never updated. + callback = EMAWeightAveraging(update_every_n_steps=1, update_starting_at_step=1000) + trainer = Trainer( + accelerator="cpu", + devices=1, + logger=False, + callbacks=callback, + max_epochs=1, + num_sanity_val_steps=0, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + limit_train_batches=4, + limit_val_batches=1, + deterministic=True, + default_root_dir=tmp_path, + ) + dataloader = DataLoader(dataset, batch_size=4, shuffle=False) + trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=dataloader) + + assert callback._average_model is not None + assert callback._average_model.n_averaged == 0 + assert model.initial_weight is not None + assert model.validation_weight is not None + # Validation must not have swapped in the un-updated (frozen) average model. + assert not torch.allclose(model.validation_weight, model.initial_weight) + + def test_ema_weight_averaging_should_update(tmp_path): """Test the should_update logic of EMAWeightAveraging.""" # Test with step-based updates