Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/library.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ jobs:
- name: Run library tests
working-directory: library
run: |
just tests || exit 0
just tests
2 changes: 0 additions & 2 deletions library/src/instantlearn/components/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from .timm import TimmImageEncoder

__all__ = [
"AVAILABLE_IMAGE_ENCODERS",
"TIMM_AVAILABLE_IMAGE_ENCODERS",
"HuggingFaceImageEncoder",
"ImageEncoder",
"TimmImageEncoder",
Expand Down
30 changes: 9 additions & 21 deletions library/src/instantlearn/components/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def load_image_encoder(
model_id: str = "dinov3_large",
model_id: str = "dinov3-large",
device: str = "cuda",
backend: Backend = Backend.TIMM,
model_path: Path | None = None,
Expand All @@ -33,9 +33,7 @@ def load_image_encoder(
and flexibility, while OpenVINO provides optimized inference.

Args:
model_id: The DINO model variant to use. Options:
- "dinov2_small", "dinov2_base", "dinov2_large", "dinov2_giant" (HuggingFace only)
- "dinov3_small", "dinov3_small_plus", "dinov3_base", "dinov3_large", "dinov3_huge"
model_id: The model ID (e.g., "dinov3-large", "dinov3-small").
device: Device to run inference on. For HuggingFace/TIMM: "cuda" or "cpu".
For OpenVINO: "CPU", "GPU", or "AUTO".
backend: Which backend to use: Backend.HUGGINGFACE, Backend.TIMM, or Backend.OPENVINO.
Expand All @@ -53,29 +51,21 @@ def load_image_encoder(

Raises:
ValueError: If backend is not valid.
FileNotFoundError: If OpenVINO model_path doesn't exist.

Examples:
>>> # HuggingFace backend (DINOv2 models)
>>> encoder = load_image_encoder(
... model_id="dinov2_large",
... device="cuda",
... backend=Backend.HUGGINGFACE
... )
>>>
>>> # TIMM backend (DINOv3 models)
>>> encoder = load_image_encoder(
... model_id="dinov3_large",
... model_id="dinov3-large",
... device="cuda",
... backend=Backend.TIMM
... )
>>>
>>> # Export to OpenVINO
>>> ov_path = encoder.export(Path("./exported/dinov2_large"))
>>> ov_path = encoder.export(Path("./exported/dinov3-large"))
>>>
>>> # OpenVINO backend (optimized inference)
>>> ov_encoder = load_image_encoder(
... model_id="dinov2_large",
... model_id="dinov3-large",
... device="CPU",
... backend=Backend.OPENVINO,
... model_path=ov_path
Expand Down Expand Up @@ -114,7 +104,7 @@ class ImageEncoder(nn.Module):
>>> import torch
>>>
>>> # Create encoder with TIMM backend
>>> encoder = ImageEncoder(model_id="dinov3_large", backend=Backend.TIMM)
>>> encoder = ImageEncoder(model_id="dinov3-large", backend=Backend.TIMM)
>>> sample_image = torch.zeros((3, 518, 518))
>>> features = encoder([sample_image])
>>> features.shape
Expand All @@ -125,15 +115,15 @@ class ImageEncoder(nn.Module):
>>>
>>> # Load with OpenVINO backend
>>> ov_encoder = ImageEncoder(
... model_id="dinov2_large",
... model_id="dinov3-large",
... backend=Backend.OPENVINO,
... model_path=ov_path
... )
"""

def __init__(
self,
model_id: str = "dinov3_large",
model_id: str = "dinov3-large",
backend: Backend = Backend.TIMM,
device: str = "cuda",
precision: str = "bf16",
Expand All @@ -144,9 +134,7 @@ def __init__(
"""Initialize the image encoder.

Args:
model_id: The DINO model variant to use. Options:
- "dinov2_small", "dinov2_base", "dinov2_large", "dinov2_giant" (HuggingFace)
- "dinov3_small", "dinov3_small_plus", "dinov3_base", "dinov3_large", "dinov3_huge" (TIMM)
model_id: The model ID (e.g., "dinov3-large", "dinov3-small").
backend: Which backend to use: Backend.HUGGINGFACE, Backend.TIMM, or Backend.OPENVINO.
device: Device to run inference on. For HuggingFace/TIMM: "cuda" or "cpu".
For OpenVINO: "CPU", "GPU", or "AUTO".
Expand Down
36 changes: 17 additions & 19 deletions library/src/instantlearn/components/encoders/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,6 @@

logger = getLogger("Geti Instant Learn")

AVAILABLE_IMAGE_ENCODERS = {
"dinov2_small": ("facebook/dinov2-with-registers-small", "0d9846e56b43a21fa46d7f3f5070f0506a5795a9"),
"dinov2_base": ("facebook/dinov2-with-registers-base", "a1d738ccfa7ae170945f210395d99dde8adb1805"),
"dinov2_large": ("facebook/dinov2-with-registers-large", "e4c89a4e05589de9b3e188688a303d0f3c04d0f3"),
"dinov2_giant": ("facebook/dinov2-with-registers-giant", "8d0d49f77fb8b5dd78842496ff14afe7dd4d85cb"),
"dinov3_small": ("facebook/dinov3-vits16-pretrain-lvd1689m", "114c1379950215c8b35dfcd4e90a5c251dde0d32"),
"dinov3_small_plus": ("facebook/dinov3-vits16plus-pretrain-lvd1689m", "c93d816fc9e567563bc068f01475bec89cc634a6"),
"dinov3_base": ("facebook/dinov3-vitb16-pretrain-lvd1689m", "5931719e67bbdb9737e363e781fb0c67687896bc"),
"dinov3_large": ("facebook/dinov3-vitl16-pretrain-lvd1689m", "ea8dc2863c51be0a264bab82070e3e8836b02d51"),
"dinov3_huge": ("facebook/dinov3-vith16plus-pretrain-lvd1689m", "c807c9eeea853df70aec4069e6f56b28ddc82acc"),
}


class HuggingFaceImageEncoder(nn.Module):
"""HuggingFace backend for DINO image encoder.
Expand All @@ -41,15 +29,15 @@ class HuggingFaceImageEncoder(nn.Module):
>>>
>>> # Create a sample image
>>> sample_image = torch.zeros((3, 518, 518))
>>> encoder = HuggingFaceImageEncoder(model_id="dinov2_large")
>>> encoder = HuggingFaceImageEncoder(model_id="dinov2-large")
>>> features = encoder(images=[sample_image])
>>> features.shape
torch.Size([1369, 1024])
"""

def __init__(
self,
model_id: str = "dinov3_large",
model_id: str = "dinov2-large",
device: str = "cuda",
precision: str = "bf16",
compile_models: bool = False,
Expand All @@ -58,28 +46,38 @@ def __init__(
"""Initialize the encoder.

Args:
model_id: The model id to use.
model_id: The model ID (e.g., "dinov2-large").
device: The device to use.
precision: The precision to use.
compile_models: Whether to compile the models.
input_size: The input size to use.

Raises:
ValueError: If the model ID is invalid.
ValueError: If the model ID is invalid or not an encoder type.
"""
from instantlearn.utils.optimization import optimize_model

super().__init__()

if model_id not in AVAILABLE_IMAGE_ENCODERS:
msg = f"Invalid model ID: {model_id}. Valid model IDs: {list(AVAILABLE_IMAGE_ENCODERS.keys())}"
# Validate model exists in registry
model_meta = get_model(model_id)
if model_meta is None:
valid = [m.id for m in get_models_by_type(ModelType.ENCODER)]
msg = f"Invalid model ID: '{model_id}'. Valid encoders: {valid}"
raise ValueError(msg)
if model_meta.type != ModelType.ENCODER:
msg = f"Model '{model_id}' is not an encoder (type: {model_meta.type})"
raise ValueError(msg)
if model_meta.hf_model_id is None:
msg = f"Model '{model_id}' has no hf_model_id configured"
raise ValueError(msg)

self.model_id = model_id
self.input_size = input_size
self.device = device

hf_model_id, revision = AVAILABLE_IMAGE_ENCODERS[model_id]
hf_model_id = model_meta.hf_model_id
revision = model_meta.hf_revision

msg = f"Loading DINO model {hf_model_id} with revision {revision}"
logger.info(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class LinearSumAssignment(nn.Module):
- "auto" (recommended): Uses fast scipy during normal execution,
automatically switches to greedy during ONNX/TorchScript export.
Best of both worlds - fast dev, exportable deployment.
- "greedy": Fast O(n² × min(n,m)) approximation, ~95-100% optimal.
- "greedy": Fast O(n² x min(n,m)) approximation, ~95-100% optimal.
Always exportable. Achieves 99%+ for rectangular matrices.
Default: "auto".

Expand Down
69 changes: 35 additions & 34 deletions library/src/instantlearn/components/sam/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from logging import getLogger
from pathlib import Path
from typing import ClassVar

import numpy as np
import torch
Expand Down Expand Up @@ -87,35 +88,32 @@ def load_sam_model(
return predictor


def check_model_weights(model_name: SAMModelName) -> None:
def check_model_weights(model_id: str) -> None:
"""Check if model weights exist locally, download if necessary.

Args:
model_name: The name of the model.
model_id: The model ID (e.g., "sam-hq-tiny", "sam2-base").

Raises:
ValueError: If the model is not found in MODEL_MAP.
ValueError: If the model weights are missing.
ValueError: If the model is not found in registry or has no weights URL.
"""
if model_name not in MODEL_MAP:
msg = f"Model '{model_name.value}' not found in MODEL_MAP for weight checking."
model = get_model(model_id)
if model is None:
valid = [m.id for m in get_models_by_type(ModelType.SEGMENTER)]
msg = f"Model '{model_id}' not found. Valid segmenters: {valid}"
raise ValueError(msg)

model_info = MODEL_MAP[model_name]
local_filename = model_info["local_filename"]
download_url = model_info["download_url"]
sha_sum = model_info["sha_sum"]

if not local_filename or not download_url:
msg = f"Missing 'local_filename' or 'download_url' for {model_name.value} in MODEL_MAP."
if model.weights_url is None:
msg = f"Model '{model_id}' has no weights_url configured."
raise ValueError(msg)

local_filename = get_local_filename(model_id)
target_path = DATA_PATH.joinpath(local_filename)

if not target_path.exists():
msg = f"Model weights for {model_name.value} not found at {target_path}, downloading..."
msg = f"Model weights for {model_id} not found at {target_path}, downloading..."
logger.info(msg)
download_file(download_url, target_path, sha_sum)
download_file(model.weights_url, target_path, model.sha_sum)


class PositionEmbeddingRandom(_PositionEmbeddingRandom):
Expand All @@ -136,7 +134,7 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
# Convert coords to match the gaussian matrix dtype
coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords @= self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)

Expand Down Expand Up @@ -313,9 +311,15 @@ class SAMPredictor(nn.Module):
"no mask input". The prompt encoder detects these and uses the default no_mask_embed.
"""

# SAM-HQ registry name mapping (model_id -> segment_anything_hq registry key)
SAM_HQ_REGISTRY_MAP: ClassVar[dict[str, str]] = {
"sam-hq": "vit_h",
"sam-hq-tiny": "vit_tiny",
}

def __init__(
self,
sam_model_name: SAMModelName,
model_id: str,
device: str,
model_path: Path | None = None,
target_length: int = 1024,
Expand All @@ -329,7 +333,7 @@ def __init__(
target_length: Target length for the longest side of the image during transformation. Defaults to 1024.

Raises:
NotImplementedError: If the model type is not supported.
ValueError: If the model ID is not found in registry.
"""
super().__init__()
self.device = device
Expand All @@ -339,28 +343,25 @@ def __init__(

# Determine checkpoint path
if model_path is None:
check_model_weights(sam_model_name)
model_info = MODEL_MAP[sam_model_name]
checkpoint_path = DATA_PATH.joinpath(model_info["local_filename"])
check_model_weights(model_id)
local_filename = get_local_filename(model_id)
checkpoint_path = DATA_PATH.joinpath(local_filename)
else:
checkpoint_path = model_path

msg = f"Loading PyTorch SAM: {sam_model_name} from {checkpoint_path}"
msg = f"Loading PyTorch SAM: {model_id} from {checkpoint_path}"
logger.info(msg)

# Load model based on type
if sam_model_name in {
SAMModelName.SAM2_TINY,
SAMModelName.SAM2_SMALL,
SAMModelName.SAM2_BASE,
SAMModelName.SAM2_LARGE,
}:
model_info = MODEL_MAP[sam_model_name]
config_path = "configs/sam2.1/" + model_info["config_filename"]
# Load model based on family
if model_meta.family == "SAM2":
config_path = "configs/sam2.1/" + model_meta.config_filename
sam_model = build_sam2(config_path, str(checkpoint_path))
self._predictor = SAM2ImagePredictor(sam_model)
elif sam_model_name in {SAMModelName.SAM_HQ, SAMModelName.SAM_HQ_TINY}:
registry_name = MODEL_MAP[sam_model_name]["registry_name"]
elif model_meta.family == "SAM-HQ":
registry_name = self.SAM_HQ_REGISTRY_MAP.get(model_id)
if registry_name is None:
msg = f"SAM-HQ model '{model_id}' not in SAM_HQ_REGISTRY_MAP"
raise ValueError(msg)
sam_model = sam_model_registry[registry_name]().to(device)
# suppress - loading the snapshot from the local path
# nosemgrep trailofbits.python.pickles-in-pytorch.pickles-in-pytorch
Expand All @@ -383,7 +384,7 @@ def __init__(
self._predictor.model.image_encoder,
])
else:
msg = f"Model {sam_model_name} not implemented"
msg = f"Model family '{model_meta.family}' not implemented"
raise NotImplementedError(msg)

def _patch_prompt_encoder(self, device: str) -> None:
Expand Down
6 changes: 4 additions & 2 deletions library/src/instantlearn/models/grounded_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
class GroundedSAM(Model):
"""This model uses a zero-shot object detector (from Huggingface) to generate boxes for SAM."""

DEFAULT_SAM = "sam-hq-tiny"

def __init__(
self,
sam: SAMModelName = SAMModelName.SAM_HQ_TINY,
sam: str = DEFAULT_SAM,
grounding_model: GroundingModel = GroundingModel.LLMDET_TINY,
precision: str = "bf16",
compile_models: bool = False,
Expand All @@ -32,7 +34,7 @@ def __init__(
"""Initialize the model.

Args:
sam: The SAM model name.
sam: Model ID for the segmenter (e.g., "sam-hq-tiny", "sam-hq").
grounding_model: The grounding model to use.
precision: The precision to use for the model.
compile_models: Whether to compile the models.
Expand Down
6 changes: 3 additions & 3 deletions library/src/instantlearn/models/soft_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class SoftMatcher(Matcher):

def __init__(
self,
sam: SAMModelName = SAMModelName.SAM_HQ_TINY,
sam: str = Matcher.DEFAULT_SAM,
num_foreground_points: int = 40,
num_background_points: int = 2,
confidence_threshold: float | None = 0.42,
Expand All @@ -88,7 +88,7 @@ def __init__(
"""Initialize the SoftMatcher model.

Args:
sam: The name of the SAM model to use.
sam: Model ID for the segmenter (e.g., "sam-hq-tiny", "sam-hq").
num_foreground_points: The number of foreground points to use.
num_background_points: The number of background points to use.
confidence_threshold: Minimum confidence score for keeping predicted masks
Expand All @@ -98,7 +98,7 @@ def __init__(
approximate_matching: Whether to use approximate matching.
softmatching_score_threshold: The score threshold for the soft matching.
softmatching_bidirectional: Whether to use bidirectional soft matching.
encoder_model: The encoder model to use.
encoder_model: Model ID for the encoder (e.g., "dinov3-large").
precision: The precision to use for the model.
compile_models: Whether to compile the models.
device: The device to use for the model.
Expand Down
Loading
Loading