diff --git a/.github/workflows/library.yml b/.github/workflows/library.yml index 4f5fc3c01..7e032e9a3 100644 --- a/.github/workflows/library.yml +++ b/.github/workflows/library.yml @@ -34,4 +34,4 @@ jobs: - name: Run library tests working-directory: library run: | - just tests || exit 0 + just tests diff --git a/library/src/instantlearn/components/encoders/__init__.py b/library/src/instantlearn/components/encoders/__init__.py index 66cde1674..651fc2591 100644 --- a/library/src/instantlearn/components/encoders/__init__.py +++ b/library/src/instantlearn/components/encoders/__init__.py @@ -9,8 +9,6 @@ from .timm import TimmImageEncoder __all__ = [ - "AVAILABLE_IMAGE_ENCODERS", - "TIMM_AVAILABLE_IMAGE_ENCODERS", "HuggingFaceImageEncoder", "ImageEncoder", "TimmImageEncoder", diff --git a/library/src/instantlearn/components/encoders/base.py b/library/src/instantlearn/components/encoders/base.py index 970a51d5d..020e57a1b 100644 --- a/library/src/instantlearn/components/encoders/base.py +++ b/library/src/instantlearn/components/encoders/base.py @@ -18,7 +18,7 @@ def load_image_encoder( - model_id: str = "dinov3_large", + model_id: str = "dinov3-large", device: str = "cuda", backend: Backend = Backend.TIMM, model_path: Path | None = None, @@ -33,9 +33,7 @@ def load_image_encoder( and flexibility, while OpenVINO provides optimized inference. Args: - model_id: The DINO model variant to use. Options: - - "dinov2_small", "dinov2_base", "dinov2_large", "dinov2_giant" (HuggingFace only) - - "dinov3_small", "dinov3_small_plus", "dinov3_base", "dinov3_large", "dinov3_huge" + model_id: The model ID (e.g., "dinov3-large", "dinov3-small"). device: Device to run inference on. For HuggingFace/TIMM: "cuda" or "cpu". For OpenVINO: "CPU", "GPU", or "AUTO". backend: Which backend to use: Backend.HUGGINGFACE, Backend.TIMM, or Backend.OPENVINO. @@ -53,29 +51,21 @@ def load_image_encoder( Raises: ValueError: If backend is not valid. - FileNotFoundError: If OpenVINO model_path doesn't exist. Examples: - >>> # HuggingFace backend (DINOv2 models) - >>> encoder = load_image_encoder( - ... model_id="dinov2_large", - ... device="cuda", - ... backend=Backend.HUGGINGFACE - ... ) - >>> >>> # TIMM backend (DINOv3 models) >>> encoder = load_image_encoder( - ... model_id="dinov3_large", + ... model_id="dinov3-large", ... device="cuda", ... backend=Backend.TIMM ... ) >>> >>> # Export to OpenVINO - >>> ov_path = encoder.export(Path("./exported/dinov2_large")) + >>> ov_path = encoder.export(Path("./exported/dinov3-large")) >>> >>> # OpenVINO backend (optimized inference) >>> ov_encoder = load_image_encoder( - ... model_id="dinov2_large", + ... model_id="dinov3-large", ... device="CPU", ... backend=Backend.OPENVINO, ... model_path=ov_path @@ -114,7 +104,7 @@ class ImageEncoder(nn.Module): >>> import torch >>> >>> # Create encoder with TIMM backend - >>> encoder = ImageEncoder(model_id="dinov3_large", backend=Backend.TIMM) + >>> encoder = ImageEncoder(model_id="dinov3-large", backend=Backend.TIMM) >>> sample_image = torch.zeros((3, 518, 518)) >>> features = encoder([sample_image]) >>> features.shape @@ -125,7 +115,7 @@ class ImageEncoder(nn.Module): >>> >>> # Load with OpenVINO backend >>> ov_encoder = ImageEncoder( - ... model_id="dinov2_large", + ... model_id="dinov3-large", ... backend=Backend.OPENVINO, ... model_path=ov_path ... ) @@ -133,7 +123,7 @@ class ImageEncoder(nn.Module): def __init__( self, - model_id: str = "dinov3_large", + model_id: str = "dinov3-large", backend: Backend = Backend.TIMM, device: str = "cuda", precision: str = "bf16", @@ -144,9 +134,7 @@ def __init__( """Initialize the image encoder. Args: - model_id: The DINO model variant to use. Options: - - "dinov2_small", "dinov2_base", "dinov2_large", "dinov2_giant" (HuggingFace) - - "dinov3_small", "dinov3_small_plus", "dinov3_base", "dinov3_large", "dinov3_huge" (TIMM) + model_id: The model ID (e.g., "dinov3-large", "dinov3-small"). backend: Which backend to use: Backend.HUGGINGFACE, Backend.TIMM, or Backend.OPENVINO. device: Device to run inference on. For HuggingFace/TIMM: "cuda" or "cpu". For OpenVINO: "CPU", "GPU", or "AUTO". diff --git a/library/src/instantlearn/components/encoders/huggingface.py b/library/src/instantlearn/components/encoders/huggingface.py index f851df4e6..b5e8396dc 100644 --- a/library/src/instantlearn/components/encoders/huggingface.py +++ b/library/src/instantlearn/components/encoders/huggingface.py @@ -15,18 +15,6 @@ logger = getLogger("Geti Instant Learn") -AVAILABLE_IMAGE_ENCODERS = { - "dinov2_small": ("facebook/dinov2-with-registers-small", "0d9846e56b43a21fa46d7f3f5070f0506a5795a9"), - "dinov2_base": ("facebook/dinov2-with-registers-base", "a1d738ccfa7ae170945f210395d99dde8adb1805"), - "dinov2_large": ("facebook/dinov2-with-registers-large", "e4c89a4e05589de9b3e188688a303d0f3c04d0f3"), - "dinov2_giant": ("facebook/dinov2-with-registers-giant", "8d0d49f77fb8b5dd78842496ff14afe7dd4d85cb"), - "dinov3_small": ("facebook/dinov3-vits16-pretrain-lvd1689m", "114c1379950215c8b35dfcd4e90a5c251dde0d32"), - "dinov3_small_plus": ("facebook/dinov3-vits16plus-pretrain-lvd1689m", "c93d816fc9e567563bc068f01475bec89cc634a6"), - "dinov3_base": ("facebook/dinov3-vitb16-pretrain-lvd1689m", "5931719e67bbdb9737e363e781fb0c67687896bc"), - "dinov3_large": ("facebook/dinov3-vitl16-pretrain-lvd1689m", "ea8dc2863c51be0a264bab82070e3e8836b02d51"), - "dinov3_huge": ("facebook/dinov3-vith16plus-pretrain-lvd1689m", "c807c9eeea853df70aec4069e6f56b28ddc82acc"), -} - class HuggingFaceImageEncoder(nn.Module): """HuggingFace backend for DINO image encoder. @@ -41,7 +29,7 @@ class HuggingFaceImageEncoder(nn.Module): >>> >>> # Create a sample image >>> sample_image = torch.zeros((3, 518, 518)) - >>> encoder = HuggingFaceImageEncoder(model_id="dinov2_large") + >>> encoder = HuggingFaceImageEncoder(model_id="dinov2-large") >>> features = encoder(images=[sample_image]) >>> features.shape torch.Size([1369, 1024]) @@ -49,7 +37,7 @@ class HuggingFaceImageEncoder(nn.Module): def __init__( self, - model_id: str = "dinov3_large", + model_id: str = "dinov2-large", device: str = "cuda", precision: str = "bf16", compile_models: bool = False, @@ -58,28 +46,38 @@ def __init__( """Initialize the encoder. Args: - model_id: The model id to use. + model_id: The model ID (e.g., "dinov2-large"). device: The device to use. precision: The precision to use. compile_models: Whether to compile the models. input_size: The input size to use. Raises: - ValueError: If the model ID is invalid. + ValueError: If the model ID is invalid or not an encoder type. """ from instantlearn.utils.optimization import optimize_model super().__init__() - if model_id not in AVAILABLE_IMAGE_ENCODERS: - msg = f"Invalid model ID: {model_id}. Valid model IDs: {list(AVAILABLE_IMAGE_ENCODERS.keys())}" + # Validate model exists in registry + model_meta = get_model(model_id) + if model_meta is None: + valid = [m.id for m in get_models_by_type(ModelType.ENCODER)] + msg = f"Invalid model ID: '{model_id}'. Valid encoders: {valid}" + raise ValueError(msg) + if model_meta.type != ModelType.ENCODER: + msg = f"Model '{model_id}' is not an encoder (type: {model_meta.type})" + raise ValueError(msg) + if model_meta.hf_model_id is None: + msg = f"Model '{model_id}' has no hf_model_id configured" raise ValueError(msg) self.model_id = model_id self.input_size = input_size self.device = device - hf_model_id, revision = AVAILABLE_IMAGE_ENCODERS[model_id] + hf_model_id = model_meta.hf_model_id + revision = model_meta.hf_revision msg = f"Loading DINO model {hf_model_id} with revision {revision}" logger.info(msg) diff --git a/library/src/instantlearn/components/linear_sum_assignment.py b/library/src/instantlearn/components/linear_sum_assignment.py index 01af7d070..2a7aec231 100644 --- a/library/src/instantlearn/components/linear_sum_assignment.py +++ b/library/src/instantlearn/components/linear_sum_assignment.py @@ -37,7 +37,7 @@ class LinearSumAssignment(nn.Module): - "auto" (recommended): Uses fast scipy during normal execution, automatically switches to greedy during ONNX/TorchScript export. Best of both worlds - fast dev, exportable deployment. - - "greedy": Fast O(n² × min(n,m)) approximation, ~95-100% optimal. + - "greedy": Fast O(n² x min(n,m)) approximation, ~95-100% optimal. Always exportable. Achieves 99%+ for rectangular matrices. Default: "auto". diff --git a/library/src/instantlearn/components/sam/predictor.py b/library/src/instantlearn/components/sam/predictor.py index 2d57144c0..c0db3b69b 100644 --- a/library/src/instantlearn/components/sam/predictor.py +++ b/library/src/instantlearn/components/sam/predictor.py @@ -5,6 +5,7 @@ from logging import getLogger from pathlib import Path +from typing import ClassVar import numpy as np import torch @@ -87,35 +88,32 @@ def load_sam_model( return predictor -def check_model_weights(model_name: SAMModelName) -> None: +def check_model_weights(model_id: str) -> None: """Check if model weights exist locally, download if necessary. Args: - model_name: The name of the model. + model_id: The model ID (e.g., "sam-hq-tiny", "sam2-base"). Raises: - ValueError: If the model is not found in MODEL_MAP. - ValueError: If the model weights are missing. + ValueError: If the model is not found in registry or has no weights URL. """ - if model_name not in MODEL_MAP: - msg = f"Model '{model_name.value}' not found in MODEL_MAP for weight checking." + model = get_model(model_id) + if model is None: + valid = [m.id for m in get_models_by_type(ModelType.SEGMENTER)] + msg = f"Model '{model_id}' not found. Valid segmenters: {valid}" raise ValueError(msg) - model_info = MODEL_MAP[model_name] - local_filename = model_info["local_filename"] - download_url = model_info["download_url"] - sha_sum = model_info["sha_sum"] - - if not local_filename or not download_url: - msg = f"Missing 'local_filename' or 'download_url' for {model_name.value} in MODEL_MAP." + if model.weights_url is None: + msg = f"Model '{model_id}' has no weights_url configured." raise ValueError(msg) + local_filename = get_local_filename(model_id) target_path = DATA_PATH.joinpath(local_filename) if not target_path.exists(): - msg = f"Model weights for {model_name.value} not found at {target_path}, downloading..." + msg = f"Model weights for {model_id} not found at {target_path}, downloading..." logger.info(msg) - download_file(download_url, target_path, sha_sum) + download_file(model.weights_url, target_path, model.sha_sum) class PositionEmbeddingRandom(_PositionEmbeddingRandom): @@ -136,7 +134,7 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: # Convert coords to match the gaussian matrix dtype coords = coords.to(self.positional_encoding_gaussian_matrix.dtype) coords = 2 * coords - 1 - coords = coords @ self.positional_encoding_gaussian_matrix + coords @= self.positional_encoding_gaussian_matrix coords = 2 * np.pi * coords return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) @@ -313,9 +311,15 @@ class SAMPredictor(nn.Module): "no mask input". The prompt encoder detects these and uses the default no_mask_embed. """ + # SAM-HQ registry name mapping (model_id -> segment_anything_hq registry key) + SAM_HQ_REGISTRY_MAP: ClassVar[dict[str, str]] = { + "sam-hq": "vit_h", + "sam-hq-tiny": "vit_tiny", + } + def __init__( self, - sam_model_name: SAMModelName, + model_id: str, device: str, model_path: Path | None = None, target_length: int = 1024, @@ -329,7 +333,7 @@ def __init__( target_length: Target length for the longest side of the image during transformation. Defaults to 1024. Raises: - NotImplementedError: If the model type is not supported. + ValueError: If the model ID is not found in registry. """ super().__init__() self.device = device @@ -339,28 +343,25 @@ def __init__( # Determine checkpoint path if model_path is None: - check_model_weights(sam_model_name) - model_info = MODEL_MAP[sam_model_name] - checkpoint_path = DATA_PATH.joinpath(model_info["local_filename"]) + check_model_weights(model_id) + local_filename = get_local_filename(model_id) + checkpoint_path = DATA_PATH.joinpath(local_filename) else: checkpoint_path = model_path - msg = f"Loading PyTorch SAM: {sam_model_name} from {checkpoint_path}" + msg = f"Loading PyTorch SAM: {model_id} from {checkpoint_path}" logger.info(msg) - # Load model based on type - if sam_model_name in { - SAMModelName.SAM2_TINY, - SAMModelName.SAM2_SMALL, - SAMModelName.SAM2_BASE, - SAMModelName.SAM2_LARGE, - }: - model_info = MODEL_MAP[sam_model_name] - config_path = "configs/sam2.1/" + model_info["config_filename"] + # Load model based on family + if model_meta.family == "SAM2": + config_path = "configs/sam2.1/" + model_meta.config_filename sam_model = build_sam2(config_path, str(checkpoint_path)) self._predictor = SAM2ImagePredictor(sam_model) - elif sam_model_name in {SAMModelName.SAM_HQ, SAMModelName.SAM_HQ_TINY}: - registry_name = MODEL_MAP[sam_model_name]["registry_name"] + elif model_meta.family == "SAM-HQ": + registry_name = self.SAM_HQ_REGISTRY_MAP.get(model_id) + if registry_name is None: + msg = f"SAM-HQ model '{model_id}' not in SAM_HQ_REGISTRY_MAP" + raise ValueError(msg) sam_model = sam_model_registry[registry_name]().to(device) # suppress - loading the snapshot from the local path # nosemgrep trailofbits.python.pickles-in-pytorch.pickles-in-pytorch @@ -383,7 +384,7 @@ def __init__( self._predictor.model.image_encoder, ]) else: - msg = f"Model {sam_model_name} not implemented" + msg = f"Model family '{model_meta.family}' not implemented" raise NotImplementedError(msg) def _patch_prompt_encoder(self, device: str) -> None: diff --git a/library/src/instantlearn/models/grounded_sam.py b/library/src/instantlearn/models/grounded_sam.py index bb3b425e0..407f7ba64 100644 --- a/library/src/instantlearn/models/grounded_sam.py +++ b/library/src/instantlearn/models/grounded_sam.py @@ -18,9 +18,11 @@ class GroundedSAM(Model): """This model uses a zero-shot object detector (from Huggingface) to generate boxes for SAM.""" + DEFAULT_SAM = "sam-hq-tiny" + def __init__( self, - sam: SAMModelName = SAMModelName.SAM_HQ_TINY, + sam: str = DEFAULT_SAM, grounding_model: GroundingModel = GroundingModel.LLMDET_TINY, precision: str = "bf16", compile_models: bool = False, @@ -32,7 +34,7 @@ def __init__( """Initialize the model. Args: - sam: The SAM model name. + sam: Model ID for the segmenter (e.g., "sam-hq-tiny", "sam-hq"). grounding_model: The grounding model to use. precision: The precision to use for the model. compile_models: Whether to compile the models. diff --git a/library/src/instantlearn/models/soft_matcher.py b/library/src/instantlearn/models/soft_matcher.py index 4c7b0485e..57b058b42 100644 --- a/library/src/instantlearn/models/soft_matcher.py +++ b/library/src/instantlearn/models/soft_matcher.py @@ -70,7 +70,7 @@ class SoftMatcher(Matcher): def __init__( self, - sam: SAMModelName = SAMModelName.SAM_HQ_TINY, + sam: str = Matcher.DEFAULT_SAM, num_foreground_points: int = 40, num_background_points: int = 2, confidence_threshold: float | None = 0.42, @@ -88,7 +88,7 @@ def __init__( """Initialize the SoftMatcher model. Args: - sam: The name of the SAM model to use. + sam: Model ID for the segmenter (e.g., "sam-hq-tiny", "sam-hq"). num_foreground_points: The number of foreground points to use. num_background_points: The number of background points to use. confidence_threshold: Minimum confidence score for keeping predicted masks @@ -98,7 +98,7 @@ def __init__( approximate_matching: Whether to use approximate matching. softmatching_score_threshold: The score threshold for the soft matching. softmatching_bidirectional: Whether to use bidirectional soft matching. - encoder_model: The encoder model to use. + encoder_model: Model ID for the encoder (e.g., "dinov3-large"). precision: The precision to use for the model. compile_models: Whether to compile the models. device: The device to use for the model. diff --git a/library/src/instantlearn/registry.py b/library/src/instantlearn/registry.py new file mode 100644 index 000000000..948c47096 --- /dev/null +++ b/library/src/instantlearn/registry.py @@ -0,0 +1,334 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Centralized model registry - single source of truth for available models.""" + +from dataclasses import dataclass +from enum import StrEnum + + +class ModelType(StrEnum): + """Model type classification.""" + + ENCODER = "encoder" + SEGMENTER = "segmenter" + TRACKER = "tracker" + + +class Modality(StrEnum): + """Input/output modalities.""" + + IMAGE = "image" + VIDEO = "video" + TEXT = "text" + + +class PromptType(StrEnum): + """Supported prompt types.""" + + POINT = "point" + BOX = "box" + MASK = "mask" + TEXT = "text" + IMAGE = "image" + + +class Capability(StrEnum): + """Model capabilities.""" + + ENCODING = "encoding" + SEGMENTATION = "segmentation" + TRACKING = "tracking" + DESCRIPTION = "description" + + +@dataclass(frozen=True) +class ModelMetadata: + """Metadata describing a single model in the registry.""" + + id: str + type: ModelType + family: str + size: str + modalities: tuple[Modality, ...] + prompts: tuple[PromptType, ...] + capabilities: tuple[Capability, ...] + # Internal details (not exposed to API by default) + weights_url: str | None = None + hf_model_id: str | None = None + hf_revision: str | None = None # HuggingFace model revision/commit hash + config_filename: str | None = None + sha_sum: str | None = None # SHA256 checksum for download verification + + +# ============================================================================= +# MODEL REGISTRY - SINGLE SOURCE OF TRUTH +# ============================================================================= + +MODEL_REGISTRY: tuple[ModelMetadata, ...] = ( + # ------------------------------------------------------------------------- + # Segmenters (SAM family) + # ------------------------------------------------------------------------- + ModelMetadata( + id="sam-hq", + type=ModelType.SEGMENTER, + family="SAM-HQ", + size="base", + modalities=(Modality.IMAGE,), + prompts=(PromptType.POINT, PromptType.BOX, PromptType.MASK), + capabilities=(Capability.SEGMENTATION,), + weights_url="https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", + sha_sum="a7ac14a085326d9fa6199c8c698c4f0e7280afdbb974d2c4660ec60877b45e35", + ), + ModelMetadata( + id="sam-hq-tiny", + type=ModelType.SEGMENTER, + family="SAM-HQ", + size="tiny", + modalities=(Modality.IMAGE,), + prompts=(PromptType.POINT, PromptType.BOX, PromptType.MASK), + capabilities=(Capability.SEGMENTATION,), + weights_url="https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth", + sha_sum="0f32c075ccdd870ae54db2f7630e7a0878ede5a2b06d05d6fe02c65a82fb7196", + ), + # SAM2 family + ModelMetadata( + id="sam2-tiny", + type=ModelType.SEGMENTER, + family="SAM2", + size="tiny", + modalities=(Modality.IMAGE,), + prompts=(PromptType.POINT, PromptType.BOX, PromptType.MASK), + capabilities=(Capability.SEGMENTATION,), + weights_url="https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", + config_filename="sam2.1_hiera_t.yaml", + sha_sum="7402e0d864fa82708a20fbd15bc84245c2f26dff0eb43a4b5b93452deb34be69", + ), + ModelMetadata( + id="sam2-small", + type=ModelType.SEGMENTER, + family="SAM2", + size="small", + modalities=(Modality.IMAGE,), + prompts=(PromptType.POINT, PromptType.BOX, PromptType.MASK), + capabilities=(Capability.SEGMENTATION,), + weights_url="https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", + config_filename="sam2.1_hiera_s.yaml", + sha_sum="6d1aa6f30de5c92224f8172114de081d104bbd23dd9dc5c58996f0cad5dc4d38", + ), + ModelMetadata( + id="sam2-base", + type=ModelType.SEGMENTER, + family="SAM2", + size="base", + modalities=(Modality.IMAGE,), + prompts=(PromptType.POINT, PromptType.BOX, PromptType.MASK), + capabilities=(Capability.SEGMENTATION,), + weights_url="https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", + config_filename="sam2.1_hiera_b+.yaml", + sha_sum="a2345aede8715ab1d5d31b4a509fb160c5a4af1970f199d9054ccfb746c004c5", + ), + ModelMetadata( + id="sam2-large", + type=ModelType.SEGMENTER, + family="SAM2", + size="large", + modalities=(Modality.IMAGE,), + prompts=(PromptType.POINT, PromptType.BOX, PromptType.MASK), + capabilities=(Capability.SEGMENTATION,), + weights_url="https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", + config_filename="sam2.1_hiera_l.yaml", + sha_sum="2647878d5dfa5098f2f8649825738a9345572bae2d4350a2468587ece47dd318", + ), + # ------------------------------------------------------------------------- + # Encoders (DINOv2 family - HuggingFace) + # ------------------------------------------------------------------------- + ModelMetadata( + id="dinov2-small", + type=ModelType.ENCODER, + family="DINOv2", + size="small", + modalities=(Modality.IMAGE,), + prompts=(), + capabilities=(Capability.ENCODING,), + hf_model_id="facebook/dinov2-with-registers-small", + hf_revision="0d9846e56b43a21fa46d7f3f5070f0506a5795a9", + ), + ModelMetadata( + id="dinov2-base", + type=ModelType.ENCODER, + family="DINOv2", + size="base", + modalities=(Modality.IMAGE,), + prompts=(), + capabilities=(Capability.ENCODING,), + hf_model_id="facebook/dinov2-with-registers-base", + hf_revision="a1d738ccfa7ae170945f210395d99dde8adb1805", + ), + ModelMetadata( + id="dinov2-large", + type=ModelType.ENCODER, + family="DINOv2", + size="large", + modalities=(Modality.IMAGE,), + prompts=(), + capabilities=(Capability.ENCODING,), + hf_model_id="facebook/dinov2-with-registers-large", + hf_revision="e4c89a4e05589de9b3e188688a303d0f3c04d0f3", + ), + ModelMetadata( + id="dinov2-giant", + type=ModelType.ENCODER, + family="DINOv2", + size="giant", + modalities=(Modality.IMAGE,), + prompts=(), + capabilities=(Capability.ENCODING,), + hf_model_id="facebook/dinov2-with-registers-giant", + hf_revision="8d0d49f77fb8b5dd78842496ff14afe7dd4d85cb", + ), + # ------------------------------------------------------------------------- + # Encoders (DINOv3 family - TIMM) + # ------------------------------------------------------------------------- + ModelMetadata( + id="dinov3-small", + type=ModelType.ENCODER, + family="DINOv3", + size="small", + modalities=(Modality.IMAGE,), + prompts=(), + capabilities=(Capability.ENCODING,), + hf_model_id="timm/vit_small_patch16_dinov3.lvd1689m", + ), + ModelMetadata( + id="dinov3-small-plus", + type=ModelType.ENCODER, + family="DINOv3", + size="small-plus", + modalities=(Modality.IMAGE,), + prompts=(), + capabilities=(Capability.ENCODING,), + hf_model_id="timm/vit_small_plus_patch16_dinov3.lvd1689m", + ), + ModelMetadata( + id="dinov3-base", + type=ModelType.ENCODER, + family="DINOv3", + size="base", + modalities=(Modality.IMAGE,), + prompts=(), + capabilities=(Capability.ENCODING,), + hf_model_id="timm/vit_base_patch16_dinov3.lvd1689m", + ), + ModelMetadata( + id="dinov3-large", + type=ModelType.ENCODER, + family="DINOv3", + size="large", + modalities=(Modality.IMAGE,), + prompts=(), + capabilities=(Capability.ENCODING,), + hf_model_id="timm/vit_large_patch16_dinov3.lvd1689m", + ), + ModelMetadata( + id="dinov3-huge", + type=ModelType.ENCODER, + family="DINOv3", + size="huge", + modalities=(Modality.IMAGE,), + prompts=(), + capabilities=(Capability.ENCODING,), + hf_model_id="timm/vit_huge_plus_patch16_dinov3.lvd1689m", + ), + # ------------------------------------------------------------------------- + # Future: Trackers (SAM3 family) + # ------------------------------------------------------------------------- + # ModelMetadata( + # id="sam3-large", + # type=ModelType.TRACKER, + # family="SAM3", + # size="large", + # modalities=(Modality.IMAGE, Modality.VIDEO), + # prompts=(PromptType.POINT, PromptType.BOX, PromptType.MASK), + # capabilities=(Capability.SEGMENTATION, Capability.TRACKING), + # ), +) + + +# ============================================================================= +# HELPER FUNCTIONS +# ============================================================================= + + +def get_model(model_id: str) -> ModelMetadata | None: + """Get a model by ID.""" + return next((m for m in MODEL_REGISTRY if m.id == model_id), None) + + +def get_models_by_type(model_type: ModelType) -> list[ModelMetadata]: + """Get all models of a specific type.""" + return [m for m in MODEL_REGISTRY if m.type == model_type] + + +def get_models_by_capability(capability: Capability) -> list[ModelMetadata]: + """Get all models with a specific capability.""" + return [m for m in MODEL_REGISTRY if capability in m.capabilities] + + +def get_models_by_family(family: str) -> list[ModelMetadata]: + """Get all models from a specific family.""" + return [m for m in MODEL_REGISTRY if m.family == family] + + +def get_available_types() -> list[str]: + """Get list of available model types.""" + return list({m.type.value for m in MODEL_REGISTRY}) + + +def get_available_families() -> list[str]: + """Get list of available model families.""" + return list({m.family for m in MODEL_REGISTRY}) + + +def get_available_sizes() -> list[str]: + """Get list of available model sizes.""" + return list({m.size for m in MODEL_REGISTRY}) + + +def get_available_capabilities() -> list[str]: + """Get list of available capabilities.""" + return list({cap.value for m in MODEL_REGISTRY for cap in m.capabilities}) + + +def is_valid_model(model_id: str, model_type: ModelType | None = None) -> bool: + """Check if a model ID is valid, optionally filtered by type.""" + model = get_model(model_id) + if model is None: + return False + if model_type is not None and model.type != model_type: + return False + return True + + +def get_local_filename(model_id: str) -> str: + """Get the local filename for a model's weights. + + Derives filename from the weights URL. + + Args: + model_id: The model ID. + + Returns: + The local filename for storing model weights. + + Raises: + ValueError: If model not found or has no weights_url. + """ + model = get_model(model_id) + if model is None: + msg = f"Model '{model_id}' not found in registry" + raise ValueError(msg) + if model.weights_url is None: + msg = f"Model '{model_id}' has no weights_url" + raise ValueError(msg) + return model.weights_url.split("/")[-1] diff --git a/library/src/instantlearn/scripts/benchmark.py b/library/src/instantlearn/scripts/benchmark.py index 0167e072e..cdf4951a3 100644 --- a/library/src/instantlearn/scripts/benchmark.py +++ b/library/src/instantlearn/scripts/benchmark.py @@ -150,7 +150,7 @@ def predict_on_dataset( output_path: Path, dataset_name: str, model_name: str, - backbone_name: str, + sam_model_id: str, number_of_priors_tests: int, device: torch.device, ) -> pl.DataFrame: @@ -163,7 +163,7 @@ def predict_on_dataset( output_path: Output path dataset_name: The dataset name model_name: The algorithm name - backbone_name: The model name + sam_model_id: The SAM model ID (e.g., "sam-hq-tiny") number_of_priors_tests: The number of priors to try device: The device to use. @@ -260,7 +260,7 @@ def predict_on_dataset( metrics["inference_time"] = [time_sum / time_count if time_count > 0 else 0] * ln metrics["dataset_name"] = [dataset_name] * ln metrics["model_name"] = [model_name] * ln - metrics["backbone_name"] = [backbone_name] * ln + metrics["sam_model_id"] = [sam_model_id] * ln if all_metrics is None: all_metrics = metrics @@ -400,16 +400,13 @@ def perform_benchmark_experiment(args: Namespace | None = None) -> None: final_results_path.mkdir(parents=True, exist_ok=True) # Get experiment lists and generate a plan - datasets_to_run, models_to_run, backbones_to_run = parse_experiment_args(args) - experiments = list(itertools.product(datasets_to_run, models_to_run, backbones_to_run)) + datasets_to_run, models_to_run, sam_models_to_run = parse_experiment_args(args) + experiments = list(itertools.product(datasets_to_run, models_to_run, sam_models_to_run)) # Execute experiments all_results = [] - for dataset_enum, model_enum, backbone_enum in experiments: - msg = ( - f"Starting experiment with Dataset={dataset_enum.value}, " - f"Model={model_enum.value}, Backbone={backbone_enum.value}", - ) + for dataset_enum, model_enum, sam_model in experiments: + msg = (f"Starting experiment with Dataset={dataset_enum.value}, Model={model_enum.value}, SAM={sam_model}",) logger.info(msg) # Parse categories from CLI argument @@ -439,7 +436,7 @@ def perform_benchmark_experiment(args: Namespace | None = None) -> None: args.experiment_name, dataset_enum, model_enum, - backbone_enum, + sam_model, ) all_metrics_df = predict_on_dataset( @@ -449,7 +446,7 @@ def perform_benchmark_experiment(args: Namespace | None = None) -> None: output_path=output_path, dataset_name=dataset_enum.value, model_name=model_enum.value, - backbone_name=backbone_enum.value, + sam_model_id=sam_model, number_of_priors_tests=args.num_priors, device=args.device, ) diff --git a/library/src/instantlearn/scripts/run.py b/library/src/instantlearn/scripts/run.py index aaf30d79b..fe36f64f8 100644 --- a/library/src/instantlearn/scripts/run.py +++ b/library/src/instantlearn/scripts/run.py @@ -65,7 +65,6 @@ def run_model( Raises: ValueError: If the dataset is not found or invalid, or if required parameters are missing. - FileNotFoundError: If the dataset is not found. """ # Check if model is a Grounding model is_grounding_model = isinstance(model, GroundedSAM) diff --git a/library/src/instantlearn/utils/args.py b/library/src/instantlearn/utils/args.py index d91b99da2..083c0fbb4 100644 --- a/library/src/instantlearn/utils/args.py +++ b/library/src/instantlearn/utils/args.py @@ -11,14 +11,14 @@ from instantlearn.components.prompt_generators import GroundingModel from instantlearn.utils.constants import DatasetName, ModelName, SAMModelName -# Generate help strings with choices -AVAILABLE_SAM_MODELS = ", ".join([model.value for model in SAMModelName]) +# Generate help strings with choices from registry +AVAILABLE_SAM_MODELS = ", ".join([m.id for m in get_models_by_type(ModelType.SEGMENTER)]) +AVAILABLE_ENCODER_MODELS = ", ".join([m.id for m in get_models_by_type(ModelType.ENCODER)]) AVAILABLE_MODELS = ", ".join([p.value for p in ModelName]) AVAILABLE_DATASETS = ", ".join([d.value for d in DatasetName]) HELP_SAM_ARG_MSG = ( - f"Backbone segmentation model name or " - f"comma-separated list. Use 'all' to run all. Available: [{AVAILABLE_SAM_MODELS}]" + f"Backbone segmentation model ID or comma-separated list. Use 'all' to run all. Available: [{AVAILABLE_SAM_MODELS}]" ) @@ -32,8 +32,8 @@ def populate_benchmark_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--sam", type=str, - default="SAM-HQ-tiny", - choices=["all"] + [model.value for model in SAMModelName], + default="sam-hq-tiny", + choices=["all"] + [m.id for m in get_models_by_type(ModelType.SEGMENTER)], help=HELP_SAM_ARG_MSG, ) parser.add_argument( @@ -175,9 +175,9 @@ def populate_benchmark_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--encoder_model", type=str, - default="dinov3_large", - choices=list(AVAILABLE_IMAGE_ENCODERS), - help="ImageEncoder model id", + default="dinov3-large", + choices=[m.id for m in get_models_by_type(ModelType.ENCODER)], + help="ImageEncoder model ID", ) parser.add_argument( "--grounding_model", @@ -271,7 +271,34 @@ def _parse_enum_list(arg_str: str, enum_cls: type[TEnum], arg_name: str) -> list return [enum_cls(item) for item in items_to_run] -def parse_experiment_args(args: argparse.Namespace) -> tuple[list[DatasetName], list[ModelName], list[SAMModelName]]: +def _parse_sam_model_list(arg_str: str) -> list[str]: + """Parse a comma-separated string of SAM model IDs. + + Args: + arg_str: Comma-separated string of model IDs or "all" + + Returns: + List of valid model IDs + + Raises: + ValueError: If any provided value is not a valid model ID + """ + valid_ids = {m.id for m in get_models_by_type(ModelType.SEGMENTER)} + + if arg_str == "all": + return list(valid_ids) + + items_to_run = [p.strip() for p in arg_str.split(",")] + + invalid_items = [item for item in items_to_run if item not in valid_ids] + if invalid_items: + msg = f"Invalid SAM model(s): {invalid_items}. Available: {sorted(valid_ids)}" + raise ValueError(msg) + + return items_to_run + + +def parse_experiment_args(args: argparse.Namespace) -> tuple[list[DatasetName], list[ModelName], list[str]]: """Parse experiment arguments. Args: @@ -281,14 +308,14 @@ def parse_experiment_args(args: argparse.Namespace) -> tuple[list[DatasetName], tuple containing: - datasets_to_run: List of dataset enums to run - models_to_run: List of model enums to run - - backbones_to_run: List of SAM model enums to run + - backbones_to_run: List of SAM model IDs (strings) to run Raises: ValueError: If any invalid arguments are provided or if no valid arguments remain after filtering """ valid_datasets = _parse_enum_list(args.dataset_name, DatasetName, "dataset") valid_models = _parse_enum_list(args.model, ModelName, "model") - valid_backbones = _parse_enum_list(args.sam, SAMModelName, "SAM model") + valid_backbones = _parse_sam_model_list(args.sam) if not valid_datasets: msg = f"No valid datasets found from '{args.dataset_name}'. Available: {[d.value for d in DatasetName]}" @@ -297,7 +324,8 @@ def parse_experiment_args(args: argparse.Namespace) -> tuple[list[DatasetName], msg = f"No valid models found from '{args.model}'. Available: {[m.value for m in ModelName]}" raise ValueError(msg) if not valid_backbones: - msg = f"No valid SAM models found from '{args.sam}'. Available: {[m.value for m in SAMModelName]}" + valid_ids = [m.id for m in get_models_by_type(ModelType.SEGMENTER)] + msg = f"No valid SAM models found from '{args.sam}'. Available: {valid_ids}" raise ValueError(msg) return valid_datasets, valid_models, valid_backbones diff --git a/library/src/instantlearn/utils/benchmark.py b/library/src/instantlearn/utils/benchmark.py index d31962794..6b053354a 100644 --- a/library/src/instantlearn/utils/benchmark.py +++ b/library/src/instantlearn/utils/benchmark.py @@ -52,7 +52,7 @@ def _get_output_path_for_experiment( experiment_name: str | None, dataset: DatasetName, model: ModelName, - backbone: SAMModelName, + sam_model_id: str, ) -> Path: """Construct a unique output path for an experiment. @@ -61,12 +61,12 @@ def _get_output_path_for_experiment( experiment_name: The name of the experiment dataset: The dataset to run model: The model to run - backbone: The backbone to run + sam_model_id: The SAM model ID (e.g., "sam-hq-tiny") Returns: The path to save the results """ - combo_str = f"{dataset.value}_{backbone.value}_{model.value}" + combo_str = f"{dataset.value}_{sam_model_id}_{model.value}" if experiment_name: return output_path / experiment_name / combo_str @@ -95,7 +95,7 @@ def _save_results(all_results: list[pl.DataFrame], output_path: Path) -> None: avg_results_dataframe_filename = output_path / "avg_results.csv" avg_results_dataframe_filename.parent.mkdir(parents=True, exist_ok=True) avg_result_dataframe = all_result_dataframe.group_by( - ["dataset_name", "model_name", "backbone_name"], + ["dataset_name", "model_name", "sam_model_id"], ).mean() avg_result_dataframe.write_csv(str(avg_results_dataframe_filename)) msg = f"Saved average results to: {avg_results_dataframe_filename}" @@ -160,11 +160,11 @@ def convert_masks_to_one_hot_tensor( return batch_pred_tensors, batch_gt_tensors -def load_model(sam: SAMModelName, model_name: ModelName, args: Namespace) -> Model: +def load_model(sam: str, model_name: ModelName, args: Namespace) -> Model: """Instantiate and return the requested model. Args: - sam: The name of the SAM model. + sam: The SAM model ID (e.g., "sam-hq-tiny"). model_name: The name of the model. args: The arguments to the model. diff --git a/library/src/instantlearn/utils/constants.py b/library/src/instantlearn/utils/constants.py index 24d554261..c39effcd1 100644 --- a/library/src/instantlearn/utils/constants.py +++ b/library/src/instantlearn/utils/constants.py @@ -19,17 +19,6 @@ class Backend(Enum): TIMM = "timm" -class SAMModelName(StrEnum): - """Enum for SAM model types.""" - - SAM_HQ = "SAM-HQ" - SAM_HQ_TINY = "SAM-HQ-tiny" - SAM2_TINY = "SAM2-tiny" - SAM2_SMALL = "SAM2-small" - SAM2_BASE = "SAM2-base" - SAM2_LARGE = "SAM2-large" - - class ModelName(Enum): """Enum for model types.""" @@ -70,49 +59,6 @@ class DINOv3BackboneSize(Enum): DINOv3BackboneSize.HUGE.value: "dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth", } -MODEL_MAP = { - SAMModelName.SAM2_TINY: { # 1024x1024 input resolution - "registry_name": "vit_t", - "local_filename": "sam2.1_hiera_tiny.pt", - "config_filename": "sam2.1_hiera_t.yaml", - "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", - "sha_sum": "7402e0d864fa82708a20fbd15bc84245c2f26dff0eb43a4b5b93452deb34be69", - }, - SAMModelName.SAM2_SMALL: { # 1024x1024 input resolution - "registry_name": "vit_s", - "local_filename": "sam2.1_hiera_small.pt", - "config_filename": "sam2.1_hiera_s.yaml", - "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", - "sha_sum": "6d1aa6f30de5c92224f8172114de081d104bbd23dd9dc5c58996f0cad5dc4d38", - }, - SAMModelName.SAM2_BASE: { # 1024x1024 input resolution - "registry_name": "vit_b", - "local_filename": "sam2.1_hiera_base_plus.pt", - "config_filename": "sam2.1_hiera_b+.yaml", - "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", - "sha_sum": "a2345aede8715ab1d5d31b4a509fb160c5a4af1970f199d9054ccfb746c004c5", - }, - SAMModelName.SAM2_LARGE: { # 1024x1024 input resolution - "registry_name": "vit_l", - "local_filename": "sam2.1_hiera_large.pt", - "config_filename": "sam2.1_hiera_l.yaml", - "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", - "sha_sum": "2647878d5dfa5098f2f8649825738a9345572bae2d4350a2468587ece47dd318", - }, - SAMModelName.SAM_HQ: { # 1024x1024 input resolution - "registry_name": "vit_h", - "local_filename": "sam_hq_vit_h.pth", - "download_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", - "sha_sum": "a7ac14a085326d9fa6199c8c698c4f0e7280afdbb974d2c4660ec60877b45e35", - }, - SAMModelName.SAM_HQ_TINY: { # 1024x1024 input resolution - "registry_name": "vit_tiny", - "local_filename": "sam_hq_vit_tiny.pth", - "download_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth", - "sha_sum": "0f32c075ccdd870ae54db2f7630e7a0878ede5a2b06d05d6fe02c65a82fb7196", - }, -} - IMAGE_EXTENSIONS = ("*.jpg", "*.jpeg", "*.png", "*.webp") diff --git a/library/src/instantlearn/utils/utils.py b/library/src/instantlearn/utils/utils.py index 33e31b8a7..5e807521c 100644 --- a/library/src/instantlearn/utils/utils.py +++ b/library/src/instantlearn/utils/utils.py @@ -101,7 +101,6 @@ def precision_to_openvino_type(precision: str) -> ov.Type: OpenVINO Type (ov.Type.f32, ov.Type.f16, etc.) Raises: - ImportError: If openvino is not installed ValueError: If precision is not supported """ precision_map = { diff --git a/library/tests/integration/models/test_model_integration.py b/library/tests/integration/models/test_model_integration.py index 088109232..dc20ca9fc 100644 --- a/library/tests/integration/models/test_model_integration.py +++ b/library/tests/integration/models/test_model_integration.py @@ -69,8 +69,8 @@ def aerial_maritime_root() -> Path: ModelName.SOFT_MATCHER: SoftMatcher, } -# SAM models to test (SAM3 doesn't use SAM backend, will be handled separately) -SAM_MODELS = [SAMModelName.SAM_HQ_TINY, SAMModelName.SAM2_TINY] +# SAM models to test (string model IDs from registry) +SAM_MODELS = ["sam-hq-tiny", "sam2-tiny"] # Models that support n-shots (all except GroundedSAM and SAM3) N_SHOT_SUPPORTED_MODELS = [ModelName.MATCHER, ModelName.PER_DINO, ModelName.SOFT_MATCHER] @@ -83,13 +83,13 @@ class TestModelIntegration: @pytest.mark.parametrize("model_name", ModelName) def test_model_initialization( self, - sam_model: SAMModelName, + sam_model: str, model_name: ModelName, ) -> None: """Test that models can be initialized with different SAM backends. Args: - sam_model: The SAM model to use. + sam_model: The SAM model ID (e.g., "sam-hq-tiny"). model_name: The model type to test. """ # Skip SAM3 as it doesn't use SAM backend @@ -97,8 +97,8 @@ def test_model_initialization( pytest.skip("SAM3 doesn't use SAM backend, tested separately") # TODO(Eugene): SAM2 is currently not supported due to a bug in the SAM2 model. - # https://github.com/open-edge-platform/instant-learn/issues/367 - if sam_model == SAMModelName.SAM2_TINY: + # https://github.com/open-edge-platform/instant.learn/issues/367 + if sam_model == "sam2-tiny": pytest.skip("Skipping test_model_initialization for SAM2-tiny") model_class = MODEL_CLASSES[model_name] @@ -107,7 +107,7 @@ def test_model_initialization( if model_name == ModelName.GROUNDED_SAM: model = model_class(sam=sam_model, device="cpu", precision="fp32") else: - model = model_class(sam=sam_model, device="cpu", precision="fp32", encoder_model="dinov3_small") + model = model_class(sam=sam_model, device="cpu", precision="fp32", encoder_model="dinov3-small") assert model is not None assert hasattr(model, "fit") @@ -119,7 +119,7 @@ def test_model_initialization( @pytest.mark.parametrize("model_name", ModelName) def test_model_fit_predict( self, - sam_model: SAMModelName, + sam_model: str, model_name: ModelName, reference_batch: Batch, target_batch: Batch, @@ -127,7 +127,7 @@ def test_model_fit_predict( """Test that models can learn from reference data and infer on target data. Args: - sam_model: The SAM model to use. + sam_model: The SAM model ID (e.g., "sam-hq-tiny"). model_name: The model type to test. reference_batch: Batch of reference samples. target_batch: Batch of target samples. @@ -137,8 +137,8 @@ def test_model_fit_predict( pytest.skip("SAM3 doesn't use SAM backend, tested separately") # TODO(Eugene): SAM2 is currently not supported due to a bug in the SAM2 model. - # https://github.com/open-edge-platform/instant-learn/issues/367 - if sam_model == SAMModelName.SAM2_TINY: + # https://github.com/open-edge-platform/instant.learn/issues/367 + if sam_model == "sam2-tiny": pytest.skip("Skipping test_model_learn_infer for SAM2-tiny") model_class = MODEL_CLASSES[model_name] @@ -147,7 +147,7 @@ def test_model_fit_predict( if model_name == ModelName.GROUNDED_SAM: model = model_class(sam=sam_model, device="cpu", precision="fp32") else: - model = model_class(sam=sam_model, device="cpu", precision="fp32", encoder_model="dinov3_small") + model = model_class(sam=sam_model, device="cpu", precision="fp32", encoder_model="dinov3-small") # Test fit method model.fit(reference_batch) @@ -169,7 +169,7 @@ def test_model_fit_predict( @pytest.mark.parametrize("model_name", N_SHOT_SUPPORTED_MODELS) def test_n_shots_capability( self, - sam_model: SAMModelName, + sam_model: str, model_name: ModelName, fss1000_root: Path, ) -> None: @@ -179,14 +179,14 @@ def test_n_shots_capability( (n-shots > 1) and that the number of reference samples affects the results. Args: - sam_model: The SAM model to use. + sam_model: The SAM model ID (e.g., "sam-hq-tiny"). model_name: The model type to test (must support n-shots). fss1000_root: Path to fss-1000 dataset. """ # TODO(Eugene): SAM2 is currently not supported due to a bug in the SAM2 model. - # https://github.com/open-edge-platform/instant-learn/issues/367 - if sam_model == SAMModelName.SAM2_TINY or model_name == ModelName.SAM3: - pytest.skip("Skipping test_n_shots_capability for SAM2-tiny or SAM3") + # https://github.com/open-edge-platform/instant.learn/issues/367 + if sam_model == "sam2-tiny": + pytest.skip("Skipping test_n_shots_capability for SAM2-tiny") if not fss1000_root.exists(): pytest.skip("fss-1000 dataset not found") @@ -206,7 +206,7 @@ def test_n_shots_capability( sam=sam_model, device="cpu", precision="fp32", - encoder_model="dinov3_small", + encoder_model="dinov3-small", ) model_1shot.fit(ref_batch_1shot) predictions_1shot = model_1shot.predict(target_batch) @@ -226,7 +226,7 @@ def test_n_shots_capability( sam=sam_model, device="cpu", precision="fp32", - encoder_model="dinov3_small", + encoder_model="dinov3-small", ) model_2shot.fit(ref_batch_2shot) predictions_2shot = model_2shot.predict(target_batch_2shot) @@ -244,7 +244,7 @@ def test_n_shots_capability( @pytest.mark.parametrize("sam_model", SAM_MODELS) def test_grounded_sam_no_n_shots( self, - sam_model: SAMModelName, + sam_model: str, reference_batch: Batch, target_batch: Batch, ) -> None: @@ -254,13 +254,13 @@ def test_grounded_sam_no_n_shots( in the same way as other models. It only needs category mapping. Args: - sam_model: The SAM model to use. + sam_model: The SAM model ID (e.g., "sam-hq-tiny"). reference_batch: Batch of reference samples (for category mapping). target_batch: Batch of target samples. """ # TODO(Eugene): SAM2 is currently not supported due to a bug in the SAM2 model. - # https://github.com/open-edge-platform/instant-learn/issues/367 - if sam_model == SAMModelName.SAM2_TINY: + # https://github.com/open-edge-platform/instant.learn/issues/367 + if sam_model == "sam2-tiny": pytest.skip("Skipping test_model_input_validation for SAM2-tiny") model = GroundedSAM(sam=sam_model, device="cpu", precision="fp32") @@ -279,7 +279,7 @@ def test_grounded_sam_no_n_shots( @pytest.mark.parametrize("model_name", ModelName) def test_model_input_validation( self, - sam_model: SAMModelName, + sam_model: str, model_name: ModelName, reference_batch: Batch, target_batch: Batch, @@ -287,7 +287,7 @@ def test_model_input_validation( """Test that models validate inputs correctly. Args: - sam_model: The SAM model to use. + sam_model: The SAM model ID (e.g., "sam-hq-tiny"). model_name: The model type to test. reference_batch: Batch of reference samples. target_batch: Batch of target samples. @@ -297,8 +297,8 @@ def test_model_input_validation( pytest.skip("SAM3 doesn't use SAM backend, tested separately") # TODO(Eugene): SAM2 is currently not supported due to a bug in the SAM2 model. - # https://github.com/open-edge-platform/instant-learn/issues/367 - if sam_model == SAMModelName.SAM2_TINY: + # https://github.com/open-edge-platform/instant.learn/issues/367 + if sam_model == "sam2-tiny": pytest.skip("Skipping test_model_input_validation for SAM2-tiny") model_class = MODEL_CLASSES[model_name] @@ -307,7 +307,7 @@ def test_model_input_validation( if model_name == ModelName.GROUNDED_SAM: model = model_class(sam=sam_model, device="cpu", precision="fp32") else: - model = model_class(sam=sam_model, device="cpu", precision="fp32", encoder_model="dinov3_small") + model = model_class(sam=sam_model, device="cpu", precision="fp32", encoder_model="dinov3-small") # Validate that reference batch has required fields assert len(reference_batch) > 0 @@ -335,7 +335,7 @@ def test_model_input_validation( @pytest.mark.parametrize("model_name", ModelName) def test_model_metrics_calculation( self, - sam_model: SAMModelName, + sam_model: str, model_name: ModelName, dataset: FolderDataset, ) -> None: @@ -347,7 +347,7 @@ def test_model_metrics_calculation( 3. Metrics have valid values (within expected ranges) Args: - sam_model: The SAM model to use. + sam_model: The SAM model ID (e.g., "sam-hq-tiny"). model_name: The model type to test. dataset: The dataset to use for testing. """ @@ -356,8 +356,8 @@ def test_model_metrics_calculation( pytest.skip("SAM3 doesn't use SAM backend, tested separately") # TODO(Eugene): SAM2 is currently not supported due to a bug in the SAM2 model. - # https://github.com/open-edge-platform/instant-learn/issues/367 - if sam_model == SAMModelName.SAM2_TINY: + # https://github.com/open-edge-platform/instant.learn/issues/367 + if sam_model == "sam2-tiny": pytest.skip("Skipping test_model_metrics_calculation for SAM2-tiny") model_class = MODEL_CLASSES[model_name] @@ -366,7 +366,7 @@ def test_model_metrics_calculation( if model_name == ModelName.GROUNDED_SAM: model = model_class(sam=sam_model, device="cpu", precision="fp32") else: - model = model_class(sam=sam_model, device="cpu", precision="fp32", encoder_model="dinov3_small") + model = model_class(sam=sam_model, device="cpu", precision="fp32", encoder_model="dinov3-small") # Get reference and target samples for first category categories = dataset.categories diff --git a/library/tests/unit/components/encoders/test_encoder.py b/library/tests/unit/components/encoders/test_encoder.py index fda4a4c29..33e653964 100644 --- a/library/tests/unit/components/encoders/test_encoder.py +++ b/library/tests/unit/components/encoders/test_encoder.py @@ -17,6 +17,10 @@ ) from instantlearn.utils.constants import Backend +# Get encoder model IDs from registry, filtered by family for HuggingFace vs TIMM +HUGGINGFACE_ENCODER_IDS = [m.id for m in get_models_by_type(ModelType.ENCODER) if m.family == "DINOv2"] +TIMM_ENCODER_IDS = [m.id for m in get_models_by_type(ModelType.ENCODER) if m.family == "DINOv3"] + class TestEncoder: """Test the ImageEncoder class.""" @@ -115,11 +119,11 @@ def test_encoder_initialization( if backend == Backend.HUGGINGFACE: mock_model_instance = self._setup_mock_hf_model(mock_model, mock_processor) mock_optimize.return_value = mock_model_instance - model_id = "dinov2_small" + model_id = "dinov2-small" else: # TIMM mock_model_instance = self._setup_mock_timm_model(mock_timm_create, mock_timm_data_config) mock_optimize.return_value = mock_model_instance - model_id = "dinov3_small" + model_id = "dinov3-small" # Create encoder encoder = ImageEncoder(model_id=model_id, backend=backend, device="cpu", input_size=expected_input_size) @@ -148,11 +152,11 @@ def test_call_without_priors( if backend == Backend.HUGGINGFACE: mock_model_instance = self._setup_mock_hf_model(mock_model, mock_processor) mock_optimize.return_value = mock_model_instance - model_id = "dinov2_small" + model_id = "dinov2-small" else: # TIMM mock_model_instance = self._setup_mock_timm_model(mock_timm_create, mock_timm_data_config) mock_optimize.return_value = mock_model_instance - model_id = "dinov3_small" + model_id = "dinov3-small" # Create encoder encoder = ImageEncoder(model_id=model_id, backend=backend, device="cpu", input_size=224) @@ -178,7 +182,7 @@ def test_model_id_validation(backend: Backend) -> None: @staticmethod def test_valid_model_ids_huggingface() -> None: """Test that all valid HuggingFace model IDs are accepted.""" - for model_id in AVAILABLE_IMAGE_ENCODERS: + for model_id in HUGGINGFACE_ENCODER_IDS: with ( patch("instantlearn.utils.optimization.optimize_model") as mock_optimize, patch("instantlearn.components.encoders.huggingface.AutoModel") as mock_model, @@ -202,7 +206,7 @@ def test_valid_model_ids_huggingface() -> None: @staticmethod def test_valid_model_ids_timm() -> None: """Test that all valid TIMM model IDs are accepted.""" - for model_id in TIMM_AVAILABLE_IMAGE_ENCODERS: + for model_id in TIMM_ENCODER_IDS: with ( patch("instantlearn.utils.optimization.optimize_model") as mock_optimize, patch("instantlearn.components.encoders.timm.timm.create_model") as mock_timm_create, @@ -248,11 +252,11 @@ def test_encoder_with_different_input_sizes( if backend == Backend.HUGGINGFACE: mock_model_instance = self._setup_mock_hf_model(mock_model, mock_processor) mock_optimize.return_value = mock_model_instance - model_id = "dinov2_small" + model_id = "dinov2-small" else: # TIMM mock_model_instance = self._setup_mock_timm_model(mock_timm_create, mock_timm_data_config) mock_optimize.return_value = mock_model_instance - model_id = "dinov3_small" + model_id = "dinov3-small" # Test with different input sizes for input_size in [224, 384, 512]: @@ -278,11 +282,11 @@ def test_encoder_with_different_precisions( if backend == Backend.HUGGINGFACE: mock_model_instance = self._setup_mock_hf_model(mock_model, mock_processor) mock_optimize.return_value = mock_model_instance - model_id = "dinov2_small" + model_id = "dinov2-small" else: # TIMM mock_model_instance = self._setup_mock_timm_model(mock_timm_create, mock_timm_data_config) mock_optimize.return_value = mock_model_instance - model_id = "dinov3_small" + model_id = "dinov3-small" # Test with different precision settings precision_mapping = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} @@ -315,11 +319,11 @@ def test_encoder_with_compile_models( if backend == Backend.HUGGINGFACE: mock_model_instance = self._setup_mock_hf_model(mock_model, mock_processor) mock_optimize.return_value = mock_model_instance - model_id = "dinov2_small" + model_id = "dinov2-small" else: # TIMM mock_model_instance = self._setup_mock_timm_model(mock_timm_create, mock_timm_data_config) mock_optimize.return_value = mock_model_instance - model_id = "dinov3_small" + model_id = "dinov3-small" # Test with compile_models=True encoder = ImageEncoder(model_id=model_id, backend=backend, device="cpu", compile_models=True, input_size=224) @@ -348,11 +352,11 @@ def test_encoder_device_handling( if backend == Backend.HUGGINGFACE: mock_model_instance = self._setup_mock_hf_model(mock_model, mock_processor) mock_optimize.return_value = mock_model_instance - model_id = "dinov2_small" + model_id = "dinov2-small" else: # TIMM mock_model_instance = self._setup_mock_timm_model(mock_timm_create, mock_timm_data_config) mock_optimize.return_value = mock_model_instance - model_id = "dinov3_small" + model_id = "dinov3-small" # Test with different devices for device in ["cpu", "cuda"]: @@ -377,7 +381,7 @@ def test_huggingface_access_error(mock_model: Mock) -> None: mock_model.from_pretrained.side_effect = OSError("gated repo access denied") with pytest.raises(ValueError, match="User does not have access"): - ImageEncoder(model_id="dinov2_small", backend=Backend.HUGGINGFACE, device="cpu") + ImageEncoder(model_id="dinov2-small", backend=Backend.HUGGINGFACE, device="cpu") class TestEncoderIntegration: @@ -409,7 +413,7 @@ def test_forward_with_real_model_comprehensive(backend: Backend) -> None: pytest.assume(embeddings.shape[1] == expected_patches) # Embedding dimension depends on model if backend == Backend.HUGGINGFACE: - pytest.assume(embeddings.shape[2] == 384) # dinov2_small has 384 dims + pytest.assume(embeddings.shape[2] == 384) # dinov2-small has 384 dims else: # TIMM pytest.assume(embeddings.shape[2] > 0) # Just check it's positive @@ -426,10 +430,10 @@ def test_model_configuration_validation(backend: Backend) -> None: """Test that real model configuration is properly loaded for both backends.""" expected_ignore_token_length = 5 # CLS token only if backend == Backend.HUGGINGFACE: - model_id = "dinov2_small" + model_id = "dinov2-small" expected_patch_size = 14 # DINOv2 small uses 14x14 patches else: # TIMM - model_id = "dinov3_small" + model_id = "dinov3-small" expected_patch_size = 16 # DINOv3 uses 16x16 patches encoder = ImageEncoder(model_id=model_id, backend=backend, device="cpu", input_size=224) diff --git a/library/tests/unit/data/test_folder_dataset.py b/library/tests/unit/data/test_folder_dataset.py index 149cc4e56..d1dddae2a 100644 --- a/library/tests/unit/data/test_folder_dataset.py +++ b/library/tests/unit/data/test_folder_dataset.py @@ -358,14 +358,14 @@ def test_get_reference_dataset(self, multi_category_dataset: FolderDataset) -> N ref_dataset = multi_category_dataset.get_reference_dataset() assert len(ref_dataset) == 3 # All should have at least one reference - assert ref_dataset.df["is_reference"].list.contains(item=True).all() + assert all(any(refs) for refs in ref_dataset.df["is_reference"].to_list()) def test_get_target_dataset(self, multi_category_dataset: FolderDataset) -> None: """Test getting target dataset.""" target_dataset = multi_category_dataset.get_target_dataset() assert len(target_dataset) == 6 - # All should have no reference instances - assert not target_dataset.df["is_reference"].list.contains(item=True).any() + # All should have no reference instances (all False) + assert all(not any(refs) for refs in target_dataset.df["is_reference"].to_list()) class TestFolderDatasetBatch: diff --git a/library/tests/unit/models/test_dinotxt.py b/library/tests/unit/models/test_dinotxt.py index 7f6e07d63..8897bd4c7 100644 --- a/library/tests/unit/models/test_dinotxt.py +++ b/library/tests/unit/models/test_dinotxt.py @@ -111,8 +111,7 @@ def test_pipeline_initialization_with_custom_params(mock_encoder_class: MagicMoc def test_learn_with_empty_reference_batch(model_instance: DinoTxtZeroShotClassification) -> None: """Test that fit raises ValueError when no reference samples are provided.""" with pytest.raises(ValueError, match="Cannot collate empty list of samples"): - empty_batch = Batch.collate([]) - model_instance.fit(empty_batch) + model_instance.fit(Batch.collate([])) @staticmethod def test_infer_without_learning( diff --git a/library/tests/unit/scripts/test_benchmark.py b/library/tests/unit/scripts/test_benchmark.py index 566dffa1e..15558fe97 100644 --- a/library/tests/unit/scripts/test_benchmark.py +++ b/library/tests/unit/scripts/test_benchmark.py @@ -144,7 +144,7 @@ def test_predict_on_dataset_single_model(self) -> None: output_path=Path(tempfile.mkdtemp()), dataset_name="lvis", model_name="Matcher", - backbone_name="SAM-HQ-tiny", + sam_model_id="sam-hq-tiny", number_of_priors_tests=1, device=torch.device("cpu"), ) @@ -168,7 +168,7 @@ def test_predict_on_dataset_error_handling(self) -> None: output_path=Path(tempfile.mkdtemp()), dataset_name="lvis", model_name="Matcher", - backbone_name="SAM-HQ-tiny", + sam_model_id="sam-hq-tiny", number_of_priors_tests=1, device=torch.device("cpu"), ) diff --git a/library/tests/unit/utils/test_model_registry.py b/library/tests/unit/utils/test_model_registry.py new file mode 100644 index 000000000..e26f746fe --- /dev/null +++ b/library/tests/unit/utils/test_model_registry.py @@ -0,0 +1,284 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for model_registry module.""" + +import pytest + +from instantlearn.registry import ( + MODEL_REGISTRY, + Capability, + Modality, + ModelMetadata, + ModelType, + PromptType, + get_available_capabilities, + get_available_families, + get_available_sizes, + get_available_types, + get_local_filename, + get_model, + get_models_by_capability, + get_models_by_family, + get_models_by_type, + is_valid_model, +) + + +class TestStrEnums: + """Tests for StrEnum classes.""" + + def test_model_type_is_str_enum(self) -> None: + """ModelType values can be compared with strings.""" + assert ModelType.ENCODER == "encoder" + assert ModelType.SEGMENTER == "segmenter" + assert ModelType.TRACKER == "tracker" + + def test_modality_is_str_enum(self) -> None: + """Modality values can be compared with strings.""" + assert Modality.IMAGE == "image" + assert Modality.VIDEO == "video" + assert Modality.TEXT == "text" + + def test_prompt_type_is_str_enum(self) -> None: + """PromptType values can be compared with strings.""" + assert PromptType.POINT == "point" + assert PromptType.BOX == "box" + assert PromptType.MASK == "mask" + assert PromptType.TEXT == "text" + assert PromptType.IMAGE == "image" + + def test_capability_is_str_enum(self) -> None: + """Capability values can be compared with strings.""" + assert Capability.ENCODING == "encoding" + assert Capability.SEGMENTATION == "segmentation" + assert Capability.TRACKING == "tracking" + assert Capability.DESCRIPTION == "description" + + def test_str_enum_conversion(self) -> None: + """StrEnum can be constructed from string.""" + assert ModelType("encoder") == ModelType.ENCODER + assert Capability("segmentation") == Capability.SEGMENTATION + + +class TestModelMetadata: + """Tests for ModelMetadata dataclass.""" + + def test_model_metadata_is_frozen(self) -> None: + """ModelMetadata instances are immutable.""" + model = get_model("sam-hq") + assert model is not None + with pytest.raises(AttributeError): + model.id = "new-id" # type: ignore[misc] + + def test_model_metadata_required_fields(self) -> None: + """ModelMetadata requires all non-optional fields.""" + metadata = ModelMetadata( + id="test-model", + type=ModelType.ENCODER, + family="Test", + size="small", + modalities=(Modality.IMAGE,), + prompts=(), + capabilities=(Capability.ENCODING,), + ) + assert metadata.id == "test-model" + assert metadata.weights_url is None + assert metadata.hf_model_id is None + + +class TestModelRegistry: + """Tests for MODEL_REGISTRY tuple.""" + + def test_registry_is_not_empty(self) -> None: + """Registry contains models.""" + assert len(MODEL_REGISTRY) > 0 + + def test_registry_is_tuple(self) -> None: + """Registry is immutable tuple.""" + assert isinstance(MODEL_REGISTRY, tuple) + + def test_all_entries_are_model_metadata(self) -> None: + """All registry entries are ModelMetadata instances.""" + for model in MODEL_REGISTRY: + assert isinstance(model, ModelMetadata) + + def test_model_ids_are_unique(self) -> None: + """All model IDs are unique.""" + ids = [m.id for m in MODEL_REGISTRY] + assert len(ids) == len(set(ids)), "Duplicate model IDs found" + + def test_registry_contains_expected_models(self) -> None: + """Registry contains expected model families.""" + families = {m.family for m in MODEL_REGISTRY} + assert "SAM-HQ" in families + assert "SAM2" in families + assert "DINOv2" in families + + +class TestGetModel: + """Tests for get_model function.""" + + def test_get_existing_model(self) -> None: + """get_model returns metadata for existing model.""" + model = get_model("sam-hq") + assert model is not None + assert model.id == "sam-hq" + assert model.family == "SAM-HQ" + + def test_get_nonexistent_model(self) -> None: + """get_model returns None for nonexistent model.""" + model = get_model("nonexistent-model") + assert model is None + + def test_get_model_case_sensitive(self) -> None: + """get_model is case-sensitive.""" + assert get_model("SAM-HQ") is None + assert get_model("sam-hq") is not None + + +class TestGetModelsByType: + """Tests for get_models_by_type function.""" + + def test_get_segmenters(self) -> None: + """get_models_by_type returns all segmenters.""" + segmenters = get_models_by_type(ModelType.SEGMENTER) + assert len(segmenters) > 0 + assert all(m.type == ModelType.SEGMENTER for m in segmenters) + + def test_get_encoders(self) -> None: + """get_models_by_type returns all encoders.""" + encoders = get_models_by_type(ModelType.ENCODER) + assert len(encoders) > 0 + assert all(m.type == ModelType.ENCODER for m in encoders) + + def test_get_models_by_type_with_string(self) -> None: + """get_models_by_type works with string argument (StrEnum).""" + # This works because ModelType is a StrEnum + segmenters_enum = get_models_by_type(ModelType.SEGMENTER) + segmenters_str = get_models_by_type("segmenter") # type: ignore[arg-type] + assert segmenters_enum == segmenters_str + + def test_get_trackers_empty(self) -> None: + """get_models_by_type returns empty list for type with no models.""" + trackers = get_models_by_type(ModelType.TRACKER) + assert trackers == [] + + +class TestGetModelsByCapability: + """Tests for get_models_by_capability function.""" + + def test_get_segmentation_capable(self) -> None: + """get_models_by_capability returns models with segmentation.""" + models = get_models_by_capability(Capability.SEGMENTATION) + assert len(models) > 0 + assert all(Capability.SEGMENTATION in m.capabilities for m in models) + + def test_get_encoding_capable(self) -> None: + """get_models_by_capability returns models with encoding.""" + models = get_models_by_capability(Capability.ENCODING) + assert len(models) > 0 + assert all(Capability.ENCODING in m.capabilities for m in models) + + +class TestGetModelsByFamily: + """Tests for get_models_by_family function.""" + + def test_get_sam2_family(self) -> None: + """get_models_by_family returns all SAM2 models.""" + models = get_models_by_family("SAM2") + assert len(models) > 0 + assert all(m.family == "SAM2" for m in models) + + def test_get_dinov2_family(self) -> None: + """get_models_by_family returns all DINOv2 models.""" + models = get_models_by_family("DINOv2") + assert len(models) > 0 + assert all(m.family == "DINOv2" for m in models) + + def test_get_nonexistent_family(self) -> None: + """get_models_by_family returns empty list for unknown family.""" + models = get_models_by_family("NonexistentFamily") + assert models == [] + + +class TestGetAvailableFunctions: + """Tests for get_available_* functions.""" + + def test_get_available_types(self) -> None: + """get_available_types returns list of type strings.""" + types = get_available_types() + assert isinstance(types, list) + assert "encoder" in types + assert "segmenter" in types + + def test_get_available_families(self) -> None: + """get_available_families returns list of family strings.""" + families = get_available_families() + assert isinstance(families, list) + assert "SAM-HQ" in families + assert "DINOv2" in families + + def test_get_available_sizes(self) -> None: + """get_available_sizes returns list of size strings.""" + sizes = get_available_sizes() + assert isinstance(sizes, list) + assert "tiny" in sizes + assert "small" in sizes + assert "base" in sizes + assert "large" in sizes + + def test_get_available_capabilities(self) -> None: + """get_available_capabilities returns list of capability strings.""" + caps = get_available_capabilities() + assert isinstance(caps, list) + assert "encoding" in caps + assert "segmentation" in caps + + +class TestIsValidModel: + """Tests for is_valid_model function.""" + + def test_valid_model(self) -> None: + """is_valid_model returns True for existing model.""" + assert is_valid_model("sam-hq") is True + assert is_valid_model("dinov2-base") is True + + def test_invalid_model(self) -> None: + """is_valid_model returns False for nonexistent model.""" + assert is_valid_model("nonexistent") is False + + def test_valid_model_with_matching_type(self) -> None: + """is_valid_model returns True when type matches.""" + assert is_valid_model("sam-hq", ModelType.SEGMENTER) is True + assert is_valid_model("dinov2-base", ModelType.ENCODER) is True + + def test_valid_model_with_wrong_type(self) -> None: + """is_valid_model returns False when type doesn't match.""" + assert is_valid_model("sam-hq", ModelType.ENCODER) is False + assert is_valid_model("dinov2-base", ModelType.SEGMENTER) is False + + +class TestGetLocalFilename: + """Tests for get_local_filename function.""" + + def test_get_filename_from_url(self) -> None: + """get_local_filename extracts filename from weights_url.""" + filename = get_local_filename("sam-hq") + assert filename == "sam_hq_vit_h.pth" + + def test_get_filename_sam2(self) -> None: + """get_local_filename works for SAM2 models.""" + filename = get_local_filename("sam2-tiny") + assert filename == "sam2.1_hiera_tiny.pt" + + def test_get_filename_nonexistent_model(self) -> None: + """get_local_filename raises ValueError for nonexistent model.""" + with pytest.raises(ValueError, match="not found in registry"): + get_local_filename("nonexistent") + + def test_get_filename_no_weights_url(self) -> None: + """get_local_filename raises ValueError for model without weights_url.""" + # DINOv2 models use hf_model_id, not weights_url + with pytest.raises(ValueError, match="has no weights_url"): + get_local_filename("dinov2-base")