44import json
55
66import pytest
7- import pytest_asyncio
87from mistral_common .audio import Audio
98from mistral_common .protocol .instruct .chunk import AudioChunk , RawAudio , TextChunk
109from mistral_common .protocol .instruct .messages import UserMessage
10+ from transformers import VoxtralForConditionalGeneration
1111
1212from vllm .tokenizers .mistral import MistralTokenizer
1313
1414from ....conftest import AudioTestAssets
1515from ....utils import RemoteOpenAIServer
16+ from ...utils import check_logprobs_close
1617from .test_ultravox import MULTI_AUDIO_PROMPT , run_multi_audio_test
18+ from .vlm_utils import model_utils
1719
1820MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507"
1921MISTRAL_FORMAT_ARGS = [
2628]
2729
2830
29- @pytest .fixture ()
30- def server (request , audio_assets : AudioTestAssets ):
31- args = [
32- "--enforce-eager" ,
33- "--limit-mm-per-prompt" ,
34- json .dumps ({"audio" : len (audio_assets )}),
35- ] + MISTRAL_FORMAT_ARGS
36-
37- with RemoteOpenAIServer (
38- MODEL_NAME , args , env_dict = {"VLLM_AUDIO_FETCH_TIMEOUT" : "30" }
39- ) as remote_server :
40- yield remote_server
41-
42-
43- @pytest_asyncio .fixture
44- async def client (server ):
45- async with server .get_async_client () as async_client :
46- yield async_client
47-
48-
49- def _get_prompt (audio_assets , question ):
31+ def _get_prompt (audio_assets : AudioTestAssets , question : str ) -> list [int ]:
32+ """Build a token-ID prompt via mistral_common for vLLM offline inference."""
5033 tokenizer = MistralTokenizer .from_pretrained (MODEL_NAME )
5134
5235 audios = [
53- Audio .from_file (str (audio_assets [ i ] .get_local_path ()), strict = False )
54- for i in range ( len ( audio_assets ))
36+ Audio .from_file (str (asset .get_local_path ()), strict = False )
37+ for asset in audio_assets
5538 ]
5639 audio_chunks = [
5740 AudioChunk (input_audio = RawAudio .from_audio (audio )) for audio in audios
5841 ]
5942
60- text_chunk = TextChunk ( text = question )
61- messages = [ UserMessage (content = [* audio_chunks , text_chunk ]).to_openai ()]
62-
43+ messages = [
44+ UserMessage (content = [* audio_chunks , TextChunk ( text = question ) ]).to_openai ()
45+ ]
6346 return tokenizer .apply_chat_template (messages = messages )
6447
6548
@@ -77,7 +60,7 @@ def test_models_with_multiple_audios(
7760 vllm_prompt = _get_prompt (audio_assets , MULTI_AUDIO_PROMPT )
7861 run_multi_audio_test (
7962 vllm_runner ,
80- [(vllm_prompt , [audio .audio_and_sample_rate for audio in audio_assets ])],
63+ [(vllm_prompt , [a .audio_and_sample_rate for a in audio_assets ])], # type: ignore[list-item]
8164 MODEL_NAME ,
8265 dtype = dtype ,
8366 max_tokens = max_tokens ,
@@ -86,30 +69,142 @@ def test_models_with_multiple_audios(
8669 )
8770
8871
89- @pytest .mark .asyncio
90- async def test_online_serving (client , audio_assets : AudioTestAssets ):
91- """Exercises online serving with/without chunked prefill enabled."""
72+ def test_online_serving (vllm_runner , audio_assets : AudioTestAssets ):
73+ """Two-layer accuracy and serving validation using Mistral format.
74+
75+ 1. Offline vLLM greedy output (runs first to avoid CUDA fork issues
76+ with multiprocessing - see vlm_utils/core.py).
77+ 2. Online OpenAI-compatible API output must match offline — validates
78+ that the serving path (chat template, audio encoding, tokenization)
79+ does not corrupt anything.
80+
81+ Steps run sequentially so each releases the GPU before the next starts.
82+ """
9283
93- def asset_to_chunk (asset ):
84+ question = f"What's happening in these { len (audio_assets )} audio clips?"
85+ max_tokens = 10
86+ audio_data = [asset .audio_and_sample_rate for asset in audio_assets ]
87+
88+ vllm_prompt = _get_prompt (audio_assets , question )
89+ with vllm_runner (
90+ MODEL_NAME ,
91+ dtype = "half" ,
92+ enforce_eager = True ,
93+ tokenizer_mode = "mistral" ,
94+ config_format = "mistral" ,
95+ load_format = "mistral" ,
96+ limit_mm_per_prompt = {"audio" : len (audio_assets )},
97+ ) as vllm_model :
98+ offline_outputs = vllm_model .generate_greedy (
99+ [vllm_prompt ],
100+ max_tokens ,
101+ audios = [audio_data ],
102+ )
103+
104+ offline_text = offline_outputs [0 ][1 ]
105+ assert offline_text , "Offline vLLM inference produced empty output"
106+
107+ def _asset_to_openai_chunk (asset ):
94108 audio = Audio .from_file (str (asset .get_local_path ()), strict = False )
95109 audio .format = "wav"
96- audio_dict = AudioChunk .from_audio (audio ).to_openai ()
97- return audio_dict
110+ return AudioChunk .from_audio (audio ).to_openai ()
98111
99- audio_chunks = [asset_to_chunk (asset ) for asset in audio_assets ]
100- text = f"What's happening in these { len (audio_assets )} audio clips?"
101112 messages = [
102113 {
103114 "role" : "user" ,
104- "content" : [* audio_chunks , {"type" : "text" , "text" : text }],
115+ "content" : [
116+ * [_asset_to_openai_chunk (a ) for a in audio_assets ],
117+ {"type" : "text" , "text" : question },
118+ ],
105119 }
106120 ]
107121
108- chat_completion = await client .chat .completions .create (
109- model = MODEL_NAME , messages = messages , max_tokens = 10
110- )
122+ server_args = [
123+ "--enforce-eager" ,
124+ "--limit-mm-per-prompt" ,
125+ json .dumps ({"audio" : len (audio_assets )}),
126+ * MISTRAL_FORMAT_ARGS ,
127+ ]
111128
112- assert len (chat_completion .choices ) == 1
113- choice = chat_completion .choices [0 ]
114- assert choice .message .content == "In the first audio clip, you hear a brief"
129+ with RemoteOpenAIServer (
130+ MODEL_NAME ,
131+ server_args ,
132+ env_dict = {"VLLM_AUDIO_FETCH_TIMEOUT" : "30" },
133+ ) as remote_server :
134+ client = remote_server .get_client ()
135+ completion = client .chat .completions .create (
136+ model = MODEL_NAME ,
137+ messages = messages ,
138+ max_tokens = max_tokens ,
139+ temperature = 0 ,
140+ )
141+
142+ assert len (completion .choices ) == 1
143+ choice = completion .choices [0 ]
115144 assert choice .finish_reason == "length"
145+ assert choice .message .content == offline_text , (
146+ f"Online serving output does not match offline inference.\n "
147+ f" Online: { choice .message .content !r} \n "
148+ f" Offline: { offline_text !r} "
149+ )
150+
151+
152+ def test_hf_reference (hf_runner , vllm_runner , audio_assets : AudioTestAssets ):
153+ """Compare vLLM Mistral-format output against HF Transformers reference.
154+
155+ Instead of requiring an exact text match (which is brittle across
156+ attention backends), we compare per-token logprobs using the standard
157+ check_logprobs_close helper: when tokens diverge at a position, each
158+ runner's chosen token must appear in the other's top-k logprobs.
159+
160+ Marked xfail(strict=False) so remaining edge-case mismatches
161+ don't block CI.
162+ """
163+ question = f"What's happening in these { len (audio_assets )} audio clips?"
164+ max_tokens = 10
165+ num_logprobs = 5
166+ audio_data = [asset .audio_and_sample_rate for asset in audio_assets ]
167+
168+ vllm_prompt = _get_prompt (audio_assets , question )
169+ with vllm_runner (
170+ MODEL_NAME ,
171+ dtype = "half" ,
172+ enforce_eager = True ,
173+ tokenizer_mode = "mistral" ,
174+ config_format = "mistral" ,
175+ load_format = "mistral" ,
176+ limit_mm_per_prompt = {"audio" : len (audio_assets )},
177+ ) as vllm_model :
178+ vllm_outputs = vllm_model .generate_greedy_logprobs (
179+ [vllm_prompt ],
180+ max_tokens ,
181+ num_logprobs ,
182+ audios = [audio_data ],
183+ )
184+ assert vllm_outputs [0 ][1 ], "vLLM inference produced empty output"
185+
186+ with hf_runner (
187+ MODEL_NAME ,
188+ dtype = "half" ,
189+ auto_cls = VoxtralForConditionalGeneration ,
190+ ) as hf_model :
191+ hf_model = model_utils .voxtral_patch_hf_runner (hf_model )
192+ hf_outputs = hf_model .generate_greedy_logprobs_limit (
193+ [question ],
194+ max_tokens ,
195+ num_logprobs ,
196+ audios = [audio_data ],
197+ )
198+ assert hf_outputs [0 ][1 ], "HF Transformers produced empty output"
199+
200+ print (
201+ f"HF Reference Comparison\n "
202+ f" vLLM: { vllm_outputs [0 ][1 ]!r} \n "
203+ f" HF: { hf_outputs [0 ][1 ]!r} "
204+ )
205+ check_logprobs_close (
206+ outputs_0_lst = vllm_outputs ,
207+ outputs_1_lst = hf_outputs ,
208+ name_0 = "vllm" ,
209+ name_1 = "hf" ,
210+ )
0 commit comments