Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@
os.getenv("CORE_MODEL_YOLO_WORLD_ENABLED", True)
)

# Enable experimental RFDETR backend (inference_exp) rollout, default is True
USE_INFERENCE_EXP_MODELS = str2bool(os.getenv("USE_INFERENCE_EXP_MODELS", "False"))

# ID of host device, default is None
DEVICE_ID = os.getenv("DEVICE_ID", None)

Expand Down
1 change: 1 addition & 0 deletions inference/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def infer_from_request(
responses = self.infer(**request.dict(), return_image_dims=False)
for response in responses:
response.time = perf_counter() - t1
logger.debug(f"model infer time: {response.time * 1000.0} ms")
if request.id:
response.inference_id = request.id

Expand Down
129 changes: 129 additions & 0 deletions inference/core/models/exp_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from threading import Lock
from time import perf_counter
from typing import Any, Generic, List, Optional, Tuple, Union

import numpy as np
from inference_exp.models.base.object_detection import Detections, ObjectDetectionModel
from inference_exp.models.base.types import (
PreprocessedInputs,
PreprocessingMetadata,
RawPrediction,
)

from inference.core.entities.responses.inference import (
InferenceResponseImage,
ObjectDetectionInferenceResponse,
ObjectDetectionPrediction,
)
from inference.core.env import API_KEY
from inference.core.logger import logger
from inference.core.models.base import Model
from inference.core.utils.image_utils import load_image_rgb
from inference.models.aliases import resolve_roboflow_model_alias


class InferenceExpObjectDetectionModelAdapter(Model):
def __init__(self, model_id: str, api_key: str = None, **kwargs):
super().__init__()

self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

self.api_key = api_key if api_key else API_KEY
model_id = resolve_roboflow_model_alias(model_id=model_id)

self.task_type = "object-detection"

# Lazy import to avoid hard dependency if flag disabled
from inference_exp import AutoModel # type: ignore

self._exp_model: ObjectDetectionModel = AutoModel.from_pretrained(
model_id_or_path=model_id, api_key=self.api_key
)
if hasattr(self._exp_model, "optimize_for_inference"):
self._exp_model.optimize_for_inference()

self.class_names = list(self._exp_model.class_names)

def map_inference_kwargs(self, kwargs: dict) -> dict:
return kwargs

def preprocess(self, image: Any, **kwargs):
is_batch = isinstance(image, list)
images = image if is_batch else [image]
np_images: List[np.ndarray] = [
load_image_rgb(
v,
disable_preproc_auto_orient=kwargs.get(
"disable_preproc_auto_orient", False
),
)
for v in images
]
mapped_kwargs = self.map_inference_kwargs(kwargs)
return self._exp_model.pre_process(np_images, **mapped_kwargs)

def predict(self, img_in, **kwargs):
mapped_kwargs = self.map_inference_kwargs(kwargs)
return self._exp_model.forward(img_in, **mapped_kwargs)

def postprocess(
self,
predictions: Tuple[np.ndarray, ...],
preprocess_return_metadata: PreprocessingMetadata,
**kwargs,
) -> List[Detections]:
mapped_kwargs = self.map_inference_kwargs(kwargs)
detections_list = self._exp_model.post_process(
predictions, preprocess_return_metadata, **mapped_kwargs
Comment on lines +50 to +77

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard experimental AutoModel inference with a lock

The new experimental adapter calls self._exp_model.pre_process, forward, and post_process directly without any synchronization. Other torch-backed models in this repository protect inference with a Lock to avoid concurrent access to shared model state (for example YOLOv8ObjectDetection.predict uses _session_lock). AutoModel instances from inference_exp are PyTorch models as well and are unlikely to be thread-safe. When the server runs with multiple workers or handles concurrent requests, unsynchronized access can trigger CUDA/torch runtime errors or corrupt intermediate buffers. The adapter already imports Lock, so wrapping the model calls in a mutex seems intended and would prevent these race conditions.

Useful? React with 👍 / 👎.

)

responses: List[ObjectDetectionInferenceResponse] = []
for preproc_metadata, det in zip(preprocess_return_metadata, detections_list):
H = preproc_metadata.original_size.height
W = preproc_metadata.original_size.width

xyxy = det.xyxy.detach().cpu().numpy()
confs = det.confidence.detach().cpu().numpy()
class_ids = det.class_id.detach().cpu().numpy()

predictions: List[ObjectDetectionPrediction] = []

for (x1, y1, x2, y2), conf, class_id in zip(xyxy, confs, class_ids):
cx = (float(x1) + float(x2)) / 2.0
cy = (float(y1) + float(y2)) / 2.0
w = float(x2) - float(x1)
h = float(y2) - float(y1)
class_id_int = int(class_id)
class_name = (
self.class_names[class_id_int]
if 0 <= class_id_int < len(self.class_names)
else str(class_id_int)
)
predictions.append(
ObjectDetectionPrediction(
x=cx,
y=cy,
width=w,
height=h,
confidence=float(conf),
**{"class": class_name},
class_id=class_id_int,
)
)

responses.append(
ObjectDetectionInferenceResponse(
predictions=predictions,
image=InferenceResponseImage(width=W, height=H),
)
)

return responses

def clear_cache(self, delete_from_disk: bool = True) -> None:
"""Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

Args:
delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
"""
pass
10 changes: 7 additions & 3 deletions inference/models/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
"paligemma-3b-ft-docvqa-448": "paligemma-pretrains/18",
"paligemma-3b-ft-ocrvqa-448": "paligemma-pretrains/19",
}
# FLORENCE_ALIASES = {
# "florence-2-base": "florence-pretrains/1",
# "florence-2-large": "florence-pretrains/2",
# }
# since transformers 0.53.3 need newer version of florence2 weights
FLORENCE_ALIASES = {
"florence-2-base": "florence-pretrains/1",
"florence-2-large": "florence-pretrains/2",
"florence-2-base": "florence-pretrains/3",
"florence-2-large": "florence-pretrains/4",
}

