11import math
2+ from collections .abc import Mapping
23from copy import deepcopy
34from typing import Any
45
56import lightning .pytorch as pl
67import torch
78from lightning .pytorch .utilities .types import STEP_OUTPUT
9+ from loguru import logger
810from torch import nn
911
12+ from luxonis_train .utils .checkpoint import filter_checkpoint_state_dict
13+
1014
1115class ModelEma (nn .Module ):
1216 """Model Exponential Moving Average.
@@ -65,13 +69,13 @@ def update(self, model: pl.LightningModule) -> None:
6569 else :
6670 decay = self .decay
6771
72+ model_state_dict = model .state_dict ()
6873 ema_lerp_values = []
6974 model_lerp_values = []
70- for ema_v , model_v in zip (
71- self .state_dict_ema .values (),
72- model .state_dict ().values (),
73- strict = True ,
74- ):
75+ for key , ema_v in self .state_dict_ema .items ():
76+ model_v = model_state_dict .get (key )
77+ if model_v is None :
78+ continue
7579 if ema_v .is_floating_point ():
7680 ema_lerp_values .append (ema_v )
7781 model_lerp_values .append (model_v )
@@ -115,8 +119,13 @@ def __init__(
115119
116120 self ._ema = None
117121 self .loaded_ema_state_dict = None
122+ self .loaded_ema_updates = None
118123 self .collected_state_dict = None
119124
125+ @staticmethod
126+ def _format_key_list (keys : set [str ]) -> str :
127+ return ", " .join (sorted (keys )) if keys else "<none>"
128+
120129 @property
121130 def ema (self ) -> ModelEma :
122131 if self ._ema is None :
@@ -144,12 +153,54 @@ def on_fit_start(
144153 target_device = next (
145154 iter (self ._ema .state_dict_ema .values ())
146155 ).device
147- self .loaded_ema_state_dict = {
148- k : v .to (target_device )
149- for k , v in self .loaded_ema_state_dict .items ()
156+ current_state_dict = self ._ema .state_dict_ema
157+ comparable_current_state_dict = filter_checkpoint_state_dict (
158+ current_state_dict
159+ )
160+ comparable_loaded_state_dict = filter_checkpoint_state_dict (
161+ self .loaded_ema_state_dict
162+ )
163+ current_keys = set (comparable_current_state_dict )
164+ loaded_keys = set (comparable_loaded_state_dict )
165+ missing_in_checkpoint = current_keys - loaded_keys
166+ extra_in_checkpoint = loaded_keys - current_keys
167+ incompatible_shapes = {
168+ key
169+ for key in current_keys & loaded_keys
170+ if comparable_current_state_dict [key ].shape
171+ != comparable_loaded_state_dict [key ].shape
150172 }
151- self ._ema .state_dict_ema = self .loaded_ema_state_dict
173+
174+ if missing_in_checkpoint :
175+ logger .warning (
176+ "EMA checkpoint is missing keys present in the current model. "
177+ "Keeping freshly initialized EMA values for: "
178+ f"{ self ._format_key_list (missing_in_checkpoint )} "
179+ )
180+ if extra_in_checkpoint :
181+ logger .warning (
182+ "EMA checkpoint contains keys not present in the current model. "
183+ "Ignoring: "
184+ f"{ self ._format_key_list (extra_in_checkpoint )} "
185+ )
186+ if incompatible_shapes :
187+ logger .warning (
188+ "EMA checkpoint contains keys with incompatible shapes. "
189+ "Ignoring: "
190+ f"{ self ._format_key_list (incompatible_shapes )} "
191+ )
192+
193+ for key , value in comparable_loaded_state_dict .items ():
194+ if (
195+ key in current_state_dict
196+ and key not in incompatible_shapes
197+ ):
198+ current_state_dict [key ] = value .to (target_device )
199+ self ._ema .state_dict_ema = current_state_dict
200+ if self .loaded_ema_updates is not None :
201+ self ._ema .updates = self .loaded_ema_updates
152202 self .loaded_ema_state_dict = None
203+ self .loaded_ema_updates = None
153204
154205 def on_train_batch_end (
155206 self ,
@@ -248,7 +299,7 @@ def on_save_checkpoint(
248299 trainer : pl .Trainer ,
249300 pl_module : pl .LightningModule ,
250301 checkpoint : dict ,
251- ) -> None : # or dict?
302+ ) -> None :
252303 """Save the EMA state dictionary into the checkpoint.
253304
254305 @type trainer: L{pl.Trainer}
@@ -261,6 +312,19 @@ def on_save_checkpoint(
261312 if self ._ema is not None :
262313 checkpoint ["state_dict" ] = self ._ema .state_dict_ema
263314
315+ def state_dict (self ) -> dict [str , Any ]:
316+ if self ._ema is None :
317+ return {}
318+ return {
319+ "ema_state_dict" : filter_checkpoint_state_dict (
320+ self ._ema .state_dict_ema
321+ ),
322+ "updates" : self ._ema .updates ,
323+ }
324+
325+ def load_state_dict (self , state_dict : dict [str , Any ]) -> None :
326+ self ._load_ema_state (state_dict )
327+
264328 def on_load_checkpoint (
265329 self ,
266330 trainer : pl .Trainer ,
@@ -272,8 +336,18 @@ def on_load_checkpoint(
272336 @type callback_state: dict
273337 @param callback_state: Pytorch Lightning callback state.
274338 """
275- if callback_state and "state_dict" in callback_state :
276- self .loaded_ema_state_dict = callback_state ["state_dict" ]
339+ self ._load_ema_state (callback_state )
340+
341+ def _load_ema_state (self , state_dict : dict [str , Any ]) -> None :
342+ if state_dict :
343+ loaded_state_dict = state_dict .get (
344+ "ema_state_dict" , state_dict .get ("state_dict" )
345+ )
346+ if isinstance (loaded_state_dict , Mapping ):
347+ self .loaded_ema_state_dict = loaded_state_dict
348+ updates = state_dict .get ("updates" )
349+ if isinstance (updates , int ):
350+ self .loaded_ema_updates = updates
277351
278352 def _swap_to_ema_weights (self , pl_module : pl .LightningModule ) -> None :
279353 """Swap the current model weights with the EMA weights.
0 commit comments