Skip to content

Commit 16de460

Browse files
authored
Merge pull request #246 from amosproj/feat/201-model-align
Feat/201 model align Signed-off-by: Felix Hilgers <felix.hilgers@fau.de>
2 parents e2d3c87 + b483999 commit 16de460

11 files changed

Lines changed: 283 additions & 19 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ Optional environment variables:
140140
- `DEPTH_ANYTHING_MODEL` - Hugging Face model ID for Depth Anything V2 (default `depth-anything/Depth-Anything-V2-Small-hf`)
141141
- `DEPTH_ANYTHING_CACHE_DIR` - Depth Anything cache directory (default `models/depth_anything_cache`)
142142
- `MIDAS_ONNX_MODEL_PATH` - defaults to `models/midas_small.onnx`
143+
- `MIDAS_ONNX_INPUT_SIZE` – input size for MiDaS ONNX preprocessing (default: `384`)
143144
- `MIDAS_ONNX_PROVIDERS` - comma separated ONNX Runtime providers for depth (falls back to `ONNX_PROVIDERS`)
145+
- `ONNX_SHARED_PREPROCESSING` – reuse one resize step for ONNX detector + depth when sizes align (default: `true`)
144146
- `DETECTOR_BACKEND` - `torch` (default) or `onnx`
145147
- `TORCH_DEVICE` - force PyTorch to use `cuda:0`, `cpu`, etc. (defaults to best available)
146148
- `TORCH_HALF_PRECISION` - `auto` (default), `true`, or `false`

scripts/download_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def main() -> None:
163163
yolo_path=yolo_final_path,
164164
output_path=yolo_onnx_target,
165165
opset=args.onnx_opset,
166+
imgsz=config.DETECTOR_IMAGE_SIZE,
166167
simplify=args.onnx_simplify,
167168
half=config.ONNX_HALF_PRECISION,
168169
)
@@ -198,6 +199,7 @@ def main() -> None:
198199
model_type=args.midas_type,
199200
model_repo=args.midas_repo,
200201
opset=args.onnx_opset,
202+
input_size=config.MIDAS_ONNX_INPUT_SIZE,
201203
half=config.ONNX_HALF_PRECISION,
202204
)
203205

src/backend/analyzer/manager.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from common.utils.transforms import (
3838
calculate_adaptive_scale,
39+
letterbox,
3940
resize_frame,
4041
)
4142

@@ -203,6 +204,19 @@ def _get_compute_intrinsics(
203204
)
204205
return self._intrinsics_cache[cache_key]
205206

207+
def _should_share_preprocess(
208+
self, detector: ObjectDetector, estimator: DepthEstimator
209+
) -> bool:
210+
"""Compute whether shared ONNX preprocessing can be used."""
211+
return (
212+
config.ONNX_SHARED_PREPROCESSING
213+
and config.DETECTOR_BACKEND == "onnx"
214+
and config.DEPTH_BACKEND == "onnx"
215+
and config.DETECTOR_IMAGE_SIZE == config.MIDAS_ONNX_INPUT_SIZE
216+
and hasattr(detector, "infer_preprocessed")
217+
and hasattr(estimator, "estimate_distance_m_preprocessed")
218+
)
219+
206220
async def shutdown(self) -> None:
207221
"""Cleanup on service shutdown."""
208222
await self._stop_processing()
@@ -254,6 +268,7 @@ async def _process_frames(self, source_track: MediaStreamTrack) -> None:
254268
"""Process frames from webcam and send metadata to all clients."""
255269
detector = get_detector()
256270
estimator = get_depth_estimator()
271+
shared_preprocess = self._should_share_preprocess(detector, estimator)
257272

