Skip to content

Commit ba1f44f

Browse files
committed
feat(utils): add prefetch to get_video_frames_generator
Adds an opt-in prefetch: int = 0 parameter. When > 0, frames are decoded in a background thread and buffered in a bounded queue, letting a CPU-bound consumer overlap with decode I/O. Default 0 keeps the original synchronous behaviour unchanged. The threaded path drives the existing sync generator on a daemon thread and pumps frames through a Queue(maxsize=prefetch). No new dependencies. Closes #1411.
1 parent fb2dec9 commit ba1f44f

2 files changed

Lines changed: 103 additions & 0 deletions

File tree

src/supervision/utils/video.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def get_video_frames_generator(
234234
start: int = 0,
235235
end: int | None = None,
236236
iterative_seek: bool = False,
237+
prefetch: int = 0,
237238
) -> Generator[npt.NDArray[np.uint8], None, None]:
238239
"""
239240
Get a generator that yields the frames of the video.
@@ -249,6 +250,10 @@ def get_video_frames_generator(
249250
iterative_seek: If True, the generator will seek to the
250251
`start` frame by grabbing each frame, which is much slower. This is a
251252
workaround for videos that don't open at all when you set the `start` value.
253+
prefetch: If > 0, decode frames in a background thread and buffer up to
254+
this many frames in a bounded queue. Useful when the consumer (e.g.
255+
CPU inference) is the bottleneck and can overlap with decode I/O.
256+
Default 0 keeps the original synchronous behaviour unchanged.
252257
253258
Returns:
254259
A generator that yields the
@@ -262,6 +267,17 @@ def get_video_frames_generator(
262267
...
263268
```
264269
"""
270+
if prefetch > 0:
271+
yield from _prefetched_frames_generator(
272+
source_path=source_path,
273+
stride=stride,
274+
start=start,
275+
end=end,
276+
iterative_seek=iterative_seek,
277+
prefetch=prefetch,
278+
)
279+
return
280+
265281
video, start, end = _validate_and_setup_video(
266282
source_path, start, end, iterative_seek
267283
)
@@ -280,6 +296,65 @@ def get_video_frames_generator(
280296
video.release()
281297

282298

299+
def _prefetched_frames_generator(
300+
source_path: str,
301+
stride: int,
302+
start: int,
303+
end: int | None,
304+
iterative_seek: bool,
305+
prefetch: int,
306+
) -> Generator[npt.NDArray[np.uint8], None, None]:
307+
frame_queue: Queue[npt.NDArray[np.uint8] | Exception | None] = Queue(
308+
maxsize=prefetch
309+
)
310+
stop_event = threading.Event()
311+
312+
def reader() -> None:
313+
sentinel: Exception | None = None
314+
try:
315+
for frame in get_video_frames_generator(
316+
source_path=source_path,
317+
stride=stride,
318+
start=start,
319+
end=end,
320+
iterative_seek=iterative_seek,
321+
prefetch=0,
322+
):
323+
if stop_event.is_set():
324+
return
325+
while True:
326+
try:
327+
frame_queue.put(frame, timeout=0.1)
328+
break
329+
except Full:
330+
if stop_event.is_set():
331+
return
332+
except Exception as exc:
333+
sentinel = exc
334+
# Push the terminating sentinel (None for normal end, exception for error),
335+
# respecting stop_event so we never block after the consumer has stopped.
336+
while True:
337+
try:
338+
frame_queue.put(sentinel, timeout=0.1)
339+
return
340+
except Full:
341+
if stop_event.is_set():
342+
return
343+
344+
thread = threading.Thread(target=reader, daemon=True)
345+
thread.start()
346+
try:
347+
while True:
348+
item = frame_queue.get()
349+
if isinstance(item, Exception):
350+
raise item
351+
if item is None:
352+
break
353+
yield item
354+
finally:
355+
stop_event.set()
356+
357+
283358
def process_video(
284359
source_path: str,
285360
target_path: str,

tests/utils/test_video.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,34 @@ def test_get_video_frames_generator(dummy_video_path):
200200
assert all(frame.shape == (480, 640, 3) for frame in frames)
201201

202202

203+
def test_get_video_frames_generator_prefetch_matches_sync(dummy_video_path):
204+
"""prefetch>0 must yield the same frames in the same order as the sync path."""
205+
sync_frames = list(get_video_frames_generator(dummy_video_path))
206+
prefetched_frames = list(get_video_frames_generator(dummy_video_path, prefetch=4))
207+
assert len(prefetched_frames) == len(sync_frames) == 10
208+
for a, b in zip(prefetched_frames, sync_frames):
209+
assert np.array_equal(a, b)
210+
211+
212+
def test_get_video_frames_generator_prefetch_propagates_decode_errors(tmp_path):
213+
"""Errors raised by the reader thread must reach the consumer, not get swallowed."""
214+
missing_path = str(tmp_path / "does_not_exist.mp4")
215+
with pytest.raises(Exception, match="Could not open video"):
216+
list(get_video_frames_generator(missing_path, prefetch=4))
217+
218+
219+
def test_get_video_frames_generator_prefetch_early_termination(dummy_video_path):
220+
"""Breaking out of the prefetched generator must not block subsequent iteration."""
221+
taken = []
222+
for frame in get_video_frames_generator(dummy_video_path, prefetch=4):
223+
taken.append(frame)
224+
if len(taken) >= 3:
225+
break
226+
assert len(taken) == 3
227+
# A fresh generator on the same file must still work normally.
228+
assert len(list(get_video_frames_generator(dummy_video_path, prefetch=4))) == 10
229+
230+
203231
def test_get_video_frames_generator_with_stride(dummy_video_path):
204232
"""
205233
Verify that get_video_frames_generator correctly handles the stride parameter.

0 commit comments

Comments
 (0)