Skip to content

Commit 1825ab4

Browse files
authored
Add parity test for streaming decoder (#469)
1 parent 4b05551 commit 1825ab4

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

tests/spdl_unittest/io/demuxer_test.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_demuxer_accept_torch_tensor(get_sample):
6767
assert not torch.any(src)
6868

6969

70-
def test_streaming_video_demuxing(get_sample):
70+
def test_streaming_video_demuxing_smoke_test(get_sample):
7171
"""`streaming_demux_video` can decode packets in streaming fashion."""
7272
cmd = "ffmpeg -hide_banner -y -f lavfi -i testsrc -f lavfi -i sine -frames:v 10 sample.mp4"
7373
sample = get_sample(cmd)
@@ -81,3 +81,24 @@ def test_streaming_video_demuxing(get_sample):
8181
packets = demuxer.demux_video()
8282

8383
assert num_packets == len(packets)
84+
85+
86+
def test_streaming_video_demuxing_parity(get_sample):
87+
"""`streaming_demux_video` can decode packets in streaming fashion."""
88+
cmd = "ffmpeg -hide_banner -y -f lavfi -i testsrc -f lavfi -i sine -frames:v 30 sample.mp4"
89+
sample = get_sample(cmd)
90+
91+
def _decode_packets(packets):
92+
frames = spdl.io.decode_packets(packets)
93+
buffer = spdl.io.convert_frames(frames)
94+
return spdl.io.to_numpy(buffer)
95+
96+
demuxer = spdl.io.Demuxer(sample.path)
97+
ite = iter(demuxer.streaming_demux_video(30))
98+
pkts = next(ite)
99+
hyp = _decode_packets(pkts)
100+
101+
pkts = spdl.io.demux_video(sample.path)
102+
ref = _decode_packets(pkts)
103+
104+
assert np.array_equal(hyp, ref)

0 commit comments

Comments
 (0)