Skip to content

Commit 317f6a6

Browse files
committed
Merge branch 'main' into fix/resume-training-ema
2 parents 959cf99 + 4c125c1 commit 317f6a6

19 files changed

Lines changed: 388 additions & 186 deletions

File tree

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ luxonis_train <command> --help
144144

145145
Specific usage examples can be found in the respective sections below.
146146

147+
> [!NOTE]
148+
> CLI commands `train`, `test`, and `tune` can be run with `--debug`
149+
> flag which allows the model to be used without a functional dataset.
150+
147151
<a name="configuration"></a>
148152

149153
## ⚙️ Configuration

luxonis_train/__main__.py

Lines changed: 68 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,25 @@
66
from functools import lru_cache
77
from importlib.metadata import version
88
from pathlib import Path
9-
from typing import TYPE_CHECKING, Annotated, Literal
9+
from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias
1010

1111
import yaml
1212
from cyclopts import App, Group, Parameter, validators
1313
from loguru import logger
1414
from luxonis_ml.typing import Params, PathType
1515

16-
from luxonis_train.config import Config
1716
from luxonis_train.upgrade import upgrade_config, upgrade_installation
1817

18+
OptsType: TypeAlias = Annotated[
19+
list[str] | None, Parameter(json_list=False, json_dict=False)
20+
]
21+
1922
if TYPE_CHECKING:
2023
import numpy as np
2124

2225
from luxonis_train import LuxonisModel
2326

