Skip to content

Commit 0010492

Browse files
author
Seppo Enarvi
committed
Weight averaging callback
* A callback that updates a torch.optim.swa_utils.AveragedModel after specific steps or epochs. * The user can provide a callback that defines after which steps or epochs the average model is updated.
1 parent a944e77 commit 0010492

File tree

5 files changed

+581
-0
lines changed

5 files changed

+581
-0
lines changed

docs/source-pytorch/api_references.rst

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ callbacks
4848
ThroughputMonitor
4949
Timer
5050
TQDMProgressBar
51+
WeightAveraging
5152

5253
cli
5354
-----

docs/source-pytorch/extensions/callbacks.rst

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Lightning has a few built-in callbacks.
8383
StochasticWeightAveraging
8484
Timer
8585
TQDMProgressBar
86+
WeightAveraging
8687

8788
----------
8889

src/lightning/pytorch/callbacks/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
3333
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
3434
from lightning.pytorch.callbacks.timer import Timer
35+
from lightning.pytorch.callbacks.weight_averaging import WeightAveraging
3536

3637
__all__ = [
3738
"BackboneFinetuning",
@@ -58,4 +59,5 @@
5859
"ThroughputMonitor",
5960
"Timer",
6061
"TQDMProgressBar",
62+
"WeightAveraging",
6163
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
r"""
15+
Weight Averaging Callback
16+
^^^^^^^^^^^^^^^^^^^^^^^^^
17+
"""
18+
19+
import itertools
20+
from copy import deepcopy
21+
from typing import Any, Callable
22+
23+
import torch
24+
from torch import Tensor
25+
from torch.optim.swa_utils import AveragedModel
26+
27+
import lightning.pytorch as pl
28+
from lightning.pytorch.callbacks.callback import Callback
29+
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
30+
from lightning.pytorch.utilities.types import STEP_OUTPUT
31+
32+
33+
def _return_true(x: int) -> bool:
34+
return True
35+
36+
37+
def _return_false(x: int) -> bool:
38+
return False
39+
40+
41+
class WeightAveraging(Callback):
42+
r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
43+
(EMA) after each training step.
44+
45+
The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average
46+
model should be updated. If neither function is provided, the average model will be updated after every optimizer
47+
step.
48+
49+
During validation and after the training finishes, the current model parameters will be replaced with the averaged
50+
values.
51+
52+
Args:
53+
device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be
54+
inferred from the original model.
55+
avg_fn: The averaging function used to update the parameters. The function must take in an
56+
:class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
57+
``None``, an equally weighted average will be used.
58+
update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average
59+
model should be updated.
60+
update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model
61+
should be updated.
62+
63+
"""
64+
65+
def __init__(
66+
self,
67+
device: torch.device | int | None = torch.device("cpu"),
68+
avg_fn: Callable[[Tensor, Tensor, Tensor | int], Tensor] | None = None,
69+
update_on_step: Callable[[int], bool] | None = None,
70+
update_on_epoch: Callable[[int], bool] | None = None,
71+
):
72+
self._device = device
73+
self._avg_fn = avg_fn
74+
75+
if (update_on_step is None) and (update_on_epoch is None):
76+
self._update_on_step: Callable[[int], bool] = _return_true
77+
self._update_on_epoch: Callable[[int], bool] = _return_false
78+
else:
79+
self._update_on_step = _return_false if update_on_step is None else update_on_step
80+
self._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch
81+
82+
self._average_model: AveragedModel | None = None
83+
84+
# Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
85+
# that the average model will be first updated after the first optimizer step, which takes place after N batches
86+
# when using accumulate_grad_batches=N.
87+
self._latest_update_step = 0
88+
# The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a
89+
# negative value means that if update_on_step(0) returns True, the first update is after the first epoch.
90+
self._latest_update_epoch = -1
91+
92+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
93+
"""Called when fit, validate, test, predict, or tune begins.
94+
95+
Creates an :class:`AveragedModel` when fit begins.
96+
97+
Args:
98+
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
99+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
100+
stage: The :class:`~lightning.pytorch.trainer.trainer.Trainer` state.
101+
102+
"""
103+
if stage == "fit":
104+
device = self._device or pl_module.device
105+
self._average_model = AveragedModel(model=pl_module, device=device, avg_fn=self._avg_fn, use_buffers=True)
106+
107+
def on_train_batch_end(
108+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
109+
) -> None:
110+
"""Called when a training batch ends.
111+
112+
Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step()``.
113+
114+
Args:
115+
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
116+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
117+
outputs: Outputs from the training batch.
118+
batch: The training batch.
119+
batch_idx: Index of the training batch.
120+
121+
"""
122+
if self._update_on_step(trainer.global_step) and (trainer.global_step > self._latest_update_step):
123+
assert self._average_model is not None
124+
self._average_model.update_parameters(pl_module)
125+
self._latest_update_step = trainer.global_step
126+
127+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
128+
"""Called when a training epoch ends.
129+
130+
Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch()``.
131+
132+
Args:
133+
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
134+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
135+
136+
"""
137+
if self._update_on_epoch(trainer.current_epoch) and (trainer.current_epoch > self._latest_update_epoch):
138+
assert self._average_model is not None
139+
self._average_model.update_parameters(pl_module)
140+
self._latest_update_epoch = trainer.current_epoch
141+
142+
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
143+
"""Called when training ends.
144+
145+
Transfers parameters from the :class:`AveragedModel` to the current model.
146+
147+
Args:
148+
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
149+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
150+
151+
"""
152+
assert self._average_model is not None
153+
self._copy_average_to_current(pl_module)
154+
155+
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
156+
"""Called when a validation epoch begins.
157+
158+
Transfers parameter values from the :class:`AveragedModel` to the current model.
159+
160+
Args:
161+
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
162+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
163+
164+
"""
165+
if self._average_model is not None:
166+
rank_zero_info("Loading the average model parameters for validation.")
167+
self._swap_models(pl_module)
168+
169+
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
170+
"""Called when a validation epoch ends.
171+
172+
Recovers the current model parameters from the :class:`AveragedModel`.
173+
174+
Args:
175+
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
176+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
177+
178+
"""
179+
if self._average_model is not None:
180+
rank_zero_info("Recovering the current model parameters after validation.")
181+
self._swap_models(pl_module)
182+
183+
def state_dict(self) -> dict[str, Any]:
184+
"""Called when saving a checkpoint.
185+
186+
Creates a ``state_dict`` of the callback state.
187+
188+
Returns:
189+
A dictionary containing the callback state.
190+
191+
"""
192+
return {"latest_update_step": self._latest_update_step}
193+
194+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
195+
"""Called when loading a checkpoint.
196+
197+
Reloads the callback state given a ``state_dict``.
198+
199+
Args:
200+
state_dict: A dictionary containing the callback state.
201+
202+
"""
203+
self._latest_update_step = state_dict["latest_update_step"]
204+
205+
def on_save_checkpoint(
206+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any]
207+
) -> None:
208+
r"""Called when saving a checkpoint.
209+
210+
Moves the current model state to the key ``current_model_state``, and places the average model state in
211+
``state_dict`` instead. Any other state variables of the ``AveragedModel`` will be saved in
212+
``averaging_state``.
213+
214+
Args:
215+
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
216+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
217+
checkpoint: The checkpoint dictionary that will be saved.
218+
219+
"""
220+
if self._average_model is None:
221+
raise Exception("Trying to save a checkpoint, but no average model (outside fit). Don't know what to do.")
222+
223+
rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.")
224+
average_model_state = self._average_model.state_dict()
225+
checkpoint["current_model_state"] = checkpoint["state_dict"]
226+
checkpoint["state_dict"] = {
227+
name[7:]: value for name, value in average_model_state.items() if name.startswith("module.")
228+
}
229+
checkpoint["averaging_state"] = {
230+
name: value for name, value in average_model_state.items() if not name.startswith("module.")
231+
}
232+
233+
def on_load_checkpoint(
234+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any]
235+
) -> None:
236+
r"""Called when loading a model checkpoint.
237+
238+
Loads the current model and the :class:`AveragedModel` parameters from the checkpoint.
239+
240+
Args:
241+
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
242+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
243+
checkpoint: The full checkpoint dictionary that got loaded by the Trainer.
244+
245+
"""
246+
if self._average_model is None:
247+
raise Exception("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do.")
248+
249+
if ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint):
250+
rank_zero_info("Found current_model_state in the checkpoint. This will be used to initialize the model.")
251+
average_model_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()}
252+
average_model_state |= checkpoint["averaging_state"]
253+
self._average_model.load_state_dict(average_model_state)
254+
checkpoint["state_dict"] = checkpoint["current_model_state"]
255+
else:
256+
rank_zero_warn(
257+
"The checkpoint was not created with WeightAveraging. Both the current and the average model will be "
258+
"initialized with state_dict."
259+
)
260+
self._average_model.module.load_state_dict(deepcopy(checkpoint["state_dict"]), strict=False)
261+
262+
def _swap_models(self, pl_module: "pl.LightningModule") -> None:
263+
"""Swaps the parameter values of the current model and the :class:`AveragedModel`.
264+
265+
Args:
266+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
267+
268+
"""
269+
assert self._average_model is not None
270+
average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers())
271+
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
272+
for average_param, current_param in zip(average_params, current_params):
273+
tmp = average_param.data.clone()
274+
average_param.data.copy_(current_param.data)
275+
current_param.data.copy_(tmp)
276+
277+
def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
278+
"""Copies the parameter values from the :class:`AveragedModel` to the current model.
279+
280+
Args:
281+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
282+
283+
"""
284+
assert self._average_model is not None
285+
average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers())
286+
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
287+
for average_param, current_param in zip(average_params, current_params):
288+
current_param.data.copy_(average_param.data)

0 commit comments

Comments
 (0)