Skip to content

Commit 91f9240

Browse files
committed
fix: preserve qwen2.5-vl video metadata in processor and add regression tests
1 parent d45286e commit 91f9240

4 files changed

Lines changed: 398 additions & 19 deletions

File tree

lmms_eval/models/chat/qwen2_5_vl.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,29 @@ def _collate(x):
9090
video_kwargs["nframes"] = self.max_num_frames
9191
batched_messages = [chat_message.to_hf_messages(video_kwargs=video_kwargs) for chat_message in chat_messages]
9292
texts = self.processor.apply_chat_template(batched_messages, tokenize=False, add_generation_prompt=True)
93-
image_inputs, video_inputs = process_vision_info(batched_messages)
93+
image_inputs, video_inputs, video_kwargs_qwen = process_vision_info(
94+
batched_messages,
95+
return_video_kwargs=True,
96+
image_patch_size=14,
97+
return_video_metadata=True,
98+
)
99+
video_kwargs = {**video_kwargs_qwen, "do_resize":False}
100+
101+
video_metadatas = None
102+
if video_inputs is not None:
103+
video_inputs, video_metadatas = zip(*video_inputs)
104+
video_inputs, video_metadatas = (
105+
list(video_inputs),
106+
list(video_metadatas),
107+
)
108+
94109
padding_side = "left" if self.batch_size > 1 else "right"
95110
inputs = self.processor(
96111
text=texts,
97112
images=image_inputs,
98113
videos=video_inputs,
114+
video_metadata=video_metadatas,
115+
**video_kwargs,
99116
padding=True,
100117
padding_side=padding_side,
101118
return_tensors="pt",

lmms_eval/models/simple/qwen2_5_vl.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,44 @@ def flatten(self, input):
167167
new_list.append(j)
168168
return new_list
169169

170+
def _subsample_video_inputs(self, video_inputs, video_metadatas=None) -> None:
171+
if video_inputs is None:
172+
return
173+
174+
for index, video_input in enumerate(video_inputs):
175+
total_frames = video_input.shape[0]
176+
indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int)
177+
indices = np.unique(indices)
178+
if total_frames - 1 not in indices:
179+
indices = np.append(indices, total_frames - 1)
180+
indices = np.unique(indices)
181+
video_inputs[index] = video_input[indices]
182+
183+
if video_metadatas is None or index >= len(video_metadatas):
184+
continue
185+
186+
video_metadata = video_metadatas[index]
187+
if isinstance(video_metadata, dict):
188+
metadata_frames = video_metadata.get("frames_indices")
189+
else:
190+
metadata_frames = getattr(video_metadata, "frames_indices", None)
191+
192+
if metadata_frames is None:
193+
continue
194+
195+
frame_indices = np.asarray(metadata_frames)
196+
if frame_indices.ndim != 1 or len(frame_indices) <= indices[-1]:
197+
continue
198+
199+
selected_frame_indices = frame_indices[indices]
200+
if isinstance(metadata_frames, list):
201+
selected_frame_indices = selected_frame_indices.tolist()
202+
203+
if isinstance(video_metadata, dict):
204+
video_metadata["frames_indices"] = selected_frame_indices
205+
else:
206+
video_metadata.frames_indices = selected_frame_indices
207+
170208
def _encode_image_data_url(self, image: Image.Image) -> str:
171209
return encode_image_to_data_url(
172210
image,
@@ -288,22 +326,27 @@ def _collate(x):
288326
batched_messages.append(message)
289327

290328
texts = self.processor.apply_chat_template(batched_messages, tokenize=False, add_generation_prompt=True)
291-
image_inputs, video_inputs = process_vision_info(batched_messages)
329+
image_inputs, video_inputs, video_kwargs_qwen = process_vision_info(
330+
batched_messages,
331+
return_video_kwargs=True,
332+
image_patch_size=14,
333+
return_video_metadata=True,
334+
)
335+
video_kwargs = {**video_kwargs_qwen, "do_resize":False}
336+
337+
video_metadatas = None
292338
if video_inputs is not None:
293-
total_frames = video_inputs[0].shape[0]
294-
indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int)
295-
# Ensure unique indices if linspace produces duplicates for few frames
296-
indices = np.unique(indices)
297-
# Append the last frame index if not already included
298-
if total_frames - 1 not in indices:
299-
indices = np.append(indices, total_frames - 1)
300-
indices = np.unique(indices) # Ensure uniqueness again
301-
video_inputs[0] = video_inputs[0][indices]
339+
video_inputs, video_metadatas = zip(*video_inputs)
340+
video_inputs, video_metadatas = list(video_inputs), list(video_metadatas)
341+
self._subsample_video_inputs(video_inputs, video_metadatas)
342+
302343
padding_side = "left" if self.batch_size > 1 else "right"
303344
inputs = self.processor(
304345
text=texts,
305346
images=image_inputs,
306347
videos=video_inputs,
348+
video_metadata=video_metadatas,
349+
**video_kwargs,
307350
padding=True,
308351
padding_side=padding_side,
309352
return_tensors="pt",
@@ -504,19 +547,26 @@ def _collate(x):
504547
batched_messages.append(message)
505548

506549
texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batched_messages]
507-
image_inputs, video_inputs = process_vision_info(batched_messages)
550+
image_inputs, video_inputs, video_kwargs_qwen = process_vision_info(
551+
batched_messages,
552+
return_video_kwargs=True,
553+
image_patch_size=14,
554+
return_video_metadata=True,
555+
)
556+
video_kwargs = {**video_kwargs_qwen, "do_resize":False}
557+
558+
video_metadatas = None
508559
if video_inputs is not None:
509-
total_frames = video_inputs[0].shape[0]
510-
indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int)
511-
indices = np.unique(indices)
512-
if total_frames - 1 not in indices:
513-
indices = np.append(indices, total_frames - 1)
514-
indices = np.unique(indices)
515-
video_inputs[0] = video_inputs[0][indices]
560+
video_inputs, video_metadatas = zip(*video_inputs)
561+
video_inputs, video_metadatas = list(video_inputs), list(video_metadatas)
562+
self._subsample_video_inputs(video_inputs, video_metadatas)
563+
516564
inputs = self.processor(
517565
text=texts,
518566
images=image_inputs,
519567
videos=video_inputs,
568+
video_metadata=video_metadatas,
569+
**video_kwargs,
520570
padding=True,
521571
return_tensors="pt",
522572
)
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import types
2+
import unittest
3+
from unittest.mock import patch
4+
5+
import numpy as np
6+
import torch
7+
8+
from lmms_eval.models.chat.qwen2_5_vl import Qwen2_5_VL
9+
10+
11+
class _FakeTokenizer:
12+
eos_token_id = 0
13+
pad_token_id = 0
14+
15+
16+
class _FakeInputs(dict):
17+
def __init__(self):
18+
super().__init__(input_ids=torch.tensor([[10, 11]]))
19+
20+
@property
21+
def input_ids(self):
22+
return self["input_ids"]
23+
24+
def to(self, device):
25+
return self
26+
27+
28+
class _FakeProcessor:
29+
def __init__(self):
30+
self.calls = []
31+
32+
def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
33+
return ["prompt"]
34+
35+
def __call__(self, **kwargs):
36+
self.calls.append(kwargs)
37+
return _FakeInputs()
38+
39+
def batch_decode(self, generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False):
40+
return ["final answer"]
41+
42+
43+
class _FakeModel:
44+
def generate(self, **kwargs):
45+
return torch.tensor([[10, 11, 12]])
46+
47+
48+
class _FakeVideoReader:
49+
def __init__(self, path):
50+
self.path = path
51+
52+
def __len__(self):
53+
return 5
54+
55+
56+
class _FakeVideoMetadata:
57+
def __init__(self, frames_indices, fps=10.0, total_num_frames=5):
58+
self.frames_indices = np.asarray(frames_indices)
59+
self.fps = fps
60+
self.total_num_frames = total_num_frames
61+
62+
@property
63+
def sampled_fps(self):
64+
return len(self.frames_indices) / self.total_num_frames * self.fps
65+
66+
67+
class _FakeChatMessages:
68+
last_video_kwargs = None
69+
70+
def __init__(self, messages):
71+
self.messages = messages
72+
73+
def extract_media(self):
74+
return [], ["demo.mp4"], []
75+
76+
def to_hf_messages(self, video_kwargs=None):
77+
type(self).last_video_kwargs = dict(video_kwargs or {})
78+
return [
79+
{
80+
"role": "user",
81+
"content": [
82+
{"type": "video", "video": "demo.mp4", **(video_kwargs or {})},
83+
{"type": "text", "text": "Describe the video"},
84+
],
85+
}
86+
]
87+
88+
89+
class TestQwen25VLChat(unittest.TestCase):
90+
def _make_model(self, max_num_frames=3, fps=None):
91+
model = Qwen2_5_VL.__new__(Qwen2_5_VL)
92+
model._tokenizer = _FakeTokenizer()
93+
model.processor = _FakeProcessor()
94+
model._model = _FakeModel()
95+
model.max_pixels = 1024
96+
model.min_pixels = 256
97+
model.max_num_frames = max_num_frames
98+
model.fps = fps
99+
model.batch_size_per_gpu = 1
100+
model.use_cache = False
101+
model.device_map = "cpu"
102+
model._device = torch.device("cpu")
103+
model._rank = 0
104+
model._world_size = 1
105+
model.task_dict = {"demo_task": {"test": [{"id": 0}]}}
106+
model.cache_hook = types.SimpleNamespace(add_partial=lambda *args, **kwargs: None)
107+
return model
108+
109+
def test_generate_until_passes_video_metadata_and_kwargs_to_processor_with_fps(self):
110+
model = self._make_model(fps=2.5)
111+
metadata = _FakeVideoMetadata([0, 2, 4], fps=6.0, total_num_frames=5)
112+
video_tensor = torch.arange(12, dtype=torch.float32).reshape(3, 4)
113+
request = types.SimpleNamespace(
114+
args=("Describe the video", lambda doc: [{"role": "user", "content": []}], {}, 0, "demo_task", "test"),
115+
)
116+
117+
with (
118+
patch("lmms_eval.models.chat.qwen2_5_vl.ChatMessages", _FakeChatMessages),
119+
patch("lmms_eval.models.chat.qwen2_5_vl.process_vision_info", return_value=(None, [(video_tensor, metadata)], {"do_sample_frames": False})),
120+
patch("lmms_eval.models.chat.qwen2_5_vl.decord", types.SimpleNamespace(VideoReader=_FakeVideoReader)),
121+
patch("lmms_eval.models.chat.qwen2_5_vl.log_metrics", lambda **kwargs: None),
122+
):
123+
result = model.generate_until([request])
124+
125+
self.assertEqual(len(result), 1)
126+
self.assertEqual(result[0].text, "final answer")
127+
self.assertEqual(_FakeChatMessages.last_video_kwargs["fps"], 2.5)
128+
129+
processor_call = model.processor.calls[0]
130+
self.assertTrue(torch.equal(processor_call["videos"][0], video_tensor))
131+
self.assertIs(processor_call["video_metadata"][0], metadata)
132+
self.assertFalse(processor_call["do_sample_frames"])
133+
self.assertFalse(processor_call["do_resize"])
134+
self.assertAlmostEqual(processor_call["video_metadata"][0].sampled_fps, 3.6)
135+
136+
def test_generate_until_keeps_sampled_metadata_in_sync_when_using_nframes(self):
137+
model = self._make_model(max_num_frames=3, fps=None)
138+
request = types.SimpleNamespace(
139+
args=("Describe the video", lambda doc: [{"role": "user", "content": []}], {}, 0, "demo_task", "test"),
140+
)
141+
142+
def fake_process_vision_info(batched_messages, return_video_kwargs=False, image_patch_size=14, return_video_metadata=False):
143+
video_content = batched_messages[0][0]["content"][0]
144+
nframes = video_content["nframes"]
145+
sampled_indices = np.linspace(0, 4, nframes, dtype=int)
146+
video_tensor = torch.arange(nframes * 4, dtype=torch.float32).reshape(nframes, 4)
147+
metadata = _FakeVideoMetadata(sampled_indices.tolist(), fps=10.0, total_num_frames=5)
148+
return None, [(video_tensor, metadata)], {"do_sample_frames": False}
149+
150+
with (
151+
patch("lmms_eval.models.chat.qwen2_5_vl.ChatMessages", _FakeChatMessages),
152+
patch("lmms_eval.models.chat.qwen2_5_vl.process_vision_info", side_effect=fake_process_vision_info),
153+
patch("lmms_eval.models.chat.qwen2_5_vl.decord", types.SimpleNamespace(VideoReader=_FakeVideoReader)),
154+
patch("lmms_eval.models.chat.qwen2_5_vl.log_metrics", lambda **kwargs: None),
155+
):
156+
result = model.generate_until([request])
157+
158+
self.assertEqual(len(result), 1)
159+
self.assertEqual(result[0].text, "final answer")
160+
self.assertEqual(_FakeChatMessages.last_video_kwargs["nframes"], 2)
161+
162+
processor_call = model.processor.calls[0]
163+
metadata = processor_call["video_metadata"][0]
164+
self.assertEqual(processor_call["videos"][0].shape[0], 2)
165+
self.assertTrue(np.array_equal(metadata.frames_indices, np.array([0, 4])))
166+
self.assertEqual(len(metadata.frames_indices), processor_call["videos"][0].shape[0])
167+
self.assertAlmostEqual(metadata.sampled_fps, 4.0)
168+
self.assertFalse(processor_call["do_sample_frames"])
169+
self.assertFalse(processor_call["do_resize"])
170+
171+
172+
if __name__ == "__main__":
173+
unittest.main()

0 commit comments

Comments
 (0)