27+
2428
app = App(
2529
help="Luxonis Train CLI",
2630
version=lambda: f"LuxonisTrain v{version('luxonis_train')}",
@@ -42,48 +46,23 @@ def create_model(
4246
config: PathType | Params | None,
4347
opts: list[str] | None = None,
4448
weights: PathType | None = None,
45-
debug_mode: bool = False,
46-
load_dataset_metadata: bool = True,
49+
allow_empty_dataset: bool = False,
4750
) -> "LuxonisModel":
4851
importlib.reload(sys.modules["luxonis_train"])
49-
import torch
5052

5153
from luxonis_train import LuxonisModel
52-
from luxonis_train.utils.dataset_metadata import DatasetMetadata
53-
54-
if weights is not None and config is None:
55-
ckpt = torch.load(weights, map_location="cpu") # nosemgre
56-
if "config" not in ckpt: # pragma: no cover
57-
raise ValueError(
58-
f"Checkpoint '{weights}' does not contain the 'config' key. "
59-
"Cannot restore `LuxonisModel` from checkpoint."
60-
)
61-
cfg = Config.get_config(upgrade_config(ckpt["config"]), opts)
62-
dataset_metadata = None
63-
if load_dataset_metadata:
64-
if "dataset_metadata" not in ckpt:
65-
logger.error("Checkpoint does not contain dataset metadata.")
66-
else:
67-
try:
68-
dataset_metadata = DatasetMetadata(
69-
**ckpt["dataset_metadata"]
70-
)
71-
except Exception as e: # pragma: no cover
72-
logger.error(
73-
"Failed to load dataset metadata from the checkpoint. "
74-
f"Error: {e}"
75-
)
76-
77-
return LuxonisModel(
78-
cfg, debug_mode=debug_mode, dataset_metadata=dataset_metadata
79-
)
8054

81-
return LuxonisModel(config, opts, debug_mode=debug_mode)
55+
return LuxonisModel(
56+
config,
57+
opts,
58+
weights=weights,
59+
allow_empty_dataset=allow_empty_dataset,
60+
)
8261

8362

8463
@app.command(group=training_group, sort_key=1)
8564
def train(
86-
opts: list[str] | None = None,
65+
opts: OptsType = None,
8766
/,
8867
*,
8968
config: str | None = None,
@@ -99,29 +78,44 @@ def train(
9978
@type opts: list[str]
10079
@param opts: A list of optional CLI overrides of the config file.
10180
@type debug: bool
102-
@param debug: If True, the training will run in debug mode which
103-
suppresses some exceptions to allow training without a fully
104-
defined model.
81+
@param debug: If true, allows the model to be constructed without
82+
a valid dataset by setting `allow_empty_dataset` to True. This can
83+
be useful for quick testing of the training loop.
10584
"""
106-
create_model(config, opts, weights, debug_mode=debug).train(
107-
weights=weights
108-
)
85+
create_model(
86+
config, opts, weights=weights, allow_empty_dataset=debug
87+
).train(weights=weights)
10988

11089

11190
@app.command(group=training_group, sort_key=2)
112-
def tune(opts: list[str] | None = None, /, *, config: str | None = None):
91+
def tune(
92+
opts: OptsType = None,
93+
/,
94+
*,
95+
config: str | None = None,
96+
weights: str | None = None,
97+
debug: bool = False,
98+
):
11399
"""Start hyperparameter tuning.
114100
115101
@type config: str
116102
@param config: Path to the configuration file.
117103
@type opts: list[str]
118104
@param opts: A list of optional CLI overrides of the config file.
105+
@type weights: str
106+
@param weights: Path to the model weights.
107+
@type debug: bool
108+
@param debug: If true, allows the model to be constructed without
109+
a valid dataset by setting `allow_empty_dataset` to True. This can
110+
be useful for quick testing of the tuning.
119111
"""
120-
create_model(config, opts).tune()
112+
create_model(
113+
config, opts, weights=weights, allow_empty_dataset=debug
114+
).tune()
121115

122116

123117
def _yield_visualizations(
124-
opts: list[str] | None = None,
118+
opts: OptsType = None,
125119
config: str | None = None,
126120
view: Literal["train", "val", "test"] = "train",
127121
size_multiplier: Annotated[
@@ -191,7 +185,7 @@ def get_visualization_item(
191185

192186
@app.command(group=training_group, sort_key=3)
193187
def inspect(
194-
opts: list[str] | None = None,
188+
opts: OptsType = None,
195189
/,
196190
*,
197191
config: str | None = None,
@@ -240,7 +234,7 @@ def get_window() -> str:
240234

241235
@app.command(group=evaluation_group, sort_key=1)
242236
def test(
243-
opts: list[str] | None = None,
237+
opts: OptsType = None,
244238
/,
245239
*,
246240
config: str | None = None,
@@ -261,18 +255,18 @@ def test(
261255
@type opts: list[str]
262256
@param opts: A list of optional CLI overrides of the config file.
263257
@type debug: bool
264-
@param debug: If True, the training will run in debug mode which
265-
suppresses some exceptions to allow training without a fully
266-
defined model.
258+
@param debug: If true, allows the model to be constructed without
259+
a valid dataset by setting `allow_empty_dataset` to True. This can
260+
be useful for quick testing of the evaluation loop.
267261
"""
268-
create_model(config, opts, weights, debug_mode=debug).test(
269-
view=view, weights=weights
270-
)
262+
create_model(
263+
config, opts, weights=weights, allow_empty_dataset=debug
264+
).test(view=view, weights=weights)
271265

272266

273267
@app.command(group=evaluation_group, sort_key=2)
274268
def infer(
275-
opts: list[str] | None = None,
269+
opts: OptsType = None,
276270
/,
277271
*,
278272
config: str | None = None,
@@ -302,7 +296,9 @@ def infer(
302296
@type opts: list[str]
303297
@param opts: A list of optional CLI overrides of the config file.
304298
"""
305-
create_model(config, opts, weights=weights, debug_mode=True).infer(
299+
create_model(
300+
config, opts, weights=weights, allow_empty_dataset=True
301+
).infer(
306302
view=view,
307303
save_dir=save_dir,
308304
source_path=source_path,
@@ -312,7 +308,7 @@ def infer(
312308

313309
@app.command(group=annotation_group, sort_key=0)
314310
def annotate(
315-
opts: list[str] | None = None,
311+
opts: OptsType = None,
316312
/,
317313
*,
318314
dir_path: Path,
@@ -323,7 +319,6 @@ def annotate(
323319
delete_local: bool = True,
324320
delete_remote: bool = True,
325321
team_id: str | None = None,
326-
debug: bool = False,
327322
):
328323
"""Run annotation on a custom directory of images.
329324
@@ -353,11 +348,7 @@ def annotate(
353348
@param opts: A list of optional CLI overrides of the config file.
354349
"""
355350
model = create_model(
356-
config,
357-
opts,
358-
weights=weights,
359-
load_dataset_metadata=True,
360-
debug_mode=debug,
351+
config, opts, weights=weights, allow_empty_dataset=True
361352
)
362353

363354
model.annotate(
@@ -373,7 +364,7 @@ def annotate(
373364

374365
@app.command(group=export_group, sort_key=1)
375366
def export(
376-
opts: list[str] | None = None,
367+
opts: OptsType = None,
377368
/,
378369
*,
379370
config: str | None = None,
@@ -400,14 +391,14 @@ def export(
400391
@type opts: list[str]
401392
@param opts: A list of optional CLI overrides of the
402393
"""
403-
create_model(config, opts, weights=weights, debug_mode=True).export(
404-
save_path=save_path, weights=weights, ckpt_only=ckpt_only
405-
)
394+
create_model(
395+
config, opts, weights=weights, allow_empty_dataset=True
396+
).export(save_path=save_path, weights=weights, ckpt_only=ckpt_only)
406397

407398

408399
@app.command(group=export_group, sort_key=2)
409400
def archive(
410-
opts: list[str] | None = None,
401+
opts: OptsType = None,
411402
/,
412403
*,
413404
config: str | None,
@@ -426,14 +417,14 @@ def archive(
426417
@type opts: list[str]
427418
@param opts: A list of optional CLI overrides of the config file.
428419
"""
429-
create_model(str(config), opts, weights=weights).archive(
430-
path=executable, weights=weights
431-
)
420+
create_model(
421+
config, opts, weights=weights, allow_empty_dataset=True
422+
).archive(path=executable, weights=weights)
432423

433424

434425
@app.command(group=export_group, sort_key=3)
435426
def convert(
436-
opts: list[str] | None = None,
427+
opts: OptsType = None,
437428
/,
438429
*,
439430
config: str | None = None,
@@ -456,9 +447,9 @@ def convert(
456447
@type opts: list[str]
457448
@param opts: A list of optional CLI overrides of the config file.
458449
"""
459-
create_model(config, opts, weights=weights).convert(
460-
weights=weights, save_dir=save_dir
461-
)
450+
create_model(
451+
config, opts, weights=weights, allow_empty_dataset=True
452+
).convert(save_dir=save_dir, weights=weights)
462453

463454

464455
@upgrade_app.command()
@@ -500,7 +491,7 @@ def config(
500491

501492
@upgrade_app.command(name=["checkpoint", "ckpt"])
502493
def checkpoint(
503-
opts: list[str] | None = None,
494+
opts: OptsType = None,
504495
/,
505496
*,
506497
path: Annotated[
@@ -518,11 +509,10 @@ def checkpoint(
518509
@param new: Where to save the upgraded checkpoint. If left empty,
519510
the old file will be overriden.
520511
"""
512+
from luxonis_train import LuxonisModel
513+
521514
logger.info("Performing a full checkpoint upgrade.")
522-
cfg = None
523-
if config is not None:
524-
cfg = upgrade_config(config)
525-
model = create_model(config=cfg, weights=path, opts=opts, debug_mode=True)
515+
model = LuxonisModel(config, opts, weights=path, allow_empty_dataset=True)
526516
model.lightning_module.load_checkpoint(path)
527517

528518
# Needs to be called in order to attach the model to the trainer

luxonis_train/callbacks/ema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def on_save_checkpoint(
298298
trainer: pl.Trainer,
299299
pl_module: pl.LightningModule,
300300
checkpoint: dict,
301-
) -> dict[str, Any] | None:
301+
) -> None:
302302
"""Save the EMA state dictionary into the checkpoint.
303303
304304
@type trainer: L{pl.Trainer}

luxonis_train/callbacks/upload_checkpoint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from copy import copy
12
from pathlib import Path
23
from typing import Any
34

45
import lightning.pytorch as pl
56
import torch
67
from lightning.pytorch.callbacks import ModelCheckpoint
78
from loguru import logger
9+
from typing_extensions import override
810

911
import luxonis_train as lxt
1012
from luxonis_train.registry import CALLBACKS
@@ -24,12 +26,15 @@ def __init__(self):
2426
super().__init__()
2527
self.last_best_checkpoints = set()
2628

29+
@override
2730
def on_save_checkpoint(
2831
self,
2932
trainer: pl.Trainer,
3033
module: "lxt.LuxonisLightningModule",
3134
checkpoint: dict[str, Any],
3235
) -> None:
36+
checkpoint = copy(checkpoint)
37+
module._add_custom_data_to_checkpoint(checkpoint)
3338
checkpoint_paths = [
3439
c.best_model_path
3540
for c in trainer.checkpoint_callbacks

0 commit comments

Comments
 (0)