Skip to content

Commit efc77dc

Browse files
author
Seppo Enarvi
committed
The user can specify after which steps or epochs the average model is updated
1 parent 11423fd commit efc77dc

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

src/lightning/pytorch/callbacks/weight_averaging.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,22 @@
3030
from lightning.pytorch.utilities.types import STEP_OUTPUT
3131

3232

33+
def _return_true(x: int) -> bool:
34+
return True
35+
36+
37+
def _return_false(x: int) -> bool:
38+
return False
39+
40+
3341
class WeightAveraging(Callback):
3442
r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
3543
(EMA) after each training step.
3644
45+
The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average
46+
model should be updated. If neither function is provided, the average model will be updated after every optimizer
47+
step.
48+
3749
During validation and after the training finishes, the current model parameters will be replaced with the averaged
3850
values.
3951
@@ -43,22 +55,39 @@ class WeightAveraging(Callback):
4355
avg_fn: The averaging function used to update the parameters. The function must take in an
4456
:class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
4557
``None``, an equally weighted average will be used.
58+
update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average
59+
model should be updated.
60+
update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model
61+
should be updated.
4662
4763
"""
4864

4965
def __init__(
5066
self,
51-
device: torch.device | str | None = torch.device("cpu"),
67+
device: torch.device | int | None = torch.device("cpu"),
5268
avg_fn: Callable[[Tensor, Tensor, Tensor | int], Tensor] | None = None,
69+
update_on_step: Callable[[int], bool] | None = None,
70+
update_on_epoch: Callable[[int], bool] | None = None,
5371
):
5472
self._device = device
5573
self._avg_fn = avg_fn
74+
75+
if (update_on_step is None) and (update_on_epoch is None):
76+
self._update_on_step: Callable[[int], bool] = _return_true
77+
self._update_on_epoch: Callable[[int], bool] = _return_false
78+
else:
79+
self._update_on_step = _return_false if update_on_step is None else update_on_step
80+
self._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch
81+
5682
self._average_model: AveragedModel | None = None
5783

5884
# Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
5985
# that the average model will be first updated after the first optimizer step, which takes place after N batches
6086
# when using accumulate_grad_batches=N.
6187
self._latest_update_step = 0
88+
# The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a
89+
# negative value means that if update_on_step(0) returns True, the first update is after the first epoch.
90+
self._latest_update_epoch = -1
6291

6392
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
6493
"""Called when fit, validate, test, predict, or tune begins.
@@ -80,7 +109,7 @@ def on_train_batch_end(
80109
) -> None:
81110
"""Called when a training batch ends.
82111
83-
Updates the :class:`AveragedModel` parameters.
112+
Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step()``.
84113
85114
Args:
86115
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
@@ -90,11 +119,26 @@ def on_train_batch_end(
90119
batch_idx: Index of the training batch.
91120
92121
"""
93-
if trainer.global_step > self._latest_update_step:
122+
if self._update_on_step(trainer.global_step) and (trainer.global_step > self._latest_update_step):
94123
assert self._average_model is not None
95124
self._average_model.update_parameters(pl_module)
96125
self._latest_update_step = trainer.global_step
97126

127+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
128+
"""Called when a training epoch ends.
129+
130+
Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch()``.
131+
132+
Args:
133+
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
134+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
135+
136+
"""
137+
if self._update_on_epoch(trainer.current_epoch) and (trainer.current_epoch > self._latest_update_epoch):
138+
assert self._average_model is not None
139+
self._average_model.update_parameters(pl_module)
140+
self._latest_update_epoch = trainer.current_epoch
141+
98142
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
99143
"""Called when training ends.
100144
@@ -173,6 +217,7 @@ def on_save_checkpoint(
173217
checkpoint: The checkpoint dictionary that will be saved.
174218
175219
"""
220+
assert self._average_model is not None
176221
rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.")
177222
average_model_state = self._average_model.state_dict()
178223
checkpoint["current_model_state"] = checkpoint["state_dict"]
@@ -196,6 +241,7 @@ def on_load_checkpoint(
196241
checkpoint: The full checkpoint dictionary that got loaded by the Trainer.
197242
198243
"""
244+
assert self._average_model is not None
199245
if ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint):
200246
rank_zero_info("Found current_model_state in the checkpoint. This will be used to initialize the model.")
201247
average_model_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()}
@@ -216,6 +262,7 @@ def _swap_models(self, pl_module: "pl.LightningModule") -> None:
216262
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
217263
218264
"""
265+
assert self._average_model is not None
219266
average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers())
220267
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
221268
for average_param, current_param in zip(average_params, current_params):
@@ -230,6 +277,7 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
230277
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
231278
232279
"""
280+
assert self._average_model is not None
233281
average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers())
234282
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
235283
for average_param, current_param in zip(average_params, current_params):

tests/tests_pytorch/callbacks/test_weight_averaging.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
from contextlib import AbstractContextManager
1615
from pathlib import Path
1716
from typing import Any, Optional
18-
from unittest import mock
1917

2018
import pytest
2119
import torch
@@ -25,7 +23,6 @@
2523
from lightning.pytorch import LightningModule, Trainer
2624
from lightning.pytorch.callbacks import WeightAveraging
2725
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
28-
from lightning.pytorch.strategies import Strategy
2926
from tests_pytorch.helpers.runif import RunIf
3027

3128

@@ -209,10 +206,9 @@ def _train(
209206
)
210207

211208
if crash_on_epoch is None:
212-
with _backward_patch(trainer):
213-
trainer.fit(model, ckpt_path=checkpoint_path)
209+
trainer.fit(model, ckpt_path=checkpoint_path)
214210
else:
215-
with _backward_patch(trainer), pytest.raises(Exception, match="CRASH TEST"):
211+
with pytest.raises(Exception, match="CRASH TEST"):
216212
trainer.fit(model, ckpt_path=checkpoint_path)
217213

218214
assert trainer.lightning_module == model
@@ -230,7 +226,3 @@ def _train_and_resume(tmp_path: str, crash_on_epoch: int, use_ddp: bool = False)
230226
checkpoint_path = str(checkpoint_dir / checkpoint_names[0])
231227

232228
_train(tmp_path, strategy=strategy, devices=devices, checkpoint_path=checkpoint_path)
233-
234-
235-
def _backward_patch(trainer: Trainer) -> AbstractContextManager:
236-
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)

0 commit comments

Comments
 (0)