Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 105 additions & 45 deletions beit/pytorch/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
BEiT model loader implementation for image classification
"""
import torch
from transformers import BeitImageProcessor, BeitForImageClassification
from PIL import Image
from ...tools.utils import get_file
from transformers import BeitForImageClassification
from typing import Optional
from dataclasses import dataclass

from ...base import ForgeModel
from ...config import (
Expand All @@ -20,6 +19,14 @@
Framework,
StrEnum,
)
from ...tools.utils import VisionPreprocessor, VisionPostprocessor


@dataclass
class BeitConfig(ModelConfig):
"""Configuration specific to BEiT models"""

source: ModelSource


class ModelVariant(StrEnum):
Expand All @@ -32,13 +39,15 @@ class ModelVariant(StrEnum):
class ModelLoader(ForgeModel):
"""BEiT model loader implementation for image classification tasks."""

# Dictionary of available model variants
# Dictionary of available model variants using structured configs
_VARIANTS = {
ModelVariant.BASE: ModelConfig(
ModelVariant.BASE: BeitConfig(
pretrained_model_name="microsoft/beit-base-patch16-224",
source=ModelSource.HUGGING_FACE,
),
ModelVariant.LARGE: ModelConfig(
ModelVariant.LARGE: BeitConfig(
pretrained_model_name="microsoft/beit-large-patch16-224",
source=ModelSource.HUGGING_FACE,
),
}

Expand All @@ -53,11 +62,13 @@ def __init__(self, variant: Optional[ModelVariant] = None):
If None, DEFAULT_VARIANT is used.
"""
super().__init__(variant)
self.processor = None
self.model = None
self._preprocessor = None
self._postprocessor = None

@classmethod
def _get_model_info(cls, variant: Optional[ModelVariant] = None) -> ModelInfo:
"""Implementation method for getting model info with validated variant.
"""Get model information for dashboard and metrics reporting.

Args:
variant: Optional ModelVariant specifying which variant to use.
Expand All @@ -66,29 +77,21 @@ def _get_model_info(cls, variant: Optional[ModelVariant] = None) -> ModelInfo:
Returns:
ModelInfo: Information about the model and variant
"""
if variant is None:
variant = cls.DEFAULT_VARIANT

# Get source from variant config
source = cls._VARIANTS[variant].source

return ModelInfo(
model="beit",
variant=variant,
group=ModelGroup.GENERALITY,
task=ModelTask.CV_IMAGE_CLS,
source=ModelSource.HUGGING_FACE,
source=source,
framework=Framework.TORCH,
)

def _load_processor(self):
"""Load processor for the current variant.

Returns:
The loaded processor instance
"""

# Initialize processor
self.processor = BeitImageProcessor.from_pretrained(
self._variant_config.pretrained_model_name
)

return self.processor

def load_model(self, dtype_override=None):
"""Load and return the BEiT model instance for this instance's variant.

Expand All @@ -102,41 +105,98 @@ def load_model(self, dtype_override=None):
# Get the pretrained model name from the instance's variant config
pretrained_model_name = self._variant_config.pretrained_model_name

# Ensure processor is loaded
if self.processor is None:
self._load_processor()

# Load pre-trained model from HuggingFace
model = BeitForImageClassification.from_pretrained(pretrained_model_name)

model.eval()

# Store model for potential use in input preprocessing and postprocessing
self.model = model

# Update preprocessor with cached model (for TIMM models)
if self._preprocessor is not None:
self._preprocessor.set_cached_model(model)

# Update postprocessor with model instance (for HuggingFace models)
if self._postprocessor is not None:
self._postprocessor.set_model_instance(model)

# Only convert dtype if explicitly requested
if dtype_override is not None:
model = model.to(dtype_override)

return model

def load_inputs(self, dtype_override=None, batch_size=1):
"""Load and return sample inputs for the BEiT model with this instance's variant settings.
def input_preprocess(self, dtype_override=None, batch_size=1, image=None):
"""Preprocess input image(s) and return model-ready input tensor.

Args:
dtype_override: Optional torch.dtype to override the model inputs' default dtype.
batch_size: Optional batch size to override the default batch size of 1.
dtype_override: Optional torch.dtype override (default: float32).
batch_size: Batch size (ignored if image is a list).
image: PIL Image, URL string, tensor, list of images/URLs, or None (uses default COCO image).

Returns:
dict: Input tensors and attention masks that can be fed to the model.
torch.Tensor: Preprocessed input tensor.
"""
# Ensure processor is initialized
if self.processor is None:
self._load_processor()
if self._preprocessor is None:
model_name = self._variant_config.pretrained_model_name
source = self._variant_config.source

self._preprocessor = VisionPreprocessor(
model_source=source,
model_name=model_name,
image_processor_kwargs={"use_fast": True} if source == ModelSource.HUGGING_FACE else None,
)

if hasattr(self, "model") and self.model is not None:
self._preprocessor.set_cached_model(self.model)

model_for_config = None
if self._variant_config.source == ModelSource.TIMM:
if hasattr(self, "model") and self.model is not None:
model_for_config = self.model

return self._preprocessor.preprocess(
image=image,
dtype_override=dtype_override,
batch_size=batch_size,
model_for_config=model_for_config,
)

image_file = get_file("http://images.cocodataset.org/val2017/000000039769.jpg")
image = Image.open(str(image_file))
inputs = self.processor(images=image, return_tensors="pt")
def load_inputs(self, dtype_override=None, batch_size=1, image=None):
"""Load and return sample inputs for the model.

if dtype_override is not None:
inputs["pixel_values"] = inputs["pixel_values"].to(dtype_override)
Args:
dtype_override: Optional torch.dtype override.
batch_size: Batch size (default: 1).
image: Optional input image.

Returns:
torch.Tensor: Preprocessed input tensor.
"""
return self.input_preprocess(
image=image,
dtype_override=dtype_override,
batch_size=batch_size,
)

def output_postprocess(self, output):
"""Post-process model outputs.

Args:
output: Model output tensor.

Returns:
dict: Prediction dict with top predictions.
"""
if self._postprocessor is None:
model_name = self._variant_config.pretrained_model_name
source = self._variant_config.source

# Add batch dimension if batch_size
for key in inputs:
if torch.is_tensor(inputs[key]):
inputs[key] = inputs[key].repeat_interleave(batch_size, dim=0)
self._postprocessor = VisionPostprocessor(
model_source=source,
model_name=model_name,
model_instance=self.model,
)

return inputs
return self._postprocessor.postprocess(output, top_k=1, return_dict=True)
Loading