Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@
os.getenv("CORE_MODEL_YOLO_WORLD_ENABLED", True)
)

# Enable experimental RFDETR backend (inference_exp) rollout, default is True
USE_INFERENCE_EXP_MODELS = str2bool(os.getenv("USE_INFERENCE_EXP_MODELS", "False"))

# ID of host device, default is None
DEVICE_ID = os.getenv("DEVICE_ID", None)

Expand Down
1 change: 1 addition & 0 deletions inference/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def infer_from_request(
responses = self.infer(**request.dict(), return_image_dims=False)
for response in responses:
response.time = perf_counter() - t1
logger.debug(f"model infer time: {response.time * 1000.0} ms")
if request.id:
response.inference_id = request.id

Expand Down
129 changes: 129 additions & 0 deletions inference/core/models/exp_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from threading import Lock
from time import perf_counter
from typing import Any, Generic, List, Optional, Tuple, Union

import numpy as np
from inference_exp.models.base.object_detection import Detections, ObjectDetectionModel
from inference_exp.models.base.types import (
PreprocessedInputs,
PreprocessingMetadata,
RawPrediction,
)

from inference.core.entities.responses.inference import (
InferenceResponseImage,
ObjectDetectionInferenceResponse,
ObjectDetectionPrediction,
)
from inference.core.env import API_KEY
from inference.core.logger import logger
from inference.core.models.base import Model
from inference.core.utils.image_utils import load_image_rgb
from inference.models.aliases import resolve_roboflow_model_alias


class InferenceExpObjectDetectionModelAdapter(Model):
def __init__(self, model_id: str, api_key: str = None, **kwargs):
super().__init__()

self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

self.api_key = api_key if api_key else API_KEY
model_id = resolve_roboflow_model_alias(model_id=model_id)

self.task_type = "object-detection"

# Lazy import to avoid hard dependency if flag disabled
from inference_exp import AutoModel # type: ignore

self._exp_model: ObjectDetectionModel = AutoModel.from_pretrained(
model_id_or_path=model_id, api_key=self.api_key
)
if hasattr(self._exp_model, "optimize_for_inference"):
self._exp_model.optimize_for_inference()

self.class_names = list(self._exp_model.class_names)

def map_inference_kwargs(self, kwargs: dict) -> dict:
return kwargs

def preprocess(self, image: Any, **kwargs):
is_batch = isinstance(image, list)
images = image if is_batch else [image]
np_images: List[np.ndarray] = [
load_image_rgb(
v,
disable_preproc_auto_orient=kwargs.get(
"disable_preproc_auto_orient", False
),
)
for v in images
]
mapped_kwargs = self.map_inference_kwargs(kwargs)
return self._exp_model.pre_process(np_images, **mapped_kwargs)

def predict(self, img_in, **kwargs):
mapped_kwargs = self.map_inference_kwargs(kwargs)
return self._exp_model.forward(img_in, **mapped_kwargs)

def postprocess(
self,
predictions: Tuple[np.ndarray, ...],
preprocess_return_metadata: PreprocessingMetadata,
**kwargs,
) -> List[Detections]:
mapped_kwargs = self.map_inference_kwargs(kwargs)
detections_list = self._exp_model.post_process(
predictions, preprocess_return_metadata, **mapped_kwargs
Comment on lines +50 to +77

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard experimental AutoModel inference with a lock

The new experimental adapter calls self._exp_model.pre_process, forward, and post_process directly without any synchronization. Other torch-backed models in this repository protect inference with a Lock to avoid concurrent access to shared model state (for example YOLOv8ObjectDetection.predict uses _session_lock). AutoModel instances from inference_exp are PyTorch models as well and are unlikely to be thread-safe. When the server runs with multiple workers or handles concurrent requests, unsynchronized access can trigger CUDA/torch runtime errors or corrupt intermediate buffers. The adapter already imports Lock, so wrapping the model calls in a mutex seems intended and would prevent these race conditions.

Useful? React with 👍 / 👎.

)

responses: List[ObjectDetectionInferenceResponse] = []
for preproc_metadata, det in zip(preprocess_return_metadata, detections_list):
H = preproc_metadata.original_size.height
W = preproc_metadata.original_size.width

xyxy = det.xyxy.detach().cpu().numpy()
confs = det.confidence.detach().cpu().numpy()
class_ids = det.class_id.detach().cpu().numpy()

predictions: List[ObjectDetectionPrediction] = []

for (x1, y1, x2, y2), conf, class_id in zip(xyxy, confs, class_ids):
cx = (float(x1) + float(x2)) / 2.0
cy = (float(y1) + float(y2)) / 2.0
w = float(x2) - float(x1)
h = float(y2) - float(y1)
class_id_int = int(class_id)
class_name = (
self.class_names[class_id_int]
if 0 <= class_id_int < len(self.class_names)
else str(class_id_int)
)
predictions.append(
ObjectDetectionPrediction(
x=cx,
y=cy,
width=w,
height=h,
confidence=float(conf),
**{"class": class_name},
class_id=class_id_int,
)
)

responses.append(
ObjectDetectionInferenceResponse(
predictions=predictions,
image=InferenceResponseImage(width=W, height=H),
)
)

return responses

def clear_cache(self, delete_from_disk: bool = True) -> None:
"""Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

Args:
delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
"""
pass
12 changes: 12 additions & 0 deletions inference/models/florence2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ def import_class_from_file(file_path, class_name, alias_name=None):

sys.path.insert(0, parent_dir)

previous_module = sys.modules.get(module_name)
injected = False
try:
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)

