Skip to content

Commit 3a791aa

Browse files
committed
Add specific export config for torchvision detectors
1 parent 333f714 commit 3a791aa

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

dlclive/modelzoo/pytorch_model_zoo_export.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ def _load_model_weights(model_name: str, super_animal: str = super_animal) -> Or
3232
checkpoint: Path = download_super_animal_snapshot(dataset=super_animal, model_name=model_name)
3333
return torch.load(checkpoint, map_location="cpu", weights_only=True)["model"]
3434

35+
# Skip downloading the detector weights for humanbody models, as they are not on huggingface
36+
skip_detector_download = (detector_name is None) or (super_animal == "superanimal_humanbody")
3537
export_dict = {
3638
"config": model_cfg,
3739
"pose": _load_model_weights(model_name),
38-
"detector": _load_model_weights(detector_name) if detector_name is not None else None,
40+
"detector": None if skip_detector_download else _load_model_weights(detector_name),
3941
}
4042
torch.save(export_dict, export_path)
4143

dlclive/modelzoo/utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ruamel.yaml import YAML
1313

1414
from dlclive.modelzoo.resolve_config import update_config
15+
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import SUPPORTED_TORCHVISION_DETECTORS
1516

1617
_MODELZOO_PATH = Path(__file__).parent
1718

@@ -131,12 +132,21 @@ def load_super_animal_config(
131132
model_config["method"] = "BU"
132133
else:
133134
model_config["method"] = "TD"
134-
if super_animal != "superanimal_humanbody":
135-
detector_cfg_path = get_super_animal_model_config_path(
136-
model_name=detector_name
137-
)
138-
detector_cfg = read_config_as_dict(detector_cfg_path)
139-
model_config["detector"] = detector_cfg
135+
detector_cfg_path = get_super_animal_model_config_path(
136+
model_name=detector_name
137+
)
138+
detector_cfg = read_config_as_dict(detector_cfg_path)
139+
model_config["detector"] = detector_cfg
140+
if super_animal == "superanimal_humanbody":
141+
# Apply specific updates required to run the torchvision detector with pretrained weights
142+
assert detector_name in SUPPORTED_TORCHVISION_DETECTORS
143+
model_config["detector"]['model']= {
144+
"type": "TorchvisionDetectorAdaptor",
145+
"model": detector_name,
146+
"weights": "COCO_V1",
147+
"num_classes": None,
148+
"box_score_thresh": 0.6,
149+
}
140150
return model_config
141151

142152

dlclive/pose_estimation_pytorch/models/detectors/torchvision.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from dlclive.pose_estimation_pytorch.models.detectors.base import DETECTORS, BaseDetector
1818

19+
SUPPORTED_TORCHVISION_DETECTORS = ["fasterrcnn_mobilenet_v3_large_fpn"]
20+
1921

2022
@DETECTORS.register_module
2123
class TorchvisionDetectorAdaptor(BaseDetector):

0 commit comments

Comments
 (0)