Skip to content

Commit de42abb

Browse files
[CI] Heavy refactoring of Voxtral multimodal audio model tests (#34294)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
1 parent 60ca798 commit de42abb

File tree

11 files changed

+350
-70
lines changed

11 files changed

+350
-70
lines changed

requirements/rocm-test.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,5 @@ albumentations==1.4.6
9696
transformers==4.57.3
9797
# Pin HF Hub version
9898
huggingface-hub==0.36.2
99+
# Pin Mistral Common
100+
mistral-common[image,audio]==1.9.1

tests/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,6 @@ def _init(
419419
self.tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast" = (
420420
AutoTokenizer.from_pretrained(
421421
model_name,
422-
dtype=dtype,
423422
trust_remote_code=trust_remote_code,
424423
)
425424
)
@@ -430,7 +429,6 @@ def _init(
430429

431430
self.processor = AutoProcessor.from_pretrained(
432431
model_name,
433-
dtype=dtype,
434432
trust_remote_code=trust_remote_code,
435433
)
436434
if skip_tokenizer_init:

tests/models/multimodal/generation/test_voxtral.py

Lines changed: 138 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@
44
import json
55

66
import pytest
7-
import pytest_asyncio
87
from mistral_common.audio import Audio
98
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
109
from mistral_common.protocol.instruct.messages import UserMessage
10+
from transformers import VoxtralForConditionalGeneration
1111

1212
from vllm.tokenizers.mistral import MistralTokenizer
1313

1414
from ....conftest import AudioTestAssets
1515
from ....utils import RemoteOpenAIServer
16+
from ...utils import check_logprobs_close
1617
from .test_ultravox import MULTI_AUDIO_PROMPT, run_multi_audio_test
18+
from .vlm_utils import model_utils
1719

1820
MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507"
1921
MISTRAL_FORMAT_ARGS = [
@@ -26,40 +28,21 @@
2628
]
2729

2830

29-
@pytest.fixture()
30-
def server(request, audio_assets: AudioTestAssets):
31-
args = [
32-
"--enforce-eager",
33-
"--limit-mm-per-prompt",
34-
json.dumps({"audio": len(audio_assets)}),
35-
] + MISTRAL_FORMAT_ARGS
36-
37-
with RemoteOpenAIServer(
38-
MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"}
39-
) as remote_server:
40-
yield remote_server
41-
42-
43-
@pytest_asyncio.fixture
44-
async def client(server):
45-
async with server.get_async_client() as async_client:
46-
yield async_client
47-
48-
49-
def _get_prompt(audio_assets, question):
31+
def _get_prompt(audio_assets: AudioTestAssets, question: str) -> list[int]:
32+
"""Build a token-ID prompt via mistral_common for vLLM offline inference."""
5033
tokenizer = MistralTokenizer.from_pretrained(MODEL_NAME)
5134

5235
audios = [
53-
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
54-
for i in range(len(audio_assets))
36+
Audio.from_file(str(asset.get_local_path()), strict=False)
37+
for asset in audio_assets
5538
]
5639
audio_chunks = [
5740
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
5841
]
5942

60-
text_chunk = TextChunk(text=question)
61-
messages = [UserMessage(content=[*audio_chunks, text_chunk]).to_openai()]
62-
43+
messages = [
44+
UserMessage(content=[*audio_chunks, TextChunk(text=question)]).to_openai()
45+
]
6346
return tokenizer.apply_chat_template(messages=messages)
6447

6548

@@ -77,7 +60,7 @@ def test_models_with_multiple_audios(
7760
vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT)
7861
run_multi_audio_test(
7962
vllm_runner,
80-
[(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])],
63+
[(vllm_prompt, [a.audio_and_sample_rate for a in audio_assets])], # type: ignore[list-item]
8164
MODEL_NAME,
8265
dtype=dtype,
8366
max_tokens=max_tokens,
@@ -86,30 +69,142 @@ def test_models_with_multiple_audios(
8669
)
8770

8871

89-
@pytest.mark.asyncio
90-
async def test_online_serving(client, audio_assets: AudioTestAssets):
91-
"""Exercises online serving with/without chunked prefill enabled."""
72+
def test_online_serving(vllm_runner, audio_assets: AudioTestAssets):
73+
"""Two-layer accuracy and serving validation using Mistral format.
74+
75+
1. Offline vLLM greedy output (runs first to avoid CUDA fork issues
76+
with multiprocessing - see vlm_utils/core.py).
77+
2. Online OpenAI-compatible API output must match offline — validates
78+
that the serving path (chat template, audio encoding, tokenization)
79+
does not corrupt anything.
80+
81+
Steps run sequentially so each releases the GPU before the next starts.
82+
"""
9283

93-
def asset_to_chunk(asset):
84+
question = f"What's happening in these {len(audio_assets)} audio clips?"
85+
max_tokens = 10
86+
audio_data = [asset.audio_and_sample_rate for asset in audio_assets]
87+
88+
vllm_prompt = _get_prompt(audio_assets, question)
89+
with vllm_runner(
90+
MODEL_NAME,
91+
dtype="half",
92+
enforce_eager=True,
93+
tokenizer_mode="mistral",
94+
config_format="mistral",
95+
load_format="mistral",
96+
limit_mm_per_prompt={"audio": len(audio_assets)},
97+
) as vllm_model:
98+
offline_outputs = vllm_model.generate_greedy(
99+
[vllm_prompt],
100+
max_tokens,
101+
audios=[audio_data],
102+
)
103+
104+
offline_text = offline_outputs[0][1]
105+
assert offline_text, "Offline vLLM inference produced empty output"
106+
107+
def _asset_to_openai_chunk(asset):
94108
audio = Audio.from_file(str(asset.get_local_path()), strict=False)
95109
audio.format = "wav"
96-
audio_dict = AudioChunk.from_audio(audio).to_openai()
97-
return audio_dict
110+
return AudioChunk.from_audio(audio).to_openai()
98111

99-
audio_chunks = [asset_to_chunk(asset) for asset in audio_assets]
100-
text = f"What's happening in these {len(audio_assets)} audio clips?"
101112
messages = [
102113
{
103114
"role": "user",
104-
"content": [*audio_chunks, {"type": "text", "text": text}],
115+
"content": [
116+
*[_asset_to_openai_chunk(a) for a in audio_assets],
117+
{"type": "text", "text": question},
118+
],
105119
}
106120
]
107121

108-
chat_completion = await client.chat.completions.create(
109-
model=MODEL_NAME, messages=messages, max_tokens=10
110-
)
122+
server_args = [
123+
"--enforce-eager",
124+
"--limit-mm-per-prompt",
125+
json.dumps({"audio": len(audio_assets)}),
126+
*MISTRAL_FORMAT_ARGS,
127+
]
111128

112-
assert len(chat_completion.choices) == 1
113-
choice = chat_completion.choices[0]
114-
assert choice.message.content == "In the first audio clip, you hear a brief"
129+
with RemoteOpenAIServer(
130+
MODEL_NAME,
131+
server_args,
132+
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"},
133+
) as remote_server:
134+
client = remote_server.get_client()
135+
completion = client.chat.completions.create(
136+
model=MODEL_NAME,
137+
messages=messages,
138+
max_tokens=max_tokens,
139+
temperature=0,
140+
)
141+
142+
assert len(completion.choices) == 1
143+
choice = completion.choices[0]
115144
assert choice.finish_reason == "length"
145+
assert choice.message.content == offline_text, (
146+
f"Online serving output does not match offline inference.\n"
147+
f" Online: {choice.message.content!r}\n"
148+
f" Offline: {offline_text!r}"
149+
)
150+
151+
152+
def test_hf_reference(hf_runner, vllm_runner, audio_assets: AudioTestAssets):
153+
"""Compare vLLM Mistral-format output against HF Transformers reference.
154+
155+
Instead of requiring an exact text match (which is brittle across
156+
attention backends), we compare per-token logprobs using the standard
157+
check_logprobs_close helper: when tokens diverge at a position, each
158+
runner's chosen token must appear in the other's top-k logprobs.
159+
160+
Marked xfail(strict=False) so remaining edge-case mismatches
161+
don't block CI.
162+
"""
163+
question = f"What's happening in these {len(audio_assets)} audio clips?"
164+
max_tokens = 10
165+
num_logprobs = 5
166+
audio_data = [asset.audio_and_sample_rate for asset in audio_assets]
167+
168+
vllm_prompt = _get_prompt(audio_assets, question)
169+
with vllm_runner(
170+
MODEL_NAME,
171+
dtype="half",
172+
enforce_eager=True,
173+
tokenizer_mode="mistral",
174+
config_format="mistral",
175+
load_format="mistral",
176+
limit_mm_per_prompt={"audio": len(audio_assets)},
177+
) as vllm_model:
178+
vllm_outputs = vllm_model.generate_greedy_logprobs(
179+
[vllm_prompt],
180+
max_tokens,
181+
num_logprobs,
182+
audios=[audio_data],
183+
)
184+
assert vllm_outputs[0][1], "vLLM inference produced empty output"
185+
186+
with hf_runner(
187+
MODEL_NAME,
188+
dtype="half",
189+
auto_cls=VoxtralForConditionalGeneration,
190+
) as hf_model:
191+
hf_model = model_utils.voxtral_patch_hf_runner(hf_model)
192+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
193+
[question],
194+
max_tokens,
195+
num_logprobs,
196+
audios=[audio_data],
197+
)
198+
assert hf_outputs[0][1], "HF Transformers produced empty output"
199+
200+
print(
201+
f"HF Reference Comparison\n"
202+
f" vLLM: {vllm_outputs[0][1]!r}\n"
203+
f" HF: {hf_outputs[0][1]!r}"
204+
)
205+
check_logprobs_close(
206+
outputs_0_lst=vllm_outputs,
207+
outputs_1_lst=hf_outputs,
208+
name_0="vllm",
209+
name_1="hf",
210+
)

tests/models/multimodal/generation/test_voxtral_realtime.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
TranscriptionRequest,
1111
)
1212
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
13+
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy
1314

