Skip to content

Commit 3d56f17

Browse files
authored
Merge pull request #135 from amosproj/93-reduce-inference-time
93 reduce inference time Signed-off-by: Felix Hilgers <felix.hilgers@fau.de>
2 parents 6a4c28a + 8a2ab0a commit 3d56f17

3 files changed

Lines changed: 86 additions & 15 deletions

File tree

src/backend/analyzer/routes.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MIT
44
import asyncio
55
import json
6+
import cv2
67
import logging
78

89
import numpy as np
@@ -37,6 +38,15 @@ def __init__(self) -> None:
3738
self._processing_task: asyncio.Task[None] | None = None
3839
self._intrinsics_logged: bool = False
3940

41+
self.max_consecutive_errors = 5
42+
# adaptive downscaling parameters
43+
self.target_scale_init = config.TARGET_SCALE_INIT
44+
self.smooth_factor = config.SMOOTH_FACTOR
45+
self.min_scale = config.MIN_SCALE
46+
self.max_scale = config.MAX_SCALE
47+
# adaptive frame dropping parameters
48+
self.fps_threshold = config.FPS_THRESHOLD
49+
4050
async def connect(self, websocket: WebSocket) -> None:
4151
"""Accept a new WebSocket connection."""
4252
await websocket.accept()
@@ -107,9 +117,10 @@ async def _process_frames(self, source_track: MediaStreamTrack) -> None:
107117
frame_id = 0
108118
last_fps_time = asyncio.get_event_loop().time()
109119
fps_counter = 0
110-
current_fps = 0.0
120+
current_fps = 30.0
111121
consecutive_errors = 0
112-
max_consecutive_errors = 5
122+
123+
target_scale = self.target_scale_init
113124

114125
try:
115126
while self.active_connections:
@@ -121,7 +132,8 @@ async def _process_frames(self, source_track: MediaStreamTrack) -> None:
121132
except asyncio.TimeoutError:
122133
logging.warning("Frame receive timeout, skipping...")
123134
consecutive_errors += 1
124-
if consecutive_errors >= max_consecutive_errors:
135+
136+
if consecutive_errors >= self.max_consecutive_errors:
125137
logging.error(
126138
"Too many consecutive timeouts, reconnecting..."
127139
)
@@ -133,7 +145,7 @@ async def _process_frames(self, source_track: MediaStreamTrack) -> None:
133145
)
134146
consecutive_errors += 1
135147

136-
if consecutive_errors >= max_consecutive_errors:
148+
if consecutive_errors >= self.max_consecutive_errors:
137149
# full reconnect
138150
if self._webcam_session is not None:
139151
await self._webcam_session.close()
@@ -170,21 +182,45 @@ async def _process_frames(self, source_track: MediaStreamTrack) -> None:
170182
fps_counter = 0
171183
last_fps_time = current_time
172184

185+
if current_fps < 10:
186+
target_scale -= self.smooth_factor
187+
elif current_fps < 18:
188+
target_scale -= self.smooth_factor * 0.5
189+
else:
190+
target_scale += self.smooth_factor * 0.8
191+
192+
target_scale = max(
193+
self.min_scale, min(self.max_scale, target_scale)
194+
)
195+
print(
196+
f"[Adaptive Res] Scale={target_scale:.2f} | FPS={current_fps:.1f}"
197+
)
198+
199+
# Resize frame for processing
200+
if target_scale < 0.98:
201+
new_w = int(frame_array.shape[1] * target_scale)
202+
new_h = int(frame_array.shape[0] * target_scale)
203+
frame_small = cv2.resize(frame_array, (new_w, new_h))
204+
else:
205+
frame_small = frame_array
206+
207+
sample_rate = 2 if current_fps < self.fps_threshold else 4
208+
173209
# Run ML inference every 3rd frame and collect detections
174-
if not self.active_connections or frame_id % 3 != 0:
210+
if not self.active_connections or frame_id % sample_rate != 0:
175211
continue
176212

177213
# YOLO detection
178-
detections = await detector.infer(frame_array)
214+
detections = await detector.infer(frame_small)
179215

180216
if not detections:
181217
continue
182218

183219
# Distance estimation
184-
distances = estimator.estimate_distance_m(frame_array, detections)
220+
distances = estimator.estimate_distance_m(frame_small, detections)
185221

