66from functools import lru_cache
77from importlib .metadata import version
88from pathlib import Path
9- from typing import TYPE_CHECKING , Annotated , Literal
9+ from typing import TYPE_CHECKING , Annotated , Literal , TypeAlias
1010
1111import yaml
1212from cyclopts import App , Group , Parameter , validators
1313from loguru import logger
1414from luxonis_ml .typing import Params , PathType
1515
16- from luxonis_train .config import Config
1716from luxonis_train .upgrade import upgrade_config , upgrade_installation
1817
18+ OptsType : TypeAlias = Annotated [
19+ list [str ] | None , Parameter (json_list = False , json_dict = False )
20+ ]
21+
1922if TYPE_CHECKING :
2023 import numpy as np
2124
2225 from luxonis_train import LuxonisModel
2326
27+
2428app = App (
2529 help = "Luxonis Train CLI" ,
2630 version = lambda : f"LuxonisTrain v{ version ('luxonis_train' )} " ,
@@ -42,48 +46,23 @@ def create_model(
4246 config : PathType | Params | None ,
4347 opts : list [str ] | None = None ,
4448 weights : PathType | None = None ,
45- debug_mode : bool = False ,
46- load_dataset_metadata : bool = True ,
49+ allow_empty_dataset : bool = False ,
4750) -> "LuxonisModel" :
4851 importlib .reload (sys .modules ["luxonis_train" ])
49- import torch
5052
5153 from luxonis_train import LuxonisModel
52- from luxonis_train .utils .dataset_metadata import DatasetMetadata
53-
54- if weights is not None and config is None :
55- ckpt = torch .load (weights , map_location = "cpu" ) # nosemgre
56- if "config" not in ckpt : # pragma: no cover
57- raise ValueError (
58- f"Checkpoint '{ weights } ' does not contain the 'config' key. "
59- "Cannot restore `LuxonisModel` from checkpoint."
60- )
61- cfg = Config .get_config (upgrade_config (ckpt ["config" ]), opts )
62- dataset_metadata = None
63- if load_dataset_metadata :
64- if "dataset_metadata" not in ckpt :
65- logger .error ("Checkpoint does not contain dataset metadata." )
66- else :
67- try :
68- dataset_metadata = DatasetMetadata (
69- ** ckpt ["dataset_metadata" ]
70- )
71- except Exception as e : # pragma: no cover
72- logger .error (
73- "Failed to load dataset metadata from the checkpoint. "
74- f"Error: { e } "
75- )
76-
77- return LuxonisModel (
78- cfg , debug_mode = debug_mode , dataset_metadata = dataset_metadata
79- )
8054
81- return LuxonisModel (config , opts , debug_mode = debug_mode )
55+ return LuxonisModel (
56+ config ,
57+ opts ,
58+ weights = weights ,
59+ allow_empty_dataset = allow_empty_dataset ,
60+ )
8261
8362
8463@app .command (group = training_group , sort_key = 1 )
8564def train (
86- opts : list [ str ] | None = None ,
65+ opts : OptsType = None ,
8766 / ,
8867 * ,
8968 config : str | None = None ,
@@ -99,29 +78,44 @@ def train(
9978 @type opts: list[str]
10079 @param opts: A list of optional CLI overrides of the config file.
10180 @type debug: bool
102- @param debug: If True, the training will run in debug mode which
103- suppresses some exceptions to allow training without a fully
104- defined model .
81+ @param debug: If true, allows the model to be constructed without
82+ a valid dataset by setting `allow_empty_dataset` to True. This can
83+ be useful for quick testing of the training loop .
10584 """
106- create_model (config , opts , weights , debug_mode = debug ). train (
107- weights = weights
108- )
85+ create_model (
86+ config , opts , weights = weights , allow_empty_dataset = debug
87+ ). train ( weights = weights )
10988
11089
11190@app .command (group = training_group , sort_key = 2 )
112- def tune (opts : list [str ] | None = None , / , * , config : str | None = None ):
91+ def tune (
92+ opts : OptsType = None ,
93+ / ,
94+ * ,
95+ config : str | None = None ,
96+ weights : str | None = None ,
97+ debug : bool = False ,
98+ ):
11399 """Start hyperparameter tuning.
114100
115101 @type config: str
116102 @param config: Path to the configuration file.
117103 @type opts: list[str]
118104 @param opts: A list of optional CLI overrides of the config file.
105+ @type weights: str
106+ @param weights: Path to the model weights.
107+ @type debug: bool
108+ @param debug: If true, allows the model to be constructed without
109+ a valid dataset by setting `allow_empty_dataset` to True. This can
110+ be useful for quick testing of the tuning.
119111 """
120- create_model (config , opts ).tune ()
112+ create_model (
113+ config , opts , weights = weights , allow_empty_dataset = debug
114+ ).tune ()
121115
122116
123117def _yield_visualizations (
124- opts : list [ str ] | None = None ,
118+ opts : OptsType = None ,
125119 config : str | None = None ,
126120 view : Literal ["train" , "val" , "test" ] = "train" ,
127121 size_multiplier : Annotated [
@@ -191,7 +185,7 @@ def get_visualization_item(
191185
192186@app .command (group = training_group , sort_key = 3 )
193187def inspect (
194- opts : list [ str ] | None = None ,
188+ opts : OptsType = None ,
195189 / ,
196190 * ,
197191 config : str | None = None ,
@@ -240,7 +234,7 @@ def get_window() -> str:
240234
241235@app .command (group = evaluation_group , sort_key = 1 )
242236def test (
243- opts : list [ str ] | None = None ,
237+ opts : OptsType = None ,
244238 / ,
245239 * ,
246240 config : str | None = None ,
@@ -261,18 +255,18 @@ def test(
261255 @type opts: list[str]
262256 @param opts: A list of optional CLI overrides of the config file.
263257 @type debug: bool
264- @param debug: If True, the training will run in debug mode which
265- suppresses some exceptions to allow training without a fully
266- defined model .
258+ @param debug: If true, allows the model to be constructed without
259+ a valid dataset by setting `allow_empty_dataset` to True. This can
260+ be useful for quick testing of the evaluation loop .
267261 """
268- create_model (config , opts , weights , debug_mode = debug ). test (
269- view = view , weights = weights
270- )
262+ create_model (
263+ config , opts , weights = weights , allow_empty_dataset = debug
264+ ). test ( view = view , weights = weights )
271265
272266
273267@app .command (group = evaluation_group , sort_key = 2 )
274268def infer (
275- opts : list [ str ] | None = None ,
269+ opts : OptsType = None ,
276270 / ,
277271 * ,
278272 config : str | None = None ,
@@ -302,7 +296,9 @@ def infer(
302296 @type opts: list[str]
303297 @param opts: A list of optional CLI overrides of the config file.
304298 """
305- create_model (config , opts , weights = weights , debug_mode = True ).infer (
299+ create_model (
300+ config , opts , weights = weights , allow_empty_dataset = True
301+ ).infer (
306302 view = view ,
307303 save_dir = save_dir ,
308304 source_path = source_path ,
@@ -312,7 +308,7 @@ def infer(
312308
313309@app .command (group = annotation_group , sort_key = 0 )
314310def annotate (
315- opts : list [ str ] | None = None ,
311+ opts : OptsType = None ,
316312 / ,
317313 * ,
318314 dir_path : Path ,
@@ -323,7 +319,6 @@ def annotate(
323319 delete_local : bool = True ,
324320 delete_remote : bool = True ,
325321 team_id : str | None = None ,
326- debug : bool = False ,
327322):
328323 """Run annotation on a custom directory of images.
329324
@@ -353,11 +348,7 @@ def annotate(
353348 @param opts: A list of optional CLI overrides of the config file.
354349 """
355350 model = create_model (
356- config ,
357- opts ,
358- weights = weights ,
359- load_dataset_metadata = True ,
360- debug_mode = debug ,
351+ config , opts , weights = weights , allow_empty_dataset = True
361352 )
362353
363354 model .annotate (
@@ -373,7 +364,7 @@ def annotate(
373364
374365@app .command (group = export_group , sort_key = 1 )
375366def export (
376- opts : list [ str ] | None = None ,
367+ opts : OptsType = None ,
377368 / ,
378369 * ,
379370 config : str | None = None ,
@@ -400,14 +391,14 @@ def export(
400391 @type opts: list[str]
401392 @param opts: A list of optional CLI overrides of the
402393 """
403- create_model (config , opts , weights = weights , debug_mode = True ). export (
404- save_path = save_path , weights = weights , ckpt_only = ckpt_only
405- )
394+ create_model (
395+ config , opts , weights = weights , allow_empty_dataset = True
396+ ). export ( save_path = save_path , weights = weights , ckpt_only = ckpt_only )
406397
407398
408399@app .command (group = export_group , sort_key = 2 )
409400def archive (
410- opts : list [ str ] | None = None ,
401+ opts : OptsType = None ,
411402 / ,
412403 * ,
413404 config : str | None ,
@@ -426,14 +417,14 @@ def archive(
426417 @type opts: list[str]
427418 @param opts: A list of optional CLI overrides of the config file.
428419 """
429- create_model (str ( config ), opts , weights = weights ). archive (
430- path = executable , weights = weights
431- )
420+ create_model (
421+ config , opts , weights = weights , allow_empty_dataset = True
422+ ). archive ( path = executable , weights = weights )
432423
433424
434425@app .command (group = export_group , sort_key = 3 )
435426def convert (
436- opts : list [ str ] | None = None ,
427+ opts : OptsType = None ,
437428 / ,
438429 * ,
439430 config : str | None = None ,
@@ -456,9 +447,9 @@ def convert(
456447 @type opts: list[str]
457448 @param opts: A list of optional CLI overrides of the config file.
458449 """
459- create_model (config , opts , weights = weights ). convert (
460- weights = weights , save_dir = save_dir
461- )
450+ create_model (
451+ config , opts , weights = weights , allow_empty_dataset = True
452+ ). convert ( save_dir = save_dir , weights = weights )
462453
463454
464455@upgrade_app .command ()
@@ -500,7 +491,7 @@ def config(
500491
501492@upgrade_app .command (name = ["checkpoint" , "ckpt" ])
502493def checkpoint (
503- opts : list [ str ] | None = None ,
494+ opts : OptsType = None ,
504495 / ,
505496 * ,
506497 path : Annotated [
@@ -518,11 +509,10 @@ def checkpoint(
518509 @param new: Where to save the upgraded checkpoint. If left empty,
519510 the old file will be overriden.
520511 """
512+ from luxonis_train import LuxonisModel
513+
521514 logger .info ("Performing a full checkpoint upgrade." )
522- cfg = None
523- if config is not None :
524- cfg = upgrade_config (config )
525- model = create_model (config = cfg , weights = path , opts = opts , debug_mode = True )
515+ model = LuxonisModel (config , opts , weights = path , allow_empty_dataset = True )
526516 model .lightning_module .load_checkpoint (path )
527517
528518 # Needs to be called in order to attach the model to the trainer
0 commit comments