|
| 1 | +import types |
| 2 | +import unittest |
| 3 | +from unittest.mock import patch |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | + |
| 8 | +from lmms_eval.models.chat.qwen2_5_vl import Qwen2_5_VL |
| 9 | + |
| 10 | + |
| 11 | +class _FakeTokenizer: |
| 12 | + eos_token_id = 0 |
| 13 | + pad_token_id = 0 |
| 14 | + |
| 15 | + |
| 16 | +class _FakeInputs(dict): |
| 17 | + def __init__(self): |
| 18 | + super().__init__(input_ids=torch.tensor([[10, 11]])) |
| 19 | + |
| 20 | + @property |
| 21 | + def input_ids(self): |
| 22 | + return self["input_ids"] |
| 23 | + |
| 24 | + def to(self, device): |
| 25 | + return self |
| 26 | + |
| 27 | + |
| 28 | +class _FakeProcessor: |
| 29 | + def __init__(self): |
| 30 | + self.calls = [] |
| 31 | + |
| 32 | + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): |
| 33 | + return ["prompt"] |
| 34 | + |
| 35 | + def __call__(self, **kwargs): |
| 36 | + self.calls.append(kwargs) |
| 37 | + return _FakeInputs() |
| 38 | + |
| 39 | + def batch_decode(self, generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False): |
| 40 | + return ["final answer"] |
| 41 | + |
| 42 | + |
| 43 | +class _FakeModel: |
| 44 | + def generate(self, **kwargs): |
| 45 | + return torch.tensor([[10, 11, 12]]) |
| 46 | + |
| 47 | + |
| 48 | +class _FakeVideoReader: |
| 49 | + def __init__(self, path): |
| 50 | + self.path = path |
| 51 | + |
| 52 | + def __len__(self): |
| 53 | + return 5 |
| 54 | + |
| 55 | + |
| 56 | +class _FakeVideoMetadata: |
| 57 | + def __init__(self, frames_indices, fps=10.0, total_num_frames=5): |
| 58 | + self.frames_indices = np.asarray(frames_indices) |
| 59 | + self.fps = fps |
| 60 | + self.total_num_frames = total_num_frames |
| 61 | + |
| 62 | + @property |
| 63 | + def sampled_fps(self): |
| 64 | + return len(self.frames_indices) / self.total_num_frames * self.fps |
| 65 | + |
| 66 | + |
| 67 | +class _FakeChatMessages: |
| 68 | + last_video_kwargs = None |
| 69 | + |
| 70 | + def __init__(self, messages): |
| 71 | + self.messages = messages |
| 72 | + |
| 73 | + def extract_media(self): |
| 74 | + return [], ["demo.mp4"], [] |
| 75 | + |
| 76 | + def to_hf_messages(self, video_kwargs=None): |
| 77 | + type(self).last_video_kwargs = dict(video_kwargs or {}) |
| 78 | + return [ |
| 79 | + { |
| 80 | + "role": "user", |
| 81 | + "content": [ |
| 82 | + {"type": "video", "video": "demo.mp4", **(video_kwargs or {})}, |
| 83 | + {"type": "text", "text": "Describe the video"}, |
| 84 | + ], |
| 85 | + } |
| 86 | + ] |
| 87 | + |
| 88 | + |
| 89 | +class TestQwen25VLChat(unittest.TestCase): |
| 90 | + def _make_model(self, max_num_frames=3, fps=None): |
| 91 | + model = Qwen2_5_VL.__new__(Qwen2_5_VL) |
| 92 | + model._tokenizer = _FakeTokenizer() |
| 93 | + model.processor = _FakeProcessor() |
| 94 | + model._model = _FakeModel() |
| 95 | + model.max_pixels = 1024 |
| 96 | + model.min_pixels = 256 |
| 97 | + model.max_num_frames = max_num_frames |
| 98 | + model.fps = fps |
| 99 | + model.batch_size_per_gpu = 1 |
| 100 | + model.use_cache = False |
| 101 | + model.device_map = "cpu" |
| 102 | + model._device = torch.device("cpu") |
| 103 | + model._rank = 0 |
| 104 | + model._world_size = 1 |
| 105 | + model.task_dict = {"demo_task": {"test": [{"id": 0}]}} |
| 106 | + model.cache_hook = types.SimpleNamespace(add_partial=lambda *args, **kwargs: None) |
| 107 | + return model |
| 108 | + |
| 109 | + def test_generate_until_passes_video_metadata_and_kwargs_to_processor_with_fps(self): |
| 110 | + model = self._make_model(fps=2.5) |
| 111 | + metadata = _FakeVideoMetadata([0, 2, 4], fps=6.0, total_num_frames=5) |
| 112 | + video_tensor = torch.arange(12, dtype=torch.float32).reshape(3, 4) |
| 113 | + request = types.SimpleNamespace( |
| 114 | + args=("Describe the video", lambda doc: [{"role": "user", "content": []}], {}, 0, "demo_task", "test"), |
| 115 | + ) |
| 116 | + |
| 117 | + with ( |
| 118 | + patch("lmms_eval.models.chat.qwen2_5_vl.ChatMessages", _FakeChatMessages), |
| 119 | + patch("lmms_eval.models.chat.qwen2_5_vl.process_vision_info", return_value=(None, [(video_tensor, metadata)], {"do_sample_frames": False})), |
| 120 | + patch("lmms_eval.models.chat.qwen2_5_vl.decord", types.SimpleNamespace(VideoReader=_FakeVideoReader)), |
| 121 | + patch("lmms_eval.models.chat.qwen2_5_vl.log_metrics", lambda **kwargs: None), |
| 122 | + ): |
| 123 | + result = model.generate_until([request]) |
| 124 | + |
| 125 | + self.assertEqual(len(result), 1) |
| 126 | + self.assertEqual(result[0].text, "final answer") |
| 127 | + self.assertEqual(_FakeChatMessages.last_video_kwargs["fps"], 2.5) |
| 128 | + |
| 129 | + processor_call = model.processor.calls[0] |
| 130 | + self.assertTrue(torch.equal(processor_call["videos"][0], video_tensor)) |
| 131 | + self.assertIs(processor_call["video_metadata"][0], metadata) |
| 132 | + self.assertFalse(processor_call["do_sample_frames"]) |
| 133 | + self.assertFalse(processor_call["do_resize"]) |
| 134 | + self.assertAlmostEqual(processor_call["video_metadata"][0].sampled_fps, 3.6) |
| 135 | + |
| 136 | + def test_generate_until_keeps_sampled_metadata_in_sync_when_using_nframes(self): |
| 137 | + model = self._make_model(max_num_frames=3, fps=None) |
| 138 | + request = types.SimpleNamespace( |
| 139 | + args=("Describe the video", lambda doc: [{"role": "user", "content": []}], {}, 0, "demo_task", "test"), |
| 140 | + ) |
| 141 | + |
| 142 | + def fake_process_vision_info(batched_messages, return_video_kwargs=False, image_patch_size=14, return_video_metadata=False): |
| 143 | + video_content = batched_messages[0][0]["content"][0] |
| 144 | + nframes = video_content["nframes"] |
| 145 | + sampled_indices = np.linspace(0, 4, nframes, dtype=int) |
| 146 | + video_tensor = torch.arange(nframes * 4, dtype=torch.float32).reshape(nframes, 4) |
| 147 | + metadata = _FakeVideoMetadata(sampled_indices.tolist(), fps=10.0, total_num_frames=5) |
| 148 | + return None, [(video_tensor, metadata)], {"do_sample_frames": False} |
| 149 | + |
| 150 | + with ( |
| 151 | + patch("lmms_eval.models.chat.qwen2_5_vl.ChatMessages", _FakeChatMessages), |
| 152 | + patch("lmms_eval.models.chat.qwen2_5_vl.process_vision_info", side_effect=fake_process_vision_info), |
| 153 | + patch("lmms_eval.models.chat.qwen2_5_vl.decord", types.SimpleNamespace(VideoReader=_FakeVideoReader)), |
| 154 | + patch("lmms_eval.models.chat.qwen2_5_vl.log_metrics", lambda **kwargs: None), |
| 155 | + ): |
| 156 | + result = model.generate_until([request]) |
| 157 | + |
| 158 | + self.assertEqual(len(result), 1) |
| 159 | + self.assertEqual(result[0].text, "final answer") |
| 160 | + self.assertEqual(_FakeChatMessages.last_video_kwargs["nframes"], 2) |
| 161 | + |
| 162 | + processor_call = model.processor.calls[0] |
| 163 | + metadata = processor_call["video_metadata"][0] |
| 164 | + self.assertEqual(processor_call["videos"][0].shape[0], 2) |
| 165 | + self.assertTrue(np.array_equal(metadata.frames_indices, np.array([0, 4]))) |
| 166 | + self.assertEqual(len(metadata.frames_indices), processor_call["videos"][0].shape[0]) |
| 167 | + self.assertAlmostEqual(metadata.sampled_fps, 4.0) |
| 168 | + self.assertFalse(processor_call["do_sample_frames"]) |
| 169 | + self.assertFalse(processor_call["do_resize"]) |
| 170 | + |
| 171 | + |
| 172 | +if __name__ == "__main__": |
| 173 | + unittest.main() |
0 commit comments