186222
metadata = self._build_metadata_message(
187-
frame_rgb=frame_array,
223+
frame_rgb=frame_small,
188224
detections=detections,
189225
distances=distances,
190226
timestamp=current_time,
@@ -196,7 +232,7 @@ async def _process_frames(self, source_track: MediaStreamTrack) -> None:
196232
await self._send_metadata(metadata)
197233

198234
# Small delay to prevent overwhelming
199-
await asyncio.sleep(0.033) # ~30 FPS processing
235+
# await asyncio.sleep(0.033) # ~30 FPS processing
200236

201237
except Exception as e:
202238
logging.warning(f"Frame processing error: {e}")

src/backend/common/config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,31 @@ class Config:
1212
# Camera settings
1313
CAMERA_INDEX: int = int(os.getenv("CAMERA_INDEX", "0"))
1414

15+
REGION_SIZE = int(
16+
os.getenv("REGION_SIZE", "5")
17+
) # size of square region for depth median
18+
SCALE_FACTOR = float(
19+
os.getenv("SCALE_FACTOR", "432.0")
20+
) # empirical calibration factor
21+
UPDATE_FREQ = int(
22+
os.getenv("UPDATE_FREQ", "2")
23+
) # number of frames between depth updates
24+
25+
# adaptive downsampling settings
26+
TARGET_SCALE_INIT: float = float(
27+
os.getenv("TARGET_SCALE_INIT", "0.8")
28+
) # initial downscale factor for images
29+
SMOOTH_FACTOR: float = float(
30+
os.getenv("SMOOTH_FACTOR", "0.15")
31+
) # smoothing factor for scale updates
32+
MIN_SCALE: float = float(os.getenv("MIN_SCALE", "0.2")) # minimum allowed scale
33+
MAX_SCALE: float = float(os.getenv("MAX_SCALE", "1.0")) # maximum allowed scale
34+
35+
# adaptive frame dropping
36+
FPS_THRESHOLD: float = float(
37+
os.getenv("FPS_THRESHOLD", "15.0")
38+
) # threshold FPS for skipping more frames
39+
1540
# Depth estimation settings
1641
REGION_SIZE = int(os.getenv("REGION_SIZE", "5"))
1742
SCALE_FACTOR = float(os.getenv("SCALE_FACTOR", "432.0"))

src/backend/common/core/depth.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Callable, Literal, Optional
1010

1111
from common.config import config
12-
from common.core.contracts import DepthEstimator, Detection
12+
from common.core.contracts import DepthEstimator
1313

1414
# Factories let us swap depth estimation backends without changing call sites.
1515
DepthEstimatorFactory = Callable[[Optional[Path]], DepthEstimator]
@@ -49,10 +49,14 @@ def __init__(
4949
midas_cache_directory: Custom directory for PyTorch Hub cache.
5050
If None, uses PyTorch's default cache location.
5151
"""
52-
53-
self.region_size = config.REGION_SIZE # size of region around bbox center
52+
self.region_size = (
53+
config.REGION_SIZE
54+
) # size of region around bbox center to sample depth
5455
self.scale_factor = config.SCALE_FACTOR # empirical calibration factor
56+
self.update_freq = config.UPDATE_FREQ # frames between depth updates
5557

58+
self.update_id = -1
59+
self.last_depths: list[float] = []
5660
self.model_type = model_type
5761
self.midas_model = midas_model
5862
self.device = (
@@ -67,19 +71,24 @@ def __init__(
6771
.to(self.device)
6872
.eval()
6973
)
70-
# MiDaS transforms
74+
# get MiDaS transforms
7175
midas_transforms = torch.hub.load(midas_model, "transforms", trust_repo=True)
7276
if model_type in {"DPT_Large", "DPT_Hybrid"}:
7377
self.transform = midas_transforms.dpt_transform
7478
else:
7579
self.transform = midas_transforms.small_transform
7680

7781
def estimate_distance_m(
78-
self, frame_rgb: np.ndarray, detections: list[Detection]
82+
self, frame_rgb: np.ndarray, dets: list[tuple[int, int, int, int, int, float]]
7983
) -> list[float]:
8084
"""Estimate distance in meters for each detection based on depth map.
8185
8286
Returns list of distances in meters."""
87+
self.update_id += 1
88+
if self.update_id % self.update_freq != 0 and len(self.last_depths) == len(
89+
dets
90+
):
91+
return self.last_depths
8392
h, w, _ = frame_rgb.shape
8493

8594
input_batch = self.transform(frame_rgb).to(self.device)
@@ -93,7 +102,7 @@ def estimate_distance_m(
93102
).squeeze()
94103
depth_map = prediction.cpu().numpy()
95104
distances = []
96-
for x1, y1, x2, y2, _cls_id, _conf in detections:
105+
for x1, y1, x2, y2, _cls_id, _conf in dets:
97106
# extract 5x5 central region of bbox and clip to image bounds
98107
cx = int((x1 + x2) / 2)
99108
cy = int((y1 + y2) / 2)
@@ -109,6 +118,7 @@ def estimate_distance_m(
109118
depth_value = max(np.mean(region), 1e-6) # avoid div by zero
110119

111120
distances.append(float(self.scale_factor / depth_value))
121+
self.last_depths = distances
112122
return distances
113123

114124

0 commit comments

Comments
 (0)