Skip to content

Commit 1e0e297

Browse files
committed
feat(utils): add prefetch to get_video_frames_generator
Adds an opt-in prefetch: int = 0 parameter. When > 0, frames are decoded in a background thread and buffered in a bounded queue, letting a CPU-bound consumer overlap with decode I/O. Default 0 keeps the original synchronous behaviour unchanged. The threaded path drives the existing sync generator on a daemon thread and pumps frames through a Queue(maxsize=prefetch). No new dependencies. Closes #1411.
1 parent 80b0181 commit 1e0e297

2 files changed

Lines changed: 58 additions & 0 deletions

File tree

src/supervision/utils/video.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def get_video_frames_generator(
234234
start: int = 0,
235235
end: int | None = None,
236236
iterative_seek: bool = False,
237+
prefetch: int = 0,
237238
) -> Generator[npt.NDArray[np.uint8], None, None]:
238239
"""
239240
Get a generator that yields the frames of the video.
@@ -249,6 +250,10 @@ def get_video_frames_generator(
249250
iterative_seek: If True, the generator will seek to the
250251
`start` frame by grabbing each frame, which is much slower. This is a
251252
workaround for videos that don't open at all when you set the `start` value.
253+
prefetch: If > 0, decode frames in a background thread and buffer up to
254+
this many frames in a bounded queue. Useful when the consumer (e.g.
255+
CPU inference) is the bottleneck and can overlap with decode I/O.
256+
Default 0 keeps the original synchronous behaviour unchanged.
252257
253258
Returns:
254259
A generator that yields the
@@ -262,6 +267,17 @@ def get_video_frames_generator(
262267
...
263268
```
264269
"""
270+
if prefetch > 0:
271+
yield from _prefetched_frames_generator(
272+
source_path=source_path,
273+
stride=stride,
274+
start=start,
275+
end=end,
276+
iterative_seek=iterative_seek,
277+
prefetch=prefetch,
278+
)
279+
return
280+
265281
video, start, end = _validate_and_setup_video(
266282
source_path, start, end, iterative_seek
267283
)
@@ -280,6 +296,39 @@ def get_video_frames_generator(
280296
video.release()
281297

282298

299+
def _prefetched_frames_generator(
300+
source_path: str,
301+
stride: int,
302+
start: int,
303+
end: int | None,
304+
iterative_seek: bool,
305+
prefetch: int,
306+
) -> Generator[npt.NDArray[np.uint8], None, None]:
307+
frame_queue: Queue[npt.NDArray[np.uint8] | None] = Queue(maxsize=prefetch)
308+
309+
def reader() -> None:
310+
try:
311+
for frame in get_video_frames_generator(
312+
source_path=source_path,
313+
stride=stride,
314+
start=start,
315+
end=end,
316+
iterative_seek=iterative_seek,
317+
prefetch=0,
318+
):
319+
frame_queue.put(frame)
320+
finally:
321+
frame_queue.put(None)
322+
323+
thread = threading.Thread(target=reader, daemon=True)
324+
thread.start()
325+
while True:
326+
frame = frame_queue.get()
327+
if frame is None:
328+
break
329+
yield frame
330+
331+
283332
def process_video(
284333
source_path: str,
285334
target_path: str,

tests/utils/test_video.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,15 @@ def test_get_video_frames_generator(dummy_video_path):
200200
assert all(frame.shape == (480, 640, 3) for frame in frames)
201201

202202

203+
def test_get_video_frames_generator_prefetch_matches_sync(dummy_video_path):
204+
"""prefetch>0 must yield the same frames in the same order as the sync path."""
205+
sync_frames = list(get_video_frames_generator(dummy_video_path))
206+
prefetched_frames = list(get_video_frames_generator(dummy_video_path, prefetch=4))
207+
assert len(prefetched_frames) == len(sync_frames) == 10
208+
for a, b in zip(prefetched_frames, sync_frames):
209+
assert np.array_equal(a, b)
210+
211+
203212
def test_get_video_frames_generator_with_stride(dummy_video_path):
204213
"""
205214
Verify that get_video_frames_generator correctly handles the stride parameter.

0 commit comments

Comments
 (0)