258273
state = ProcessingState(
259274
target_scale=self.target_scale_init, source_track=source_track
@@ -301,7 +316,12 @@ async def frame_receiver() -> None:
301316

302317
self._inference_task = asyncio.create_task(
303318
self._run_inference_pipeline(
304-
frame_small, state, detector, estimator, current_time
319+
frame_small,
320+
state,
321+
detector,
322+
estimator,
323+
current_time,
324+
shared_preprocess,
305325
)
306326
)
307327

@@ -425,12 +445,25 @@ async def _process_detection(
425445
state: ProcessingState,
426446
detector: ObjectDetector,
427447
estimator: DepthEstimator,
448+
shared_preprocess: bool,
428449
) -> tuple[list[Detection], list[float], list[bool]]:
429-
with self._measure_time(
430-
self._detection_duration, labels={"backend": config.DETECTOR_BACKEND}
431-
):
432-
# YOLO detection (async)
433-
raw_detections = await detector.infer(frame_small)
450+
if shared_preprocess:
451+
resized, ratio, dwdh = letterbox(frame_small, config.DETECTOR_IMAGE_SIZE)
452+
with self._measure_time(
453+
self._detection_duration, labels={"backend": config.DETECTOR_BACKEND}
454+
):
455+
raw_detections = await detector.infer_preprocessed(
456+
resized,
457+
ratio,
458+
dwdh,
459+
(frame_small.shape[0], frame_small.shape[1]),
460+
)
461+
else:
462+
with self._measure_time(
463+
self._detection_duration, labels={"backend": config.DETECTOR_BACKEND}
464+
):
465+
# YOLO detection (async)
466+
raw_detections = await detector.infer(frame_small)
434467

435468
if not raw_detections:
436469
return [], [], []
@@ -440,9 +473,18 @@ async def _process_detection(
440473
labels={"model_type": estimator.model_type},
441474
):
442475
# Distance estimation (sync) -> run in executor
443-
raw_distances = await asyncio.get_running_loop().run_in_executor(
444-
None, estimator.estimate_distance_m, frame_small, raw_detections
445-
)
476+
if shared_preprocess:
477+
raw_distances = await asyncio.get_running_loop().run_in_executor(
478+
None,
479+
estimator.estimate_distance_m_preprocessed,
480+
resized,
481+
raw_detections,
482+
(frame_small.shape[0], frame_small.shape[1]),
483+
)
484+
else:
485+
raw_distances = await asyncio.get_running_loop().run_in_executor(
486+
None, estimator.estimate_distance_m, frame_small, raw_detections
487+
)
446488

447489
# Tracking logic
448490
updated_track_ids, track_assignments = (
@@ -493,6 +535,7 @@ async def _run_inference_pipeline(
493535
detector: ObjectDetector,
494536
estimator: DepthEstimator,
495537
current_time: float,
538+
shared_preprocess: bool,
496539
) -> None:
497540
"""Run ML inference detection and tracking pipeline in background."""
498541

@@ -517,6 +560,7 @@ async def _run_inference_pipeline(
517560
state=state,
518561
detector=detector,
519562
estimator=estimator,
563+
shared_preprocess=shared_preprocess,
520564
)
521565

522566
if all_detections:

src/backend/common/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class Config:
6161
MIDAS_ONNX_MODEL_PATH: Path = Path(
6262
os.getenv("MIDAS_ONNX_MODEL_PATH", "models/midas_small.onnx")
6363
).resolve()
64+
MIDAS_ONNX_INPUT_SIZE: int = int(os.getenv("MIDAS_ONNX_INPUT_SIZE", "384"))
6465
MIDAS_ONNX_PROVIDERS: list[str] = [
6566
provider.strip()
6667
for provider in os.getenv("MIDAS_ONNX_PROVIDERS", "").split(",")
@@ -103,7 +104,7 @@ class Config:
103104
os.getenv("ONNX_MODEL_PATH", str(MODEL_PATH.with_suffix(".onnx")))
104105
).resolve()
105106
DETECTOR_BACKEND: str = os.getenv("DETECTOR_BACKEND", "torch").lower()
106-
DETECTOR_IMAGE_SIZE: int = int(os.getenv("DETECTOR_IMAGE_SIZE", "640"))
107+
DETECTOR_IMAGE_SIZE: int = int(os.getenv("DETECTOR_IMAGE_SIZE", "384"))
107108
DETECTOR_CONF_THRESHOLD: float = float(os.getenv("DETECTOR_CONF_THRESHOLD", "0.25"))
108109
DETECTOR_IOU_THRESHOLD: float = float(os.getenv("DETECTOR_IOU_THRESHOLD", "0.7"))
109110
DETECTOR_MAX_DETECTIONS: int = int(os.getenv("DETECTOR_MAX_DETECTIONS", "100"))
@@ -120,6 +121,9 @@ class Config:
120121
for provider in os.getenv("ONNX_PROVIDERS", "").split(",")
121122
if provider.strip()
122123
]
124+
ONNX_SHARED_PREPROCESSING: bool = os.getenv(
125+
"ONNX_SHARED_PREPROCESSING", "true"
126+
).lower() in ("1", "true", "yes")
123127
ONNX_IO_BINDING: bool = os.getenv("ONNX_IO_BINDING", "false").lower() in (
124128
"1",
125129
"true",

src/backend/common/core/depth.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88
import numpy as np
99
import torch
1010
from PIL import Image
11-
1211
from common.config import config
1312
from common.typing import Detection
1413
from common.protocols import DepthEstimator
15-
1614
from common.utils.depth import calculate_distances, resize_to_frame
1715

1816

@@ -34,6 +32,58 @@
3432
AutoModelForDepthEstimation = None # type: ignore
3533

3634

35+
def _build_midas_small_transform(
36+
midas_transforms: object,
37+
input_size: int,
38+
) -> Callable[[np.ndarray], torch.Tensor]:
39+
"""Create a MiDaS-small transform with a custom input size."""
40+
import cv2
41+
from torchvision.transforms import Compose # type: ignore[import-untyped]
42+
43+
resize = getattr(midas_transforms, "Resize")
44+
normalize = getattr(midas_transforms, "NormalizeImage")
45+
prepare = getattr(midas_transforms, "PrepareForNet")
46+
47+
return Compose(
48+
[
49+
lambda img: {"image": img / 255.0},
50+
resize(
51+
input_size,
52+
input_size,
53+
resize_target=None,
54+
keep_aspect_ratio=True,
55+
ensure_multiple_of=32,
56+
resize_method="upper_bound",
57+
image_interpolation_method=cv2.INTER_CUBIC,
58+
),
59+
normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
60+
prepare(),
61+
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
62+
]
63+
)
64+
65+
66+
def _build_midas_no_resize_transform(
67+
midas_transforms: object,
68+
mean: list[float],
69+
std: list[float],
70+
) -> Callable[[np.ndarray], torch.Tensor]:
71+
"""Create a MiDaS transform that assumes the input is already resized, want to avoid resize inside MiDaS."""
72+
from torchvision.transforms import Compose # type: ignore[import-untyped]
73+
74+
normalize = getattr(midas_transforms, "NormalizeImage")
75+
prepare = getattr(midas_transforms, "PrepareForNet")
76+
77+
return Compose(
78+
[
79+
lambda img: {"image": img / 255.0},
80+
normalize(mean=mean, std=std),
81+
prepare(),
82+
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
83+
]
84+
)
85+
86+
3787
# Factories let us swap depth estimation backends without changing call sites.
3888
DepthEstimatorFactory = Callable[[Optional[Path]], DepthEstimator]
3989

@@ -227,13 +277,47 @@ def __init__(
227277
)
228278
self._input_name = self._session.get_inputs()[0].name
229279
self._output_name = self._session.get_outputs()[0].name
280+
self._no_resize_transform: Optional[Callable[[np.ndarray], torch.Tensor]] = None
230281

231282
super().__init__(
232283
midas_cache_directory=midas_cache_directory,
233284
model_type=model_type,
234285
midas_model=midas_model,
235286
)
236287

288+
def _load_transform(self) -> Callable[[np.ndarray], torch.Tensor]:
289+
"""Load MiDaS transform, aligned to the ONNX input size when needed."""
290+
torch.hub.set_dir(str(self.midas_cache_directory))
291+
midas_transforms = torch.hub.load(
292+
self.midas_model, "transforms", trust_repo=True
293+
)
294+
if self.model_type in {"DPT_Large", "DPT_Hybrid"}:
295+
if config.ONNX_SHARED_PREPROCESSING and all(
296+
hasattr(midas_transforms, attr)
297+
for attr in ("NormalizeImage", "PrepareForNet")
298+
):
299+
self._no_resize_transform = _build_midas_no_resize_transform(
300+
midas_transforms, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
301+
)
302+
return midas_transforms.dpt_transform
303+
if self.model_type == "MiDaS_small":
304+
if config.ONNX_SHARED_PREPROCESSING and all(
305+
hasattr(midas_transforms, attr)
306+
for attr in ("NormalizeImage", "PrepareForNet")
307+
):
308+
self._no_resize_transform = _build_midas_no_resize_transform(
309+
midas_transforms,
310+
mean=[0.485, 0.456, 0.406],
311+
std=[0.229, 0.224, 0.225],
312+
)
313+
if config.MIDAS_ONNX_INPUT_SIZE != 256 and hasattr(
314+
midas_transforms, "Resize"
315+
):
316+
return _build_midas_small_transform(
317+
midas_transforms, config.MIDAS_ONNX_INPUT_SIZE
318+
)
319+
return midas_transforms.small_transform
320+
237321
def _resolve_providers(self) -> list[str]:
238322
configured = config.MIDAS_ONNX_PROVIDERS or config.ONNX_PROVIDERS
239323
if configured:
@@ -255,6 +339,35 @@ def _predict_depth_map(
255339
self, frame_rgb: np.ndarray, output_shape: tuple[int, int]
256340
) -> np.ndarray:
257341
input_batch = self.transform(frame_rgb)
342+
return self._run_onnx_inference(input_batch, output_shape)
343+
344+
def estimate_distance_m_preprocessed(
345+
self,
346+
resized_rgb: np.ndarray,
347+
dets: list[Detection],
348+
output_shape: tuple[int, int],
349+
) -> list[float]:
350+
"""Estimate distances using a pre-resized ONNX input."""
351+
self.update_id += 1
352+
if self.update_id % self.update_freq != 0 and len(self.last_depths) == len(
353+
dets
354+
):
355+
return self.last_depths
356+
depth_map = self._predict_depth_map_preprocessed(resized_rgb, output_shape)
357+
distances = self._distances_from_depth_map(depth_map, dets)
358+
self.last_depths = distances
359+
return distances
360+
361+
def _predict_depth_map_preprocessed(
362+
self, resized_rgb: np.ndarray, output_shape: tuple[int, int]
363+
) -> np.ndarray:
364+
transform = self._no_resize_transform or self.transform
365+
input_batch = transform(resized_rgb)
366+
return self._run_onnx_inference(input_batch, output_shape)
367+
368+
def _run_onnx_inference(
369+
self, input_batch: torch.Tensor, output_shape: tuple[int, int]
370+
) -> np.ndarray:
258371
_, _, h, w = input_batch.shape
259372
size = max(w, h)
260373
input_batch = torch.nn.functional.pad(input_batch, (0, size - w, 0, size - h))

0 commit comments

Comments
 (0)