30
30
from lightning .pytorch .utilities .types import STEP_OUTPUT
31
31
32
32
33
+ def _return_true (x : int ) -> bool :
34
+ return True
35
+
36
+
37
+ def _return_false (x : int ) -> bool :
38
+ return False
39
+
40
+
33
41
class WeightAveraging (Callback ):
34
42
r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
35
43
(EMA) after each training step.
36
44
45
+ The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average
46
+ model should be updated. If neither function is provided, the average model will be updated after every optimizer
47
+ step.
48
+
37
49
During validation and after the training finishes, the current model parameters will be replaced with the averaged
38
50
values.
39
51
@@ -43,22 +55,39 @@ class WeightAveraging(Callback):
43
55
avg_fn: The averaging function used to update the parameters. The function must take in an
44
56
:class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
45
57
``None``, an equally weighted average will be used.
58
+ update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average
59
+ model should be updated.
60
+ update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model
61
+ should be updated.
46
62
47
63
"""
48
64
49
65
def __init__ (
50
66
self ,
51
- device : torch .device | str | None = torch .device ("cpu" ),
67
+ device : torch .device | int | None = torch .device ("cpu" ),
52
68
avg_fn : Callable [[Tensor , Tensor , Tensor | int ], Tensor ] | None = None ,
69
+ update_on_step : Callable [[int ], bool ] | None = None ,
70
+ update_on_epoch : Callable [[int ], bool ] | None = None ,
53
71
):
54
72
self ._device = device
55
73
self ._avg_fn = avg_fn
74
+
75
+ if (update_on_step is None ) and (update_on_epoch is None ):
76
+ self ._update_on_step : Callable [[int ], bool ] = _return_true
77
+ self ._update_on_epoch : Callable [[int ], bool ] = _return_false
78
+ else :
79
+ self ._update_on_step = _return_false if update_on_step is None else update_on_step
80
+ self ._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch
81
+
56
82
self ._average_model : AveragedModel | None = None
57
83
58
84
# Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
59
85
# that the average model will be first updated after the first optimizer step, which takes place after N batches
60
86
# when using accumulate_grad_batches=N.
61
87
self ._latest_update_step = 0
88
+ # The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a
89
+ # negative value means that if update_on_step(0) returns True, the first update is after the first epoch.
90
+ self ._latest_update_epoch = - 1
62
91
63
92
def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
64
93
"""Called when fit, validate, test, predict, or tune begins.
@@ -80,7 +109,7 @@ def on_train_batch_end(
80
109
) -> None :
81
110
"""Called when a training batch ends.
82
111
83
- Updates the :class:`AveragedModel` parameters.
112
+ Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step()`` .
84
113
85
114
Args:
86
115
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
@@ -90,11 +119,26 @@ def on_train_batch_end(
90
119
batch_idx: Index of the training batch.
91
120
92
121
"""
93
- if trainer .global_step > self ._latest_update_step :
122
+ if self . _update_on_step ( trainer .global_step ) and ( trainer . global_step > self ._latest_update_step ) :
94
123
assert self ._average_model is not None
95
124
self ._average_model .update_parameters (pl_module )
96
125
self ._latest_update_step = trainer .global_step
97
126
127
+ def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
128
+ """Called when a training epoch ends.
129
+
130
+ Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch()``.
131
+
132
+ Args:
133
+ trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
134
+ pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
135
+
136
+ """
137
+ if self ._update_on_epoch (trainer .current_epoch ) and (trainer .current_epoch > self ._latest_update_epoch ):
138
+ assert self ._average_model is not None
139
+ self ._average_model .update_parameters (pl_module )
140
+ self ._latest_update_epoch = trainer .current_epoch
141
+
98
142
def on_train_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
99
143
"""Called when training ends.
100
144
@@ -173,6 +217,7 @@ def on_save_checkpoint(
173
217
checkpoint: The checkpoint dictionary that will be saved.
174
218
175
219
"""
220
+ assert self ._average_model is not None
176
221
rank_zero_info ("The average model parameters will be saved to the state_dict in the checkpoint." )
177
222
average_model_state = self ._average_model .state_dict ()
178
223
checkpoint ["current_model_state" ] = checkpoint ["state_dict" ]
@@ -196,6 +241,7 @@ def on_load_checkpoint(
196
241
checkpoint: The full checkpoint dictionary that got loaded by the Trainer.
197
242
198
243
"""
244
+ assert self ._average_model is not None
199
245
if ("current_model_state" in checkpoint ) and ("averaging_state" in checkpoint ):
200
246
rank_zero_info ("Found current_model_state in the checkpoint. This will be used to initialize the model." )
201
247
average_model_state = {"module." + name : value for name , value in checkpoint ["state_dict" ].items ()}
@@ -216,6 +262,7 @@ def _swap_models(self, pl_module: "pl.LightningModule") -> None:
216
262
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
217
263
218
264
"""
265
+ assert self ._average_model is not None
219
266
average_params = itertools .chain (self ._average_model .module .parameters (), self ._average_model .module .buffers ())
220
267
current_params = itertools .chain (pl_module .parameters (), pl_module .buffers ())
221
268
for average_param , current_param in zip (average_params , current_params ):
@@ -230,6 +277,7 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
230
277
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
231
278
232
279
"""
280
+ assert self ._average_model is not None
233
281
average_params = itertools .chain (self ._average_model .module .parameters (), self ._average_model .module .buffers ())
234
282
current_params = itertools .chain (pl_module .parameters (), pl_module .buffers ())
235
283
for average_param , current_param in zip (average_params , current_params ):
0 commit comments