QWEN_ALIASES = {
"qwen25-vl-7b": "qwen-pretrains/1",
}
Expand Down
20 changes: 15 additions & 5 deletions inference/models/florence2/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def initialize_model(self, **kwargs):
lora_config = LoraConfig.from_pretrained(self.cache_dir, device_map=DEVICE)
model_id = lora_config.base_model_name_or_path
revision = lora_config.revision
original_revision_pre_mapping = revision
if revision is not None:
try:
self.dtype = getattr(torch, revision)
Expand Down Expand Up @@ -135,11 +136,19 @@ def initialize_model(self, **kwargs):
adapter_missing_keys.append(key)
load_result.missing_keys.clear()
load_result.missing_keys.extend(adapter_missing_keys)
if len(load_result.missing_keys) > 0:
raise RuntimeError(
"Could not load LoRA weights for the model - found missing checkpoint keys "
f"({len(load_result.missing_keys)}): {load_result.missing_keys}",
)
if original_revision_pre_mapping == "refs/pr/6":
if len(load_result.missing_keys) > 2:
raise RuntimeError(
"Could not load LoRA weights for the model - found missing checkpoint keys "
f"({len(load_result.missing_keys)}): {load_result.missing_keys}",
)

else:
if len(load_result.missing_keys) > 0:
raise RuntimeError(
"Could not load LoRA weights for the model - found missing checkpoint keys "
f"({len(load_result.missing_keys)}): {load_result.missing_keys}",
)

self.model = model
except ImportError:
Expand All @@ -166,6 +175,7 @@ def get_lora_base_from_roboflow(self, model_id, revision):
)

