Skip to content

Commit 6548c2d

Browse files
dtronmanskozlov721
andauthored
Fix: resume_training: True and EMACallback active (#372)
Co-authored-by: Martin Kozlovsky <martin.kozlovsky@luxonis.com>
1 parent 9d4e264 commit 6548c2d

7 files changed

Lines changed: 179 additions & 30 deletions

File tree

luxonis_train/callbacks/ema.py

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import math
2+
from collections.abc import Mapping
23
from copy import deepcopy
34
from typing import Any
45

56
import lightning.pytorch as pl
67
import torch
78
from lightning.pytorch.utilities.types import STEP_OUTPUT
9+
from loguru import logger
810
from torch import nn
911

12+
from luxonis_train.utils.checkpoint import filter_checkpoint_state_dict
13+
1014

1115
class ModelEma(nn.Module):
1216
"""Model Exponential Moving Average.
@@ -65,13 +69,13 @@ def update(self, model: pl.LightningModule) -> None:
6569
else:
6670
decay = self.decay
6771

72+
model_state_dict = model.state_dict()
6873
ema_lerp_values = []
6974
model_lerp_values = []
70-
for ema_v, model_v in zip(
71-
self.state_dict_ema.values(),
72-
model.state_dict().values(),
73-
strict=True,
74-
):
75+
for key, ema_v in self.state_dict_ema.items():
76+
model_v = model_state_dict.get(key)
77+
if model_v is None:
78+
continue
7579
if ema_v.is_floating_point():
7680
ema_lerp_values.append(ema_v)
7781
model_lerp_values.append(model_v)
@@ -115,8 +119,13 @@ def __init__(
115119

116120
self._ema = None
117121
self.loaded_ema_state_dict = None
122+
self.loaded_ema_updates = None
118123
self.collected_state_dict = None
119124

125+
@staticmethod
126+
def _format_key_list(keys: set[str]) -> str:
127+
return ", ".join(sorted(keys)) if keys else "<none>"
128+
120129
@property
121130
def ema(self) -> ModelEma:
122131
if self._ema is None:
@@ -144,12 +153,54 @@ def on_fit_start(
144153
target_device = next(
145154
iter(self._ema.state_dict_ema.values())
146155
).device
147-
self.loaded_ema_state_dict = {
148-
k: v.to(target_device)
149-
for k, v in self.loaded_ema_state_dict.items()
156+
current_state_dict = self._ema.state_dict_ema
157+
comparable_current_state_dict = filter_checkpoint_state_dict(
158+
current_state_dict
159+
)
160+
comparable_loaded_state_dict = filter_checkpoint_state_dict(
161+
self.loaded_ema_state_dict
162+
)
163+
current_keys = set(comparable_current_state_dict)
164+
loaded_keys = set(comparable_loaded_state_dict)
165+
missing_in_checkpoint = current_keys - loaded_keys
166+
extra_in_checkpoint = loaded_keys - current_keys
167+
incompatible_shapes = {
168+
key
169+
for key in current_keys & loaded_keys
170+
if comparable_current_state_dict[key].shape
171+
!= comparable_loaded_state_dict[key].shape
150172
}
151-
self._ema.state_dict_ema = self.loaded_ema_state_dict
173+
174+
if missing_in_checkpoint:
175+
logger.warning(
176+
"EMA checkpoint is missing keys present in the current model. "
177+
"Keeping freshly initialized EMA values for: "
178+
f"{self._format_key_list(missing_in_checkpoint)}"
179+
)
180+
if extra_in_checkpoint:
181+
logger.warning(
182+
"EMA checkpoint contains keys not present in the current model. "
183+
"Ignoring: "
184+
f"{self._format_key_list(extra_in_checkpoint)}"
185+
)
186+
if incompatible_shapes:
187+
logger.warning(
188+
"EMA checkpoint contains keys with incompatible shapes. "
189+
"Ignoring: "
190+
f"{self._format_key_list(incompatible_shapes)}"
191+
)
192+
193+
for key, value in comparable_loaded_state_dict.items():
194+
if (
195+
key in current_state_dict
196+
and key not in incompatible_shapes
197+
):
198+
current_state_dict[key] = value.to(target_device)
199+
self._ema.state_dict_ema = current_state_dict
200+
if self.loaded_ema_updates is not None:
201+
self._ema.updates = self.loaded_ema_updates
152202
self.loaded_ema_state_dict = None
203+
self.loaded_ema_updates = None
153204

154205
def on_train_batch_end(
155206
self,
@@ -248,7 +299,7 @@ def on_save_checkpoint(
248299
trainer: pl.Trainer,
249300
pl_module: pl.LightningModule,
250301
checkpoint: dict,
251-
) -> None: # or dict?
302+
) -> None:
252303
"""Save the EMA state dictionary into the checkpoint.
253304
254305
@type trainer: L{pl.Trainer}
@@ -261,6 +312,19 @@ def on_save_checkpoint(
261312
if self._ema is not None:
262313
checkpoint["state_dict"] = self._ema.state_dict_ema
263314

315+
def state_dict(self) -> dict[str, Any]:
316+
if self._ema is None:
317+
return {}
318+
return {
319+
"ema_state_dict": filter_checkpoint_state_dict(
320+
self._ema.state_dict_ema
321+
),
322+
"updates": self._ema.updates,
323+
}
324+
325+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
326+
self._load_ema_state(state_dict)
327+
264328
def on_load_checkpoint(
265329
self,
266330
trainer: pl.Trainer,
@@ -272,8 +336,18 @@ def on_load_checkpoint(
272336
@type callback_state: dict
273337
@param callback_state: Pytorch Lightning callback state.
274338
"""
275-
if callback_state and "state_dict" in callback_state:
276-
self.loaded_ema_state_dict = callback_state["state_dict"]
339+
self._load_ema_state(callback_state)
340+
341+
def _load_ema_state(self, state_dict: dict[str, Any]) -> None:
342+
if state_dict:
343+
loaded_state_dict = state_dict.get(
344+
"ema_state_dict", state_dict.get("state_dict")
345+
)
346+
if isinstance(loaded_state_dict, Mapping):
347+
self.loaded_ema_state_dict = loaded_state_dict
348+
updates = state_dict.get("updates")
349+
if isinstance(updates, int):
350+
self.loaded_ema_updates = updates
277351

278352
def _swap_to_ema_weights(self, pl_module: pl.LightningModule) -> None:
279353
"""Swap the current model weights with the EMA weights.

luxonis_train/lightning/luxonis_lightning.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import re
21
from collections import defaultdict
32
from collections.abc import Callable, Mapping
43
from pathlib import Path
@@ -26,6 +25,7 @@
2625
from luxonis_train.nodes import BaseNode
2726
from luxonis_train.typing import Labels, Packet
2827
from luxonis_train.utils import DatasetMetadata, LuxonisTrackerPL
28+
from luxonis_train.utils.checkpoint import filter_checkpoint_state_dict
2929

3030
from .luxonis_output import LuxonisOutput
3131
from .utils import (
@@ -1027,14 +1027,9 @@ def _strip_state_prefix(key: str) -> str:
10271027
def _add_custom_data_to_checkpoint(
10281028
self, checkpoint: dict[str, Any]
10291029
) -> None:
1030-
pattern = re.compile(
1031-
r"^nodes\.[^.]+\.(metrics|visualizers|losses)\..*_node\..*"
1030+
checkpoint["state_dict"] = filter_checkpoint_state_dict(
1031+
checkpoint["state_dict"]
10321032
)
1033-
checkpoint["state_dict"] = {
1034-
k: v
1035-
for k, v in checkpoint["state_dict"].items()
1036-
if not pattern.match(k)
1037-
}
10381033
checkpoint |= {
10391034
"version": luxonis_train.__version__,
10401035
"execution_order": get_model_execution_order(self),

luxonis_train/lightning/utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -493,16 +493,6 @@ def build_callbacks(
493493
"in the callbacks list. The `accumulate_grad_batches` "
494494
"parameter in the config will be ignored."
495495
)
496-
callbacks.append(
497-
ModelCheckpoint(
498-
dirpath=save_dir / "min_val_loss",
499-
filename=f"{model_name}_loss={{val/loss:.4f}}_{{epoch:02d}}",
500-
monitor="val/loss",
501-
auto_insert_metric_name=False,
502-
save_top_k=cfg.trainer.save_top_k,
503-
mode="min",
504-
),
505-
)
506496
if main_metric is not None:
507497
node_name, metric_name = main_metric
508498
formatted_node = nodes.formatted_name(node_name)
@@ -521,6 +511,16 @@ def build_callbacks(
521511
)
522512
)
523513

514+
callbacks.append(
515+
ModelCheckpoint(
516+
dirpath=save_dir / "min_val_loss",
517+
filename=f"{model_name}_loss={{val/loss:.4f}}_{{epoch:02d}}",
518+
monitor="val/loss",
519+
auto_insert_metric_name=False,
520+
save_top_k=cfg.trainer.save_top_k,
521+
mode="min",
522+
)
523+
)
524524
return callbacks
525525

526526

luxonis_train/utils/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
keypoints_to_bboxes,
1010
non_max_suppression,
1111
)
12+
from .checkpoint import (
13+
CHECKPOINT_FILTERED_STATE_DICT_PATTERN,
14+
filter_checkpoint_state_dict,
15+
)
1216
from .dataset_metadata import DatasetMetadata
1317
from .exceptions import IncompatibleError
1418
from .general import (
@@ -41,6 +45,7 @@
4145
from .tracker import LuxonisTrackerPL
4246

4347
__all__ = [
48+
"CHECKPOINT_FILTERED_STATE_DICT_PATTERN",
4449
"Counter",
4550
"DatasetMetadata",
4651
"IncompatibleError",
@@ -55,6 +60,7 @@
5560
"compute_pose_oks",
5661
"default_annotate",
5762
"dist2bbox",
63+
"filter_checkpoint_state_dict",
5864
"get_attribute_check_none",
5965
"get_batch_instances",
6066
"get_batch_instances",

luxonis_train/utils/checkpoint.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import re
2+
from collections.abc import Mapping
3+
4+
from torch import Tensor
5+
6+
CHECKPOINT_FILTERED_STATE_DICT_PATTERN = re.compile(
7+
r"^nodes\.[^.]+\.(metrics|visualizers|losses)\..*_node\..*"
8+
)
9+
10+
11+
def filter_checkpoint_state_dict(
12+
state_dict: Mapping[str, Tensor],
13+
) -> dict[str, Tensor]:
14+
return {
15+
key: value
16+
for key, value in state_dict.items()
17+
if not CHECKPOINT_FILTERED_STATE_DICT_PATTERN.match(key)
18+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from pathlib import Path
2+
3+
from luxonis_ml.data import LuxonisDataset
4+
from luxonis_ml.typing import Params
5+
6+
from luxonis_train.core import LuxonisModel
7+
8+
9+
def test_resume_training_with_ema_does_not_crash(
10+
parking_lot_dataset: LuxonisDataset, opts: Params, tmp_path: Path
11+
):
12+
config_file = "configs/detection_light_model.yaml"
13+
save_dir = tmp_path / "save-directory"
14+
15+
train_opts = opts | {
16+
"loader.params.dataset_name": parking_lot_dataset.identifier,
17+
"loader.train_view": "train",
18+
"loader.val_view": "train",
19+
"loader.test_view": "train",
20+
"model.predefined_model.params.task_name": "vehicles",
21+
"trainer.overfit_batches": 1,
22+
"trainer.seed": 42,
23+
"trainer.deterministic": "warn",
24+
"trainer.epochs": 1,
25+
"trainer.validation_interval": 1,
26+
"tracker.save_directory": str(save_dir),
27+
"trainer.callbacks": [
28+
{
29+
"name": "EMACallback",
30+
"active": True,
31+
"params": {"decay": 0.9999},
32+
},
33+
{"name": "TestOnTrainEnd", "active": False},
34+
{"name": "ExportOnTrainEnd", "active": False},
35+
{"name": "ArchiveOnTrainEnd", "active": False},
36+
{"name": "ConvertOnTrainEnd", "active": False},
37+
{"name": "UploadCheckpoint", "active": False},
38+
],
39+
}
40+
41+
model = LuxonisModel(config_file, train_opts)
42+
model.train()
43+
44+
ckpt_path = model.get_best_metric_checkpoint_path()
45+
assert ckpt_path, "No checkpoint found after initial training"
46+
47+
resume_opts = train_opts | {
48+
"trainer.resume_training": True,
49+
"trainer.epochs": 2,
50+
}
51+
resumed_model = LuxonisModel(config_file, resume_opts)
52+
resumed_model.train(weights=ckpt_path)

tests/unittests/test_callbacks/test_ema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def test_ema_state_saved_to_checkpoint(
9797
ema_callback.on_save_checkpoint(trainer, model, checkpoint)
9898

9999
assert "state_dict" in checkpoint
100+
assert (
101+
checkpoint["state_dict"].keys()
102+
== ema_callback.ema.state_dict_ema.keys()
103+
)
100104

101105

102106
def test_load_from_checkpoint(

0 commit comments

Comments
 (0)