Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623))

- Fixed `WeightAveraging` swapping in the un-updated average model during validation before its first update, which evaluated the untrained weights during a delayed-start warmup ([#21724](https://github.com/Lightning-AI/pytorch-lightning/issues/21724))

---

## [2.6.4] - 2026-05-20
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/pytorch/callbacks/weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.

"""
if self._average_model is not None:
# Only swap in the averaged weights once the average model has actually been updated. Until then it only holds
# the copy of the initial weights made in setup(), so validating with it would discard the trained weights.
if self._average_model is not None and self._average_model.n_averaged > 0:
self._swap_models(pl_module)

@override
Expand All @@ -237,7 +239,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.

"""
if self._average_model is not None:
if self._average_model is not None and self._average_model.n_averaged > 0:
self._swap_models(pl_module)

@override
Expand Down
74 changes: 62 additions & 12 deletions tests/tests_pytorch/callbacks/test_weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ def __init__(self, **kwargs: Any) -> None:
super().__init__(avg_fn=get_swa_avg_fn(), **kwargs)
self.swap_calls = 0
self.copy_calls = 0
# Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0.
self.first_epoch: Optional[int] = None

def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool:
return epoch_idx in (3, 5, 7)
Expand All @@ -148,14 +146,6 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
assert self.swap_calls == 0
assert self.copy_calls == 0

def on_train_epoch_start(self, trainer: Trainer, *args: Any) -> None:
super().on_train_epoch_start(trainer, *args)
# Since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will not update the
# model and will just call the epoch-level hooks. For that reason, we check that we are not restarting before
# choosing the first epoch.
if self.first_epoch is None and not trainer.fit_loop.restarting:
self.first_epoch = trainer.current_epoch

def on_train_epoch_end(self, trainer: Trainer, *args: Any) -> None:
super().on_train_epoch_end(trainer, *args)
if trainer.current_epoch < 3:
Expand All @@ -166,13 +156,16 @@ def on_train_epoch_end(self, trainer: Trainer, *args: Any) -> None:
assert self._average_model.n_averaged == 2
else:
assert self._average_model.n_averaged == 3
assert self.swap_calls == (trainer.current_epoch + 1 - self.first_epoch) * 2
# The model is only swapped during validation once the average model has been updated. The first update happens
# at the end of epoch 3, so the validation epochs that swap are those from epoch 4 onwards (two swaps each).
assert self.swap_calls == max(0, trainer.current_epoch - 3) * 2
assert self.copy_calls == 0

def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_train_end(trainer, pl_module)
assert self._average_model.n_averaged == 3
assert self.swap_calls == (trainer.max_epochs - self.first_epoch) * 2
# Validation epochs 4 to ``max_epochs - 1`` each swap twice (see ``on_train_epoch_end``).
assert self.swap_calls == max(0, trainer.max_epochs - 1 - 3) * 2
assert self.copy_calls == 1


Expand Down Expand Up @@ -388,6 +381,63 @@ def test_ema_weight_averaging_starting_epoch(tmp_path):
assert callback._average_model is not None


def test_weight_averaging_no_swap_before_first_update(tmp_path):
"""Validation must use the current (trained) weights while the average model has not been updated yet.

Before the first update, the ``AveragedModel`` only holds the copy of the initial weights made in ``setup()``.
Swapping it in for validation during a delayed-start warmup would discard the trained weights and evaluate the
untrained snapshot instead. See https://github.com/Lightning-AI/pytorch-lightning/issues/21724.

"""

class SwapProbeModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.layer = nn.Linear(32, 2)
self.initial_weight: Optional[Tensor] = None
self.validation_weight: Optional[Tensor] = None

def on_train_start(self) -> None:
# Snapshot of the weights the average model copies in ``setup()``.
self.initial_weight = self.layer.weight.detach().clone()

def on_validation_epoch_start(self) -> None:
self.validation_weight = self.layer.weight.detach().clone()

def configure_optimizers(self):
# A large learning rate guarantees the weights move noticeably away from their initial values.
return torch.optim.SGD(self.parameters(), lr=1.0)

model = SwapProbeModel()
dataset = RandomDataset(32, 32)
# The update threshold is never reached during this run, so the average model is never updated.
callback = EMAWeightAveraging(update_every_n_steps=1, update_starting_at_step=1000)
trainer = Trainer(
accelerator="cpu",
devices=1,
logger=False,
callbacks=callback,
max_epochs=1,
num_sanity_val_steps=0,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
limit_train_batches=4,
limit_val_batches=1,
deterministic=True,
default_root_dir=tmp_path,
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=dataloader)

assert callback._average_model is not None
assert callback._average_model.n_averaged == 0
assert model.initial_weight is not None
assert model.validation_weight is not None
# Validation must not have swapped in the un-updated (frozen) average model.
assert not torch.allclose(model.validation_weight, model.initial_weight)


def test_ema_weight_averaging_should_update(tmp_path):
"""Test the should_update logic of EMAWeightAveraging."""
# Test with step-based updates
Expand Down