revision_mapping = {
("microsoft/Florence-2-base-ft", "refs/pr/6"): "refs/pr/29-converted",
("microsoft/Florence-2-base-ft", "refs/pr/22"): "refs/pr/29-converted",
("microsoft/Florence-2-large-ft", "refs/pr/20"): "refs/pr/38-converted",
}
Expand Down
12 changes: 12 additions & 0 deletions inference/models/florence2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ def import_class_from_file(file_path, class_name, alias_name=None):

sys.path.insert(0, parent_dir)

previous_module = sys.modules.get(module_name)
injected = False
try:
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)

sys.modules[module_name] = module
injected = True

# Manually set the __package__ attribute to the parent package
module.__package__ = os.path.basename(module_dir)

Expand All @@ -27,5 +32,12 @@ def import_class_from_file(file_path, class_name, alias_name=None):
if alias_name:
globals()[alias_name] = cls
return cls
except Exception:
if injected:
if previous_module is not None:
sys.modules[module_name] = previous_module
else:
sys.modules.pop(module_name, None)
raise
finally:
sys.path.pop(0)
14 changes: 14 additions & 0 deletions inference/models/rfdetr/rfdetr_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from inference.core.models.exp_adapter import InferenceExpObjectDetectionModelAdapter


class RFDetrExperimentalModel(InferenceExpObjectDetectionModelAdapter):
"""Adapter for RF-DETR using inference_exp AutoModel backend.

This class wraps an inference_exp AutoModel to present the same interface
as legacy models in the inference server.
"""

def map_inference_kwargs(self, kwargs: dict) -> dict:
return {
"threshold": kwargs.get("confidence"),
}
28 changes: 28 additions & 0 deletions inference/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PALIGEMMA_ENABLED,
QWEN_2_5_ENABLED,
SMOLVLM2_ENABLED,
USE_INFERENCE_EXP_MODELS,
)
from inference.core.models.base import Model
from inference.core.models.stubs import (
Expand Down Expand Up @@ -550,3 +551,30 @@ def get_model(model_id, api_key=API_KEY, **kwargs) -> Model:

def get_roboflow_model(*args, **kwargs):
return get_model(*args, **kwargs)


# Prefer inference_exp backend for RF-DETR variants when enabled and available
try:
if USE_INFERENCE_EXP_MODELS:
# Ensure experimental package is importable before swapping
__import__("inference_exp")
from inference.models.rfdetr.rfdetr_exp import RFDetrExperimentalModel
from inference.models.yolov8.yolov8_object_detection_exp import (
Yolo8ODExperimentalModel,
)

for task, variant in ROBOFLOW_MODEL_TYPES.keys():
if task == "object-detection" and variant.startswith("rfdetr-"):
ROBOFLOW_MODEL_TYPES[(task, variant)] = RFDetrExperimentalModel

# iterate over ROBOFLOW_MODEL_TYPES and replace all valuses where the model variatn starts with yolov8 with the experimental model
for task, variant in ROBOFLOW_MODEL_TYPES.keys():
if task == "object-detection" and variant.startswith("yolov8"):
ROBOFLOW_MODEL_TYPES[(task, variant)] = Yolo8ODExperimentalModel


except Exception:
# Fallback silently to legacy ONNX RFDETR when experimental stack is unavailable
warnings.warn(
"Inference experimental stack is unavailable, falling back to regular model inference stack"
)
11 changes: 11 additions & 0 deletions inference/models/yolov8/yolov8_object_detection_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from inference.core.models.exp_adapter import InferenceExpObjectDetectionModelAdapter


class Yolo8ODExperimentalModel(InferenceExpObjectDetectionModelAdapter):
def map_inference_kwargs(self, kwargs: dict) -> dict:
return {
"conf_thresh": kwargs.get("confidence"),
"iou_thresh": kwargs.get("iou_threshold"),
"max_detections": kwargs.get("max_detections"),
"class_agnostic": kwargs.get("class_agnostic"),
}
Loading