88import numpy as np
99import torch
1010from PIL import Image
11-
1211from common .config import config
1312from common .typing import Detection
1413from common .protocols import DepthEstimator
15-
1614from common .utils .depth import calculate_distances , resize_to_frame
1715
1816
3432 AutoModelForDepthEstimation = None # type: ignore
3533
3634
35+ def _build_midas_small_transform (
36+ midas_transforms : object ,
37+ input_size : int ,
38+ ) -> Callable [[np .ndarray ], torch .Tensor ]:
39+ """Create a MiDaS-small transform with a custom input size."""
40+ import cv2
41+ from torchvision .transforms import Compose # type: ignore[import-untyped]
42+
43+ resize = getattr (midas_transforms , "Resize" )
44+ normalize = getattr (midas_transforms , "NormalizeImage" )
45+ prepare = getattr (midas_transforms , "PrepareForNet" )
46+
47+ return Compose (
48+ [
49+ lambda img : {"image" : img / 255.0 },
50+ resize (
51+ input_size ,
52+ input_size ,
53+ resize_target = None ,
54+ keep_aspect_ratio = True ,
55+ ensure_multiple_of = 32 ,
56+ resize_method = "upper_bound" ,
57+ image_interpolation_method = cv2 .INTER_CUBIC ,
58+ ),
59+ normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
60+ prepare (),
61+ lambda sample : torch .from_numpy (sample ["image" ]).unsqueeze (0 ),
62+ ]
63+ )
64+
65+
66+ def _build_midas_no_resize_transform (
67+ midas_transforms : object ,
68+ mean : list [float ],
69+ std : list [float ],
70+ ) -> Callable [[np .ndarray ], torch .Tensor ]:
71+ """Create a MiDaS transform that assumes the input is already resized, want to avoid resize inside MiDaS."""
72+ from torchvision .transforms import Compose # type: ignore[import-untyped]
73+
74+ normalize = getattr (midas_transforms , "NormalizeImage" )
75+ prepare = getattr (midas_transforms , "PrepareForNet" )
76+
77+ return Compose (
78+ [
79+ lambda img : {"image" : img / 255.0 },
80+ normalize (mean = mean , std = std ),
81+ prepare (),
82+ lambda sample : torch .from_numpy (sample ["image" ]).unsqueeze (0 ),
83+ ]
84+ )
85+
86+
3787# Factories let us swap depth estimation backends without changing call sites.
3888DepthEstimatorFactory = Callable [[Optional [Path ]], DepthEstimator ]
3989
@@ -227,13 +277,47 @@ def __init__(
227277 )
228278 self ._input_name = self ._session .get_inputs ()[0 ].name
229279 self ._output_name = self ._session .get_outputs ()[0 ].name
280+ self ._no_resize_transform : Optional [Callable [[np .ndarray ], torch .Tensor ]] = None
230281
231282 super ().__init__ (
232283 midas_cache_directory = midas_cache_directory ,
233284 model_type = model_type ,
234285 midas_model = midas_model ,
235286 )
236287
288+ def _load_transform (self ) -> Callable [[np .ndarray ], torch .Tensor ]:
289+ """Load MiDaS transform, aligned to the ONNX input size when needed."""
290+ torch .hub .set_dir (str (self .midas_cache_directory ))
291+ midas_transforms = torch .hub .load (
292+ self .midas_model , "transforms" , trust_repo = True
293+ )
294+ if self .model_type in {"DPT_Large" , "DPT_Hybrid" }:
295+ if config .ONNX_SHARED_PREPROCESSING and all (
296+ hasattr (midas_transforms , attr )
297+ for attr in ("NormalizeImage" , "PrepareForNet" )
298+ ):
299+ self ._no_resize_transform = _build_midas_no_resize_transform (
300+ midas_transforms , mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ]
301+ )
302+ return midas_transforms .dpt_transform
303+ if self .model_type == "MiDaS_small" :
304+ if config .ONNX_SHARED_PREPROCESSING and all (
305+ hasattr (midas_transforms , attr )
306+ for attr in ("NormalizeImage" , "PrepareForNet" )
307+ ):
308+ self ._no_resize_transform = _build_midas_no_resize_transform (
309+ midas_transforms ,
310+ mean = [0.485 , 0.456 , 0.406 ],
311+ std = [0.229 , 0.224 , 0.225 ],
312+ )
313+ if config .MIDAS_ONNX_INPUT_SIZE != 256 and hasattr (
314+ midas_transforms , "Resize"
315+ ):
316+ return _build_midas_small_transform (
317+ midas_transforms , config .MIDAS_ONNX_INPUT_SIZE
318+ )
319+ return midas_transforms .small_transform
320+
237321 def _resolve_providers (self ) -> list [str ]:
238322 configured = config .MIDAS_ONNX_PROVIDERS or config .ONNX_PROVIDERS
239323 if configured :
@@ -255,6 +339,35 @@ def _predict_depth_map(
255339 self , frame_rgb : np .ndarray , output_shape : tuple [int , int ]
256340 ) -> np .ndarray :
257341 input_batch = self .transform (frame_rgb )
342+ return self ._run_onnx_inference (input_batch , output_shape )
343+
344+ def estimate_distance_m_preprocessed (
345+ self ,
346+ resized_rgb : np .ndarray ,
347+ dets : list [Detection ],
348+ output_shape : tuple [int , int ],
349+ ) -> list [float ]:
350+ """Estimate distances using a pre-resized ONNX input."""
351+ self .update_id += 1
352+ if self .update_id % self .update_freq != 0 and len (self .last_depths ) == len (
353+ dets
354+ ):
355+ return self .last_depths
356+ depth_map = self ._predict_depth_map_preprocessed (resized_rgb , output_shape )
357+ distances = self ._distances_from_depth_map (depth_map , dets )
358+ self .last_depths = distances
359+ return distances
360+
361+ def _predict_depth_map_preprocessed (
362+ self , resized_rgb : np .ndarray , output_shape : tuple [int , int ]
363+ ) -> np .ndarray :
364+ transform = self ._no_resize_transform or self .transform
365+ input_batch = transform (resized_rgb )
366+ return self ._run_onnx_inference (input_batch , output_shape )
367+
368+ def _run_onnx_inference (
369+ self , input_batch : torch .Tensor , output_shape : tuple [int , int ]
370+ ) -> np .ndarray :
258371 _ , _ , h , w = input_batch .shape
259372 size = max (w , h )
260373 input_batch = torch .nn .functional .pad (input_batch , (0 , size - w , 0 , size - h ))
0 commit comments