sys.modules[module_name] = module
injected = True

# Manually set the __package__ attribute to the parent package
module.__package__ = os.path.basename(module_dir)

Expand All @@ -27,5 +32,12 @@ def import_class_from_file(file_path, class_name, alias_name=None):
if alias_name:
globals()[alias_name] = cls
return cls
except Exception:
if injected:
if previous_module is not None:
sys.modules[module_name] = previous_module
else:
sys.modules.pop(module_name, None)
raise
finally:
sys.path.pop(0)
14 changes: 14 additions & 0 deletions inference/models/rfdetr/rfdetr_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from inference.core.models.exp_adapter import InferenceExpObjectDetectionModelAdapter


class RFDetrExperimentalModel(InferenceExpObjectDetectionModelAdapter):
"""Adapter for RF-DETR using inference_exp AutoModel backend.

This class wraps an inference_exp AutoModel to present the same interface
as legacy models in the inference server.
"""

def map_inference_kwargs(self, kwargs: dict) -> dict:
return {
"threshold": kwargs.get("confidence"),
}
28 changes: 28 additions & 0 deletions inference/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PALIGEMMA_ENABLED,
QWEN_2_5_ENABLED,
SMOLVLM2_ENABLED,
USE_INFERENCE_EXP_MODELS,
)
from inference.core.models.base import Model
from inference.core.models.stubs import (
Expand Down Expand Up @@ -550,3 +551,30 @@ def get_model(model_id, api_key=API_KEY, **kwargs) -> Model:

def get_roboflow_model(*args, **kwargs):
return get_model(*args, **kwargs)


# Prefer inference_exp backend for RF-DETR variants when enabled and available
try:
if USE_INFERENCE_EXP_MODELS:
# Ensure experimental package is importable before swapping
__import__("inference_exp")
from inference.models.rfdetr.rfdetr_exp import RFDetrExperimentalModel
from inference.models.yolov8.yolov8_object_detection_exp import (
Yolo8ODExperimentalModel,
)

for task, variant in ROBOFLOW_MODEL_TYPES.keys():
if task == "object-detection" and variant.startswith("rfdetr-"):
ROBOFLOW_MODEL_TYPES[(task, variant)] = RFDetrExperimentalModel

# iterate over ROBOFLOW_MODEL_TYPES and replace all valuses where the model variatn starts with yolov8 with the experimental model
for task, variant in ROBOFLOW_MODEL_TYPES.keys():
if task == "object-detection" and variant.startswith("yolov8"):
ROBOFLOW_MODEL_TYPES[(task, variant)] = Yolo8ODExperimentalModel


except Exception:
# Fallback silently to legacy ONNX RFDETR when experimental stack is unavailable
warnings.warn(
"Inference experimental stack is unavailable, falling back to regular model inference stack"
)
11 changes: 11 additions & 0 deletions inference/models/yolov8/yolov8_object_detection_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from inference.core.models.exp_adapter import InferenceExpObjectDetectionModelAdapter


class Yolo8ODExperimentalModel(InferenceExpObjectDetectionModelAdapter):
def map_inference_kwargs(self, kwargs: dict) -> dict:
return {
"conf_thresh": kwargs.get("confidence"),
"iou_thresh": kwargs.get("iou_threshold"),
"max_detections": kwargs.get("max_detections"),
"class_agnostic": kwargs.get("class_agnostic"),
}
50 changes: 35 additions & 15 deletions inference_experimental/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[project]
name = "inference-exp"
version = "0.15.2"
version = "0.16.2"
description = "Experimental vresion of inference package which is supposed to evolve into inference 1.0"
readme = "README.md"
requires-python = ">=3.10,<3.13"
requires-python = ">=3.9,<3.13"
dependencies = [
"numpy",
"torch>=2.0.0,<3.0.0",
Expand All @@ -28,7 +28,7 @@ dependencies = [
"filelock>=3.12.0,<4.0.0",
"rich>=14.1.0,<15.0.0",
"segmentation-models-pytorch>=0.5.0,<1.0.0",
"scikit-image~=0.25.0"
"scikit-image>=0.24.0,<0.26.0"
]

[project.optional-dependencies]
Expand All @@ -39,28 +39,28 @@ torch-cpu = [
torch-cu118 = [
"torch>=2.0.0,<3.0.0",
"torchvision",
"pycuda>=2025.0.0,<2026.0.0",
"pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin' and python_version >= '3.10'",
]
torch-cu124 = [
"torch>=2.0.0,<3.0.0",
"torchvision",
"pycuda>=2025.0.0,<2026.0.0",
"pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin' and python_version >= '3.10'",
]
torch-cu126 = [
"torch>=2.0.0,<3.0.0",
"torchvision",
"pycuda>=2025.0.0,<2026.0.0",
"pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin' and python_version >= '3.10'",
]
torch-cu128 = [
"torch>=2.0.0,<3.0.0",
"torchvision",
"pycuda>=2025.0.0,<2026.0.0",
"pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin' and python_version >= '3.10'",
]
torch-jp6-cu126 = [
"numpy<2.0.0",
"torch>=2.0.0,<3.0.0",
"torchvision",
"pycuda>=2025.0.0,<2026.0.0",
"numpy<2.0.0; platform_system == 'Linux' and platform_machine == 'aarch64' and python_version >= '3.10'",
"torch>=2.0.0,<3.0.0; platform_system == 'Linux' and platform_machine == 'aarch64' and python_version >= '3.10'",
"torchvision; platform_system == 'Linux' and platform_machine == 'aarch64' and python_version >= '3.10'",
"pycuda>=2025.0.0,<2026.0.0; platform_system == 'Linux' and platform_machine == 'aarch64' and python_version >= '3.10'",
]
onnx-cpu = [
"onnxruntime>=1.15.1,<1.23.0"
Expand All @@ -74,9 +74,9 @@ onnx-cu12 = [
"pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin'",
]
onnx-jp6-cu126 = [
"numpy<2.0.0",
"onnxruntime-gpu>=1.17.0,<1.24.0; platform_system != 'darwin'",
"pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin'",
"numpy<2.0.0; platform_system == 'Linux' and platform_machine == 'aarch64' and python_version >= '3.10'",
"onnxruntime-gpu>=1.17.0,<1.24.0; platform_system == 'Linux' and platform_machine == 'aarch64' and python_version >= '3.10'",
"pycuda>=2025.0.0,<2026.0.0; platform_system == 'Linux' and platform_machine == 'aarch64' and python_version >= '3.10'",
]
mediapipe = [
"mediapipe>=0.9,<0.11"
Expand All @@ -89,7 +89,7 @@ trt10 = [
"tensorrt-cu12>=10.0.0,<11.0.0; platform_system == 'Linux' or platform_system == 'Windows'",
"tensorrt-lean>=10.0.0,<11.0.0; platform_system == 'Linux' or platform_system == 'Windows'",
"tensorrt-lean-cu12>=10.0.0,<11.0.0; platform_system == 'Linux' or platform_system == 'Windows'",
"pycuda>=2025.0.0,<2026.0.0",
"pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin' and python_version >= '3.10'",
]
test = [
"pytest>=8.0.0",
Expand Down Expand Up @@ -179,6 +179,26 @@ name = "tensorrt-cu12-bindings"
[[tool.uv.dependency-metadata]]
name = "tensorrt-lean-cu12-bindings"

[[tool.uv.dependency-metadata]]
name = "tensorrt"
requires-dist = [
"tensorrt-cu12; platform_system == 'Linux' or platform_system == 'Windows'",
"tensorrt-cu13; platform_system == 'Linux' or platform_system == 'Windows'",
]

[[tool.uv.dependency-metadata]]
name = "tensorrt-cu13"

[[tool.uv.dependency-metadata]]
name = "tensorrt-lean"
requires-dist = [
"tensorrt-lean-cu12; platform_system == 'Linux' or platform_system == 'Windows'",
"tensorrt-lean-cu13; platform_system == 'Linux' or platform_system == 'Windows'",
]

[[tool.uv.dependency-metadata]]
name = "tensorrt-lean-cu13"


[tool.uv]
conflicts = [
Expand Down
Loading
Loading