-
Notifications
You must be signed in to change notification settings - Fork 54
Expand file tree
/
Copy pathutils.py
More file actions
189 lines (149 loc) · 6.64 KB
/
utils.py
File metadata and controls
189 lines (149 loc) · 6.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
Utils for the DLC-Live Model Zoo
"""
# NOTE JR 2026-23-01: This file contains duplicated code from the DeepLabCut main repository.
# This should be removed once a solution is found to address duplicate code.
import copy
import logging
from pathlib import Path
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
from ruamel.yaml import YAML
from dlclive.modelzoo.resolve_config import update_config
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import SUPPORTED_TORCHVISION_DETECTORS
_MODELZOO_PATH = Path(__file__).parent
def get_super_animal_model_config_path(model_name: str) -> Path:
"""Get the path to the model configuration file for a model and validate choice of model"""
cfg_path = _MODELZOO_PATH / "model_configs" / f"{model_name}.yaml"
if not cfg_path.exists():
raise FileNotFoundError(
f"Modelzoo model configuration file not found: {cfg_path} Available models: {list_available_models()}"
)
return cfg_path
def get_super_animal_project_config_path(super_animal: str) -> Path:
"""Get the path to the project configuration file for a project and validate choice of project"""
cfg_path = _MODELZOO_PATH / "project_configs" / f"{super_animal}.yaml"
if not cfg_path.exists():
raise FileNotFoundError(
f"Modelzoo project configuration file not found: {cfg_path} Available projects: {list_available_projects()}"
)
return cfg_path
def get_snapshot_folder_path() -> Path:
return _MODELZOO_PATH / "snapshots"
def list_available_models() -> list[str]:
return [p.stem for p in _MODELZOO_PATH.glob("model_configs/*.yaml")]
def list_available_projects() -> list[str]:
return [p.stem for p in _MODELZOO_PATH.glob("project_configs/*.yaml")]
def list_available_combinations() -> list[str]:
models = list_available_models()
projects = list_available_projects()
combinations = ["_".join([p, m]) for p in projects for m in models]
return combinations
def read_config_as_dict(config_path: str | Path) -> dict:
"""
Args:
config_path: the path to the configuration file to load
Returns:
The configuration file with pure Python classes
"""
with open(config_path) as f:
cfg = YAML(typ="safe", pure=True).load(f)
return cfg
# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase
# from deeplabcut/pose_estimation_pytorch/config/make_pose_config.py
def add_metadata(
project_config: dict,
config: dict,
) -> dict:
"""Adds metadata to a pytorch pose configuration
Args:
project_config: the project configuration
config: the pytorch pose configuration
pose_config_path: the path where the pytorch pose configuration will be saved
Returns:
the configuration with a `meta` key added
"""
config = copy.deepcopy(config)
config["metadata"] = {
"project_path": project_config["project_path"],
"pose_config_path": "",
"bodyparts": project_config.get("multianimalbodyparts")
or project_config["bodyparts"],
"unique_bodyparts": project_config.get("uniquebodyparts", []),
"individuals": project_config.get("individuals", ["animal"]),
"with_identity": project_config.get("identity", False),
}
return config
# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase
# from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py
def load_super_animal_config(
super_animal: str,
model_name: str,
detector_name: str | None = None,
max_individuals: int = 30,
device: str | None = None,
) -> dict:
"""Loads the model configuration file for a model, detector and SuperAnimal
Args:
super_animal: The name of the SuperAnimal for which to create the model config.
model_name: The name of the model for which to create the model config.
detector_name: The name of the detector for which to create the model config.
max_individuals: The maximum number of detections to make in an image
device: The device to use to train/run inference on the model
Returns:
The model configuration for a SuperAnimal-pretrained model.
"""
project_cfg_path = get_super_animal_project_config_path(super_animal=super_animal)
project_config = read_config_as_dict(project_cfg_path)
model_cfg_path = get_super_animal_model_config_path(model_name=model_name)
model_config = read_config_as_dict(model_cfg_path)
model_config = add_metadata(project_config, model_config)
model_config = update_config(model_config, max_individuals, device)
if detector_name is None and super_animal != "superanimal_humanbody":
model_config["method"] = "BU"
else:
model_config["method"] = "TD"
detector_cfg_path = get_super_animal_model_config_path(
model_name=detector_name
)
detector_cfg = read_config_as_dict(detector_cfg_path)
model_config["detector"] = detector_cfg
if super_animal == "superanimal_humanbody":
# Apply specific updates required to run the torchvision detector with pretrained weights
assert detector_name in SUPPORTED_TORCHVISION_DETECTORS
model_config["detector"]['model']= {
"type": "TorchvisionDetectorAdaptor",
"model": detector_name,
"weights": "COCO_V1",
"num_classes": None,
"box_score_thresh": 0.6,
}
return model_config
def download_super_animal_snapshot(dataset: str, model_name: str) -> Path:
"""Downloads a SuperAnimal snapshot
Args:
dataset: The name of the SuperAnimal dataset for which to download a snapshot.
model_name: The name of the model for which to download a snapshot.
Returns:
The path to the downloaded snapshot.
Raises:
RuntimeError if the model fails to download.
"""
snapshot_dir = get_snapshot_folder_path()
model_name = f"{dataset}_{model_name}"
model_filename = f"{model_name}.pt"
model_path = snapshot_dir / model_filename
if model_path.exists():
logging.info(f"Snapshot {model_path} already exists, skipping download")
return model_path
try:
download_huggingface_model(
model_name, target_dir=str(snapshot_dir), rename_mapping=model_filename
)
if not model_path.exists():
raise RuntimeError(f"Failed to download {model_name} to {model_path}")
except Exception as e:
logging.error(
f"Failed to download superanimal snapshot {model_name} to {model_path}: {e}"
)
raise e
return model_path