Skip to content

Commit cbf2a77

Browse files
authored
Add EMA callback (#74)
* 👽 pytorch WeightAveraging callback which will become available in pytorch-lightning 2.5.4 * 🚨 Fix linting issues in the WeightAveraging callback. * ✨ Add an EMAWeightAveragingCallback that uses the upstream WeightAveraging callback but has a simpler configuration interface * 🔧 Add the EMAWeightAveragingCallback to our set of default callbacks * ✨ Make decay rate configurable
1 parent 9e9a79b commit cbf2a77

5 files changed

Lines changed: 469 additions & 1 deletion

File tree

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
from .ema_weight_averaging_callback import EMAWeightAveragingCallback
12
from .metric_summary_callback import MetricSummaryCallback
23
from .plotting_callback import PlottingCallback
34
from .unconditional_checkpoint import UnconditionalCheckpoint
45

5-
__all__ = ["MetricSummaryCallback", "PlottingCallback", "UnconditionalCheckpoint"]
6+
__all__ = [
7+
"EMAWeightAveragingCallback",
8+
"MetricSummaryCallback",
9+
"PlottingCallback",
10+
"UnconditionalCheckpoint",
11+
]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from torch.optim.swa_utils import get_ema_multi_avg_fn
2+
3+
from .weight_averaging import WeightAveraging
4+
5+
6+
class EMAWeightAveragingCallback(WeightAveraging):
7+
"""A callback that updates an averaged model for Exponential Moving Average (EMA) after each training step."""
8+
9+
def __init__(
10+
self,
11+
*,
12+
decay_rate: float,
13+
every_n_epochs: int | None = None,
14+
every_n_steps: int | None = None,
15+
) -> None:
16+
"""Summarise metrics during evaluation.
17+
18+
Args:
19+
decay_rate: Parameter update decay rate.
20+
every_n_epochs: How many epochs to wait before updating.
21+
every_n_steps: How many steps to wait before updating.
22+
23+
"""
24+
super().__init__(
25+
multi_avg_fn=get_ema_multi_avg_fn(decay_rate), use_buffers=True
26+
)
27+
self.every_n_epochs = every_n_epochs
28+
self.every_n_steps = every_n_steps
29+
30+
def should_update(
31+
self, step_idx: int | None = None, epoch_idx: int | None = None
32+
) -> bool:
33+
"""Update if we are at the requested number of steps or epochs."""
34+
if self.every_n_epochs and epoch_idx:
35+
return epoch_idx % self.every_n_epochs == 0
36+
37+
if self.every_n_steps and step_idx:
38+
return step_idx % self.every_n_steps == 0
39+
40+
return False

0 commit comments

Comments
 (0)