@@ -308,6 +308,7 @@ class Model(dj.Manual):
308308 snapshotindex (int): Which snapshot for prediction (if -1, latest).
309309 shuffle (int): Which shuffle of the training dataset.
310310 trainingsetindex (int): Which training set fraction to generate model.
311+ engine (str): Engine used for model. Either 'tensorflow' or 'pytorch'.
311312 scorer ( varchar(64) ): Scorer/network name - DLC's GetScorerName().
312313 config_template (longblob): Dictionary of the config for analyze_videos().
313314 project_path ( varchar(255) ): DLC's project_path in config relative to root.
@@ -329,7 +330,8 @@ class Model(dj.Manual):
329330 snapshotindex : int # which snapshot for prediction (if -1, latest)
330331 shuffle : int # Shuffle (1) or not (0)
331332 trainingsetindex : int # Index of training fraction list in config.yaml
332- unique index (task, date, iteration, shuffle, snapshotindex, trainingsetindex)
333+ engine='tensorflow' : varchar(16) # Engine used for model. Either 'tensorflow' or 'pytorch'
334+ unique index (task, date, iteration, shuffle, snapshotindex, trainingsetindex, engine)
333335 scorer : varchar(64) # Scorer/network name - DLC's GetScorerName()
334336 config_template : longblob # Dictionary of the config for analyze_videos()
335337 project_path : varchar(255) # DLC's project_path in config relative to root
@@ -378,9 +380,6 @@ def insert_new_model(
378380 prompt (bool): Optional. Prompt the user with all info before inserting.
379381 params (dict): Optional. If dlc_config is path, dict of override items
380382 """
381-
382- from deeplabcut .utils .auxiliaryfunctions import GetScorerName # isort:skip
383-
384383 # handle dlc_config being a yaml file
385384 dlc_config_fp = find_full_path (get_dlc_root_data_dir (), Path (dlc_config ))
386385 assert dlc_config_fp .exists (), (
@@ -409,16 +408,37 @@ def insert_new_model(
409408 for attribute in needed_attributes :
410409 assert attribute in dlc_config , f"Couldn't find { attribute } in config"
411410
412- # ---- Get scorer name ----
413- # "or 'f'" below covers case where config returns None. str_to_bool handles else
414- scorer_legacy = str_to_bool (dlc_config .get ("scorer_legacy" , "f" ))
411+ engine = dlc_config .get ("engine" )
412+ if engine is None :
413+ logger .warning (
414+ "DLC engine not specified in config file. Defaulting to TensorFlow."
415+ )
416+ engine = "tensorflow"
417+
418+ if engine == "tensorflow" :
419+ from deeplabcut .utils .auxiliaryfunctions import GetScorerName # isort:skip
420+
421+ # ---- Get scorer name ----
422+ # "or 'f'" below covers case where config returns None. str_to_bool handles else
423+ scorer_legacy = str_to_bool (dlc_config .get ("scorer_legacy" , "f" ))
424+ dlc_scorer = GetScorerName (
425+ cfg = dlc_config ,
426+ shuffle = shuffle ,
427+ trainFraction = dlc_config ["TrainingFraction" ][int (trainingsetindex )],
428+ modelprefix = model_prefix ,
429+ )[scorer_legacy ]
430+ elif engine == "pytorch" :
431+ from deeplabcut .pose_estimation_pytorch .apis .utils import get_scorer_name
432+
433+ dlc_scorer = get_scorer_name (
434+ cfg = dlc_config ,
435+ shuffle = shuffle ,
436+ train_fraction = dlc_config ["TrainingFraction" ][int (trainingsetindex )],
437+ modelprefix = model_prefix ,
438+ )
439+ else :
440+ raise ValueError (f"Unknow engine type { engine } " )
415441
416- dlc_scorer = GetScorerName (
417- cfg = dlc_config ,
418- shuffle = shuffle ,
419- trainFraction = dlc_config ["TrainingFraction" ][int (trainingsetindex )],
420- modelprefix = model_prefix ,
421- )[scorer_legacy ]
422442 if dlc_config ["snapshotindex" ] == - 1 :
423443 dlc_scorer = "" .join (dlc_scorer .split ("_" )[:- 1 ])
424444
@@ -433,6 +453,7 @@ def insert_new_model(
433453 "snapshotindex" : dlc_config ["snapshotindex" ],
434454 "shuffle" : shuffle ,
435455 "trainingsetindex" : int (trainingsetindex ),
456+ "engine" : engine ,
436457 "project_path" : project_path .relative_to (root_dir ).as_posix (),
437458 "paramset_idx" : paramset_idx ,
438459 "config_template" : dlc_config ,
@@ -719,7 +740,16 @@ def make(self, key):
719740 PoseEstimationTask .update1 (
720741 {** key , "pose_estimation_output_dir" : output_dir .as_posix ()}
721742 )
722- output_dir = find_full_path (get_dlc_root_data_dir (), output_dir )
743+
744+ try :
745+ output_dir = find_full_path (get_dlc_root_data_dir (), output_dir )
746+ except FileNotFoundError as e :
747+ if task_mode == "trigger" :
748+ processed_dir = Path (get_dlc_processed_data_dir ())
749+ output_dir = processed_dir / output_dir
750+ output_dir .mkdir (parents = True , exist_ok = True )
751+ else :
752+ raise e
723753
724754 # Trigger PoseEstimation
725755 if task_mode == "trigger" :
@@ -756,7 +786,18 @@ def make(self, key):
756786 output_directory = output_dir ,
757787 )
758788 def do_analyze_videos ():
759- from deeplabcut .pose_estimation_tensorflow import analyze_videos
789+ engine = dlc_model_ .get ("engine" )
790+ if engine is None :
791+ logger .warning (
792+ "DLC engine not specified in config file. Defaulting to TensorFlow."
793+ )
794+ engine = "tensorflow"
795+ if engine == "pytorch" :
796+ from deeplabcut .pose_estimation_pytorch import analyze_videos
797+ elif engine == "tensorflow" :
798+ from deeplabcut .pose_estimation_tensorflow import analyze_videos
799+ else :
800+ raise ValueError (f"Unknow engine type { engine } " )
760801
761802 # ---- Build and save DLC configuration (yaml) file ----
762803 dlc_config = dlc_model_ ["config_template" ]
0 commit comments