Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dlclive/modelzoo/pytorch_model_zoo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ def _load_model_weights(model_name: str, super_animal: str = super_animal) -> Or
checkpoint: Path = download_super_animal_snapshot(dataset=super_animal, model_name=model_name)
return torch.load(checkpoint, map_location="cpu", weights_only=True)["model"]

# Skip downloading the detector weights for humanbody models, as they are not on huggingface
skip_detector_download = (detector_name is None) or (super_animal == "superanimal_humanbody")
export_dict = {
"config": model_cfg,
"pose": _load_model_weights(model_name),
"detector": _load_model_weights(detector_name) if detector_name is not None else None,
"detector": None if skip_detector_download else _load_model_weights(detector_name),
}
torch.save(export_dict, export_path)

Expand Down
22 changes: 16 additions & 6 deletions dlclive/modelzoo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ruamel.yaml import YAML

from dlclive.modelzoo.resolve_config import update_config
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import SUPPORTED_TORCHVISION_DETECTORS

_MODELZOO_PATH = Path(__file__).parent

Expand Down Expand Up @@ -131,12 +132,21 @@ def load_super_animal_config(
model_config["method"] = "BU"
else:
model_config["method"] = "TD"
if super_animal != "superanimal_humanbody":
detector_cfg_path = get_super_animal_model_config_path(
model_name=detector_name
)
detector_cfg = read_config_as_dict(detector_cfg_path)
model_config["detector"] = detector_cfg
detector_cfg_path = get_super_animal_model_config_path(
model_name=detector_name
)
detector_cfg = read_config_as_dict(detector_cfg_path)
model_config["detector"] = detector_cfg
if super_animal == "superanimal_humanbody":
# Apply specific updates required to run the torchvision detector with pretrained weights
assert detector_name in SUPPORTED_TORCHVISION_DETECTORS
model_config["detector"]['model']= {
"type": "TorchvisionDetectorAdaptor",
"model": detector_name,
"weights": "COCO_V1",
"num_classes": None,
"box_score_thresh": 0.6,
}
return model_config


Expand Down
3 changes: 3 additions & 0 deletions dlclive/pose_estimation_pytorch/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@
DETECTORS,
BaseDetector,
)
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import (
TorchvisionDetectorAdaptor,
)
from dlclive.pose_estimation_pytorch.models.detectors.fasterRCNN import FasterRCNN
from dlclive.pose_estimation_pytorch.models.detectors.ssd import SSDLite
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import torch
import torchvision.models.detection as detection

from dlclive.pose_estimation_pytorch.models.detectors.base import BaseDetector
from dlclive.pose_estimation_pytorch.models.detectors.base import DETECTORS, BaseDetector

SUPPORTED_TORCHVISION_DETECTORS = ["fasterrcnn_mobilenet_v3_large_fpn"]


@DETECTORS.register_module
class TorchvisionDetectorAdaptor(BaseDetector):
"""An adaptor for torchvision detectors

Expand All @@ -26,8 +29,8 @@ class TorchvisionDetectorAdaptor(BaseDetector):
- fasterrcnn_mobilenet_v3_large_fpn
- fasterrcnn_resnet50_fpn_v2

This class should not be used out-of-the-box. Subclasses (such as FasterRCNN or
SSDLite) should be used instead.
This class can be used directly (e.g. with pre-trained COCO weights) or through its
subclasses (FasterRCNN or SSDLite) which adapt the model for DLC's 2-class detection.

The torchvision implementation does not allow to get both predictions and losses
with a single forward pass. Therefore, during evaluation only bounding box metrics
Expand Down
21 changes: 18 additions & 3 deletions dlclive/pose_estimation_pytorch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,24 @@ def load_model(self) -> None:
self.model = self.model.half()

self.detector = None
if self.dynamic is None and raw_data.get("detector") is not None:
detector_cfg = self.cfg.get("detector")
has_detector_weights = raw_data.get("detector") is not None
if detector_cfg is not None:
detector_model_cfg = detector_cfg["model"]
uses_pretrained = (
detector_model_cfg.get("pretrained", False)
or detector_model_cfg.get("weights") is not None
)
else:
uses_pretrained = False

if self.dynamic is None and (has_detector_weights or uses_pretrained):
self.detector = models.DETECTORS.build(self.cfg["detector"]["model"])
self.detector.to(self.device)
self.detector.load_state_dict(raw_data["detector"])

if has_detector_weights:
self.detector.load_state_dict(raw_data["detector"])

self.detector.eval()
if self.precision == "FP16":
self.detector = self.detector.half()
Expand All @@ -281,7 +295,8 @@ def load_model(self) -> None:
self.top_down_config.read_config(self.cfg)

detector_transforms = [v2.ToDtype(torch.float32, scale=True)]
if self.cfg["detector"]["data"]["inference"].get("normalize_images", False):
detector_data_cfg = detector_cfg.get("data", {}).get("inference", {})
if detector_data_cfg.get("normalize_images", False):
detector_transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
self.detector_transform = v2.Compose(detector_transforms)

Expand Down