Skip to content

Commit 5f9edf1

Browse files
authored
Merge pull request #134 from ttngu207/dev_pytorch
feat(pytorch_dlc): support the latest DLC with pytorch engine
2 parents eabfe37 + d0f84a7 commit 5f9edf1

File tree

3 files changed

+60
-16
lines changed

3 files changed

+60
-16
lines changed

element_deeplabcut/model.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ class Model(dj.Manual):
308308
snapshotindex (int): Which snapshot for prediction (if -1, latest).
309309
shuffle (int): Which shuffle of the training dataset.
310310
trainingsetindex (int): Which training set fraction to generate model.
311+
engine (str): Engine used for model. Either 'tensorflow' or 'pytorch'.
311312
scorer ( varchar(64) ): Scorer/network name - DLC's GetScorerName().
312313
config_template (longblob): Dictionary of the config for analyze_videos().
313314
project_path ( varchar(255) ): DLC's project_path in config relative to root.
@@ -329,7 +330,8 @@ class Model(dj.Manual):
329330
snapshotindex : int # which snapshot for prediction (if -1, latest)
330331
shuffle : int # Shuffle (1) or not (0)
331332
trainingsetindex : int # Index of training fraction list in config.yaml
332-
unique index (task, date, iteration, shuffle, snapshotindex, trainingsetindex)
333+
engine='tensorflow' : varchar(16) # Engine used for model. Either 'tensorflow' or 'pytorch'
334+
unique index (task, date, iteration, shuffle, snapshotindex, trainingsetindex, engine)
333335
scorer : varchar(64) # Scorer/network name - DLC's GetScorerName()
334336
config_template : longblob # Dictionary of the config for analyze_videos()
335337
project_path : varchar(255) # DLC's project_path in config relative to root
@@ -378,9 +380,6 @@ def insert_new_model(
378380
prompt (bool): Optional. Prompt the user with all info before inserting.
379381
params (dict): Optional. If dlc_config is path, dict of override items
380382
"""
381-
382-
from deeplabcut.utils.auxiliaryfunctions import GetScorerName # isort:skip
383-
384383
# handle dlc_config being a yaml file
385384
dlc_config_fp = find_full_path(get_dlc_root_data_dir(), Path(dlc_config))
386385
assert dlc_config_fp.exists(), (
@@ -409,16 +408,37 @@ def insert_new_model(
409408
for attribute in needed_attributes:
410409
assert attribute in dlc_config, f"Couldn't find {attribute} in config"
411410

412-
# ---- Get scorer name ----
413-
# "or 'f'" below covers case where config returns None. str_to_bool handles else
414-
scorer_legacy = str_to_bool(dlc_config.get("scorer_legacy", "f"))
411+
engine = dlc_config.get("engine")
412+
if engine is None:
413+
logger.warning(
414+
"DLC engine not specified in config file. Defaulting to TensorFlow."
415+
)
416+
engine = "tensorflow"
417+
418+
if engine == "tensorflow":
419+
from deeplabcut.utils.auxiliaryfunctions import GetScorerName # isort:skip
420+
421+
# ---- Get scorer name ----
422+
# "or 'f'" below covers case where config returns None. str_to_bool handles else
423+
scorer_legacy = str_to_bool(dlc_config.get("scorer_legacy", "f"))
424+
dlc_scorer = GetScorerName(
425+
cfg=dlc_config,
426+
shuffle=shuffle,
427+
trainFraction=dlc_config["TrainingFraction"][int(trainingsetindex)],
428+
modelprefix=model_prefix,
429+
)[scorer_legacy]
430+
elif engine == "pytorch":
431+
from deeplabcut.pose_estimation_pytorch.apis.utils import get_scorer_name
432+
433+
dlc_scorer = get_scorer_name(
434+
cfg=dlc_config,
435+
shuffle=shuffle,
436+
train_fraction=dlc_config["TrainingFraction"][int(trainingsetindex)],
437+
modelprefix=model_prefix,
438+
)
439+
else:
440+
raise ValueError(f"Unknow engine type {engine}")
415441

416-
dlc_scorer = GetScorerName(
417-
cfg=dlc_config,
418-
shuffle=shuffle,
419-
trainFraction=dlc_config["TrainingFraction"][int(trainingsetindex)],
420-
modelprefix=model_prefix,
421-
)[scorer_legacy]
422442
if dlc_config["snapshotindex"] == -1:
423443
dlc_scorer = "".join(dlc_scorer.split("_")[:-1])
424444

@@ -433,6 +453,7 @@ def insert_new_model(
433453
"snapshotindex": dlc_config["snapshotindex"],
434454
"shuffle": shuffle,
435455
"trainingsetindex": int(trainingsetindex),
456+
"engine": engine,
436457
"project_path": project_path.relative_to(root_dir).as_posix(),
437458
"paramset_idx": paramset_idx,
438459
"config_template": dlc_config,
@@ -719,7 +740,16 @@ def make(self, key):
719740
PoseEstimationTask.update1(
720741
{**key, "pose_estimation_output_dir": output_dir.as_posix()}
721742
)
722-
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)
743+
744+
try:
745+
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)
746+
except FileNotFoundError as e:
747+
if task_mode == "trigger":
748+
processed_dir = Path(get_dlc_processed_data_dir())
749+
output_dir = processed_dir / output_dir
750+
output_dir.mkdir(parents=True, exist_ok=True)
751+
else:
752+
raise e
723753

724754
# Trigger PoseEstimation
725755
if task_mode == "trigger":
@@ -756,7 +786,18 @@ def make(self, key):
756786
output_directory=output_dir,
757787
)
758788
def do_analyze_videos():
759-
from deeplabcut.pose_estimation_tensorflow import analyze_videos
789+
engine = dlc_model_.get("engine")
790+
if engine is None:
791+
logger.warning(
792+
"DLC engine not specified in config file. Defaulting to TensorFlow."
793+
)
794+
engine = "tensorflow"
795+
if engine == "pytorch":
796+
from deeplabcut.pose_estimation_pytorch import analyze_videos
797+
elif engine == "tensorflow":
798+
from deeplabcut.pose_estimation_tensorflow import analyze_videos
799+
else:
800+
raise ValueError(f"Unknow engine type {engine}")
760801

761802
# ---- Build and save DLC configuration (yaml) file ----
762803
dlc_config = dlc_model_["config_template"]

element_deeplabcut/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
Package metadata
33
"""
44

5-
__version__ = "0.3.3"
5+
__version__ = "0.4.0"

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,8 @@
4646
"element-interface @ git+https://github.com/datajoint/element-interface.git",
4747
],
4848
"tests": ["pytest", "pytest-cov", "shutils"],
49+
"dlc-pytorch": [
50+
"deeplabcut @ git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc"
51+
],
4952
},
5053
)

0 commit comments

Comments
 (0)