1415
from vllm import LLM, EngineArgs, SamplingParams
1516
from vllm.assets.audio import AudioAsset
@@ -26,7 +27,7 @@
2627
load_format="mistral",
2728
tokenizer_mode="mistral",
2829
enforce_eager=True,
29-
gpu_memory_utilization=0.4,
30+
gpu_memory_utilization=0.9,
3031
)
3132

3233

@@ -148,6 +149,9 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
148149

149150
output_tokens_list.append(output_tokens)
150151

151-
texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list]
152+
texts = [
153+
tokenizer.decode(output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE)
154+
for output_tokens in output_tokens_list
155+
]
152156
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
153157
assert texts == EXPECTED_TEXT

tests/models/multimodal/generation/vlm_utils/model_utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,3 +1215,91 @@ def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
12151215
hf_processor.patch_size = vision_encoder_info.get_patch_size()
12161216

12171217
return hf_model
1218+
1219+
1220+
def voxtral_patch_hf_runner(hf_model: "HfRunner") -> "HfRunner":
1221+
"""Patch HfRunner for Voxtral's conversation-based processor.
1222+
1223+
Two issues in HfRunner require patching:
1224+
1225+
1. VoxtralProcessor requires ``apply_chat_template()`` with conversation
1226+
dicts (accepting ``url``, ``path``, or ``base64`` audio) rather than
1227+
the standard ``processor(text=, audio=, sampling_rate=)`` interface.
1228+
2. HfRunner.get_inputs cannot handle multi-audio per prompt because it
1229+
mis-unpacks ``[(arr1, sr1), (arr2, sr2)]`` via a ``len == 2`` check.
1230+
1231+
We override ``get_inputs`` to build conversation dicts and call
1232+
``apply_chat_template`` directly, bypassing both issues. We also wrap
1233+
``model.generate`` to strip prompt tokens before decoding, since
1234+
HfRunner.generate calls batch_decode on the full sequence (prompt +
1235+
generated).
1236+
"""
1237+
1238+
import base64
1239+
import io
1240+
1241+
import soundfile as sf
1242+
1243+
processor = hf_model.processor
1244+
1245+
def _audio_to_base64(audio_array, sample_rate: int) -> str:
1246+
"""Encode a numpy audio array as a base64 WAV string."""
1247+
buf = io.BytesIO()
1248+
sf.write(buf, audio_array, int(sample_rate), format="WAV")
1249+
return base64.b64encode(buf.getvalue()).decode("ascii")
1250+
1251+
def patched_get_inputs(prompts, images=None, videos=None, audios=None, **kwargs):
1252+
all_inputs = []
1253+
for i, prompt in enumerate(prompts):
1254+
content: list[dict] = []
1255+
1256+
if audios is not None and audios[i] is not None:
1257+
items = audios[i]
1258+
if not isinstance(items, list):
1259+
items = [items]
1260+
for item in items:
1261+
if isinstance(item, (list, tuple)) and len(item) == 2:
1262+
arr, sr = item
1263+
else:
1264+
arr, sr = item, 16_000
1265+
content.append(
1266+
{
1267+
"type": "audio",
1268+
"base64": _audio_to_base64(arr, sr),
1269+
}
1270+
)
1271+
1272+
content.append({"type": "text", "text": prompt})
1273+
1274+
inputs = processor.apply_chat_template(
1275+
[{"role": "user", "content": content}]
1276+
)
1277+
if hasattr(inputs, "to"):
1278+
inputs = inputs.to(dtype=hf_model.dtype)
1279+
all_inputs.append(inputs)
1280+
1281+
return all_inputs
1282+
1283+
_orig_generate = hf_model.model.generate
1284+
1285+
def patched_generate(*args, **kwargs):
1286+
"""Strip prompt tokens so only generated tokens are decoded."""
1287+
input_ids = kwargs.get("input_ids")
1288+
if input_ids is None and args:
1289+
input_ids = args[0]
1290+
prompt_len = input_ids.shape[1] if input_ids is not None else 0
1291+
1292+
output = _orig_generate(*args, **kwargs)
1293+
if prompt_len:
1294+
if isinstance(output, torch.Tensor):
1295+
output = output[:, prompt_len:]
1296+
else:
1297+
# GenerateDecoderOnlyOutput - trim sequences but preserve
1298+
# scores/logits so generate_greedy_logprobs_limit can
1299+
# extract per-token logprobs.
1300+
output.sequences = output.sequences[:, prompt_len:]
1301+
return output
1302+
1303+
hf_model.get_inputs = patched_get_inputs # type: ignore[method-assign, assignment]
1304+
hf_model.model.generate = patched_generate # type: ignore[method-assign]
1305+
return hf_model

0 commit comments

Comments
 (0)