11# SPDX-License-Identifier: Apache-2.0
2- """Tests for Qwen3-ASR model: config, encoder shapes, weight sanitization ."""
2+ """Tests for Qwen3-ASR model behavior and weight mapping ."""
33
44from __future__ import annotations
55
66import json
77import os
88from pathlib import Path
9- from unittest .mock import MagicMock
9+ from types import SimpleNamespace
10+ from typing import Any , cast
1011
1112import mlx .core as mx
13+ import numpy as np
1214import pytest
15+ from transformers import WhisperFeatureExtractor
1316
17+ from vllm_metal .stt .audio import load_audio
1418from vllm_metal .stt .detection import is_stt_model
1519from vllm_metal .stt .loader import load_model
1620from vllm_metal .stt .qwen3_asr .config import (
2630)
2731from vllm_metal .stt .qwen3_asr .transcriber import Qwen3ASRTranscriber
2832
29- # ===========================================================================
30- # Configuration
31- # ===========================================================================
32-
33-
34- class TestQwen3ASRConfig :
35- """Tests for Qwen3ASRConfig.from_dict with 0.6B config."""
36-
37- def test_from_dict_basic (self ) -> None :
38- """Config should be parsed from nested thinker_config dict."""
39- d = {
40- "model_type" : "qwen3_asr" ,
41- "thinker_config" : {
42- "audio_config" : {
43- "d_model" : 896 ,
44- "num_mel_bins" : 128 ,
45- "encoder_layers" : 18 ,
46- "encoder_attention_heads" : 14 ,
47- "encoder_ffn_dim" : 3584 ,
48- "downsample_hidden_size" : 480 ,
49- "output_dim" : 1024 ,
50- "max_source_positions" : 1500 ,
51- "n_window" : 50 ,
52- "n_window_infer" : 800 ,
53- },
54- "text_config" : {
55- "hidden_size" : 1024 ,
56- "num_hidden_layers" : 28 ,
57- "num_attention_heads" : 16 ,
58- "num_key_value_heads" : 8 ,
59- "head_dim" : 128 ,
60- "intermediate_size" : 3072 ,
61- "vocab_size" : 151936 ,
62- "rms_norm_eps" : 1e-6 ,
63- "rope_theta" : 1000000.0 ,
64- "tie_word_embeddings" : True ,
65- },
66- "audio_token_id" : 151676 ,
67- "audio_start_token_id" : 151669 ,
68- "audio_end_token_id" : 151670 ,
69- },
70- }
71- config = Qwen3ASRConfig .from_dict (d )
72- assert config .audio_config .d_model == 896
73- assert config .audio_config .encoder_layers == 18
74- assert config .audio_config .num_mel_bins == 128
75- assert config .audio_config .n_window == 50
76- assert config .text_config .hidden_size == 1024
77- assert config .text_config .num_hidden_layers == 28
78- assert config .text_config .num_attention_heads == 16
79- assert config .text_config .num_key_value_heads == 8
80- assert config .audio_token_id == 151676
81- assert config .n_mels == 128
82- assert config .n_audio_ctx == 1500
83-
84- def test_defaults (self ) -> None :
85- """Default config should have 0.6B model values."""
86- config = Qwen3ASRConfig ()
87- assert config .audio_config .d_model == 896
88- assert config .text_config .vocab_size == 151936
89- assert config .eos_token_id == 151643
9033
34+ class TestQwen3ASRConfigAdaptation :
35+ """Tests for adapting the upstream config into the local MLX config."""
36+
37+ def test_from_vllm_config_keeps_local_eos_default_when_upstream_omits_it (
38+ self ,
39+ ) -> None :
40+ upstream_config = SimpleNamespace (
41+ thinker_config = SimpleNamespace (
42+ audio_config = SimpleNamespace (
43+ d_model = 896 ,
44+ num_mel_bins = 128 ,
45+ encoder_layers = 18 ,
46+ encoder_attention_heads = 14 ,
47+ encoder_ffn_dim = 3584 ,
48+ downsample_hidden_size = 480 ,
49+ output_dim = 1024 ,
50+ max_source_positions = 1500 ,
51+ n_window = 50 ,
52+ n_window_infer = 800 ,
53+ activation_function = "gelu" ,
54+ ),
55+ text_config = SimpleNamespace (
56+ hidden_size = 1024 ,
57+ num_hidden_layers = 28 ,
58+ num_attention_heads = 16 ,
59+ num_key_value_heads = 8 ,
60+ head_dim = 128 ,
61+ intermediate_size = 3072 ,
62+ vocab_size = 151936 ,
63+ rms_norm_eps = 1e-6 ,
64+ rope_theta = 1000000.0 ,
65+ tie_word_embeddings = True ,
66+ ),
67+ audio_token_id = 151676 ,
68+ )
69+ )
70+ config = Qwen3ASRConfig ._from_vllm_config (cast (Any , upstream_config ))
9171
92- # ===========================================================================
93- # CNN output lengths
94- # ===========================================================================
72+ assert config .audio_token_id == 151676
73+ assert config .eos_token_id == 151643
9574
9675
9776class TestCNNOutputLengths :
@@ -127,11 +106,6 @@ def test_feat_extract_3000_frames(self) -> None:
127106 assert Qwen3ASRAudioConfig ().feat_extract_output_length (3000 ) == 390
128107
129108
130- # ===========================================================================
131- # Audio Encoder shapes
132- # ===========================================================================
133-
134-
135109class TestAudioEncoderShapes :
136110 """Tests for AudioEncoder output dimensions."""
137111
@@ -182,11 +156,6 @@ def test_with_batch_dim(self, tiny_encoder) -> None:
182156 assert out .shape == (13 , 48 )
183157
184158
185- # ===========================================================================
186- # Qwen3 Attention
187- # ===========================================================================
188-
189-
190159class TestQwen3Attention :
191160 """Tests for GQA with QK normalization."""
192161
@@ -243,11 +212,6 @@ def test_cached_decode(self) -> None:
243212 assert cache2 [0 ].shape == (1 , 2 , 6 , 16 ) # 5 + 1 = 6
244213
245214
246- # ===========================================================================
247- # Weight sanitization
248- # ===========================================================================
249-
250-
251215class TestWeightSanitize :
252216 """Tests for Qwen3ASRModel.sanitize() weight mapping."""
253217
@@ -353,11 +317,6 @@ def test_casts_dtype(self, model) -> None:
353317 assert sanitized ["audio_tower.ln_post.weight" ].dtype == mx .float32
354318
355319
356- # ===========================================================================
357- # Qwen3 LM forward
358- # ===========================================================================
359-
360-
361320class TestQwen3LM :
362321 """Tests for Qwen3LM forward pass."""
363322
@@ -393,11 +352,6 @@ def test_decode_step(self, tiny_lm) -> None:
393352 assert logits .shape == (1 , 1 , 100 )
394353
395354
396- # ===========================================================================
397- # Full model
398- # ===========================================================================
399-
400-
401355class TestQwen3ASRModel :
402356 """Tests for the full Qwen3ASRModel."""
403357
@@ -430,9 +384,6 @@ def tiny_model(self):
430384 )
431385 return Qwen3ASRModel (config , dtype = mx .float32 )
432386
433- def test_model_type (self , tiny_model ) -> None :
434- assert tiny_model .model_type == "qwen3_asr"
435-
436387 def test_encode (self , tiny_model ) -> None :
437388 """Encode should produce audio embeddings."""
438389 mel = mx .random .normal ((16 , 100 ))
@@ -464,11 +415,6 @@ def test_prefill_and_decode(self, tiny_model) -> None:
464415 assert logits2 .shape == (1 , 1 , 100 )
465416
466417
467- # ===========================================================================
468- # Post-process output
469- # ===========================================================================
470-
471-
472418class TestPostProcessOutput :
473419 """Tests for Qwen3ASRTranscriber.post_process_output."""
474420
@@ -484,11 +430,6 @@ def test_empty_string(self) -> None:
484430 assert Qwen3ASRTranscriber .post_process_output ("" ) == ""
485431
486432
487- # ===========================================================================
488- # Config detection
489- # ===========================================================================
490-
491-
492433class TestPostProcessOutputTruncation :
493434 """Tests for special token truncation in post_process_output."""
494435
@@ -513,88 +454,6 @@ def test_strips_whitespace(self) -> None:
513454 assert Qwen3ASRTranscriber .post_process_output (text ) == "Hello world"
514455
515456
516- # ===========================================================================
517- # Build prompt tokens
518- # ===========================================================================
519-
520-
521- class TestBuildPromptTokens :
522- """Tests for Qwen3ASRTranscriber.build_prompt_tokens structure."""
523-
524- @pytest .fixture ()
525- def transcriber (self , tmp_path ):
526- """Create a transcriber with a mock tokenizer for prompt tests."""
527- config = Qwen3ASRConfig (
528- audio_token_id = 99 ,
529- audio_start_token_id = 97 ,
530- audio_end_token_id = 98 ,
531- eos_token_id = 0 ,
532- )
533- model = MagicMock ()
534- model .config = config
535-
536- # Inject mock tokenizer with deterministic encode
537- mock_tok = MagicMock ()
538- _encode_map = {
539- "<|im_start|>" : [10 ],
540- "<|im_end|>" : [11 ],
541- "user\n " : [20 ],
542- "assistant\n " : [30 ],
543- "\n " : [40 ],
544- }
545- mock_tok .encode = MagicMock (
546- side_effect = lambda s , add_special_tokens = False : _encode_map .get (s , [0 ])
547- )
548- t = Qwen3ASRTranscriber (model , tokenizer = mock_tok )
549- return t
550-
551- def test_audio_pad_count_matches_frames (self , transcriber ) -> None :
552- """Number of audio_pad tokens should equal n_audio_frames."""
553- # Act
554- prompt = transcriber .build_prompt_tokens (50 )
555-
556- # Assert
557- audio_pad_count = prompt .count (99 ) # audio_token_id
558- assert audio_pad_count == 50
559-
560- def test_audio_pad_count_zero (self , transcriber ) -> None :
561- """Zero audio frames should produce no audio_pad tokens."""
562- # Act
563- prompt = transcriber .build_prompt_tokens (0 )
564-
565- # Assert
566- assert prompt .count (99 ) == 0
567-
568- def test_prompt_contains_structural_tokens (self , transcriber ) -> None :
569- """Prompt should contain audio_start, audio_end, im_start, user, assistant."""
570- # Act
571- prompt = transcriber .build_prompt_tokens (5 )
572-
573- # Assert
574- assert 97 in prompt # audio_start
575- assert 98 in prompt # audio_end
576- assert 10 in prompt # im_start
577- assert 20 in prompt # user
578- assert 30 in prompt # assistant
579-
580- def test_prompt_structure_order (self , transcriber ) -> None :
581- """Audio tokens should be between audio_start and audio_end."""
582- # Act
583- prompt = transcriber .build_prompt_tokens (3 )
584-
585- # Assert
586- start_idx = prompt .index (97 ) # audio_start
587- end_idx = prompt .index (98 ) # audio_end
588- for i , tok in enumerate (prompt ):
589- if tok == 99 :
590- assert start_idx < i < end_idx
591-
592-
593- # ===========================================================================
594- # Config detection
595- # ===========================================================================
596-
597-
598457class TestConfigDetection :
599458 """Tests for is_stt_model with Qwen3-ASR config."""
600459
@@ -605,11 +464,6 @@ def test_qwen3_asr_detected(self, tmp_path) -> None:
605464 assert is_stt_model (str (tmp_path )) is True
606465
607466
608- # ===========================================================================
609- # Slow tests (require real model)
610- # ===========================================================================
611-
612-
613467@pytest .mark .slow
614468class TestModelLoad :
615469 """Tests that load the real Qwen3-ASR-0.6B model.
@@ -629,7 +483,7 @@ def _model_path(self):
629483 def test_load_model (self ) -> None :
630484 """Should load model without errors."""
631485 model = load_model (self ._MODEL_PATH )
632- assert model . model_type == "qwen3_asr"
486+ assert isinstance ( model , Qwen3ASRModel )
633487
634488 def test_encode_dummy_mel (self ) -> None :
635489 """Should encode a dummy mel spectrogram."""
@@ -642,11 +496,6 @@ def test_encode_dummy_mel(self) -> None:
642496
643497 def test_greedy_decode (self ) -> None :
644498 """Should encode + decode a real audio file using WhisperFeatureExtractor."""
645- import numpy as np
646- from transformers import WhisperFeatureExtractor
647-
648- from vllm_metal .stt .audio import load_audio
649-
650499 audio_path = os .environ .get ("QWEN3_ASR_AUDIO_PATH" )
651500 if not audio_path or not Path (audio_path ).exists ():
652501 pytest .skip ("QWEN3_ASR_AUDIO_PATH not set or file not found" )
@@ -669,7 +518,24 @@ def test_greedy_decode(self) -> None:
669518
670519 # Build prompt
671520 n_audio = audio_emb .shape [0 ]
672- prompt = transcriber .build_prompt_tokens (n_audio )
521+ tokenizer = transcriber .tokenizer
522+ audio_start_token_id = tokenizer .convert_tokens_to_ids (
523+ tokenizer .audio_bos_token
524+ )
525+ audio_token_id = tokenizer .convert_tokens_to_ids (tokenizer .audio_token )
526+ audio_end_token_id = tokenizer .convert_tokens_to_ids (tokenizer .audio_eos_token )
527+ prompt = (
528+ tokenizer .encode ("<|im_start|>" , add_special_tokens = False )
529+ + tokenizer .encode ("user\n " , add_special_tokens = False )
530+ + [audio_start_token_id ]
531+ + [audio_token_id ] * n_audio
532+ + [audio_end_token_id ]
533+ + tokenizer .encode ("\n " , add_special_tokens = False )
534+ + tokenizer .encode ("<|im_end|>" , add_special_tokens = False )
535+ + tokenizer .encode ("\n " , add_special_tokens = False )
536+ + tokenizer .encode ("<|im_start|>" , add_special_tokens = False )
537+ + tokenizer .encode ("assistant\n " , add_special_tokens = False )
538+ )
673539
674540 # Decode
675541 tokens = transcriber .greedy_decode_tokens (audio_emb , prompt , max_tokens = 100 )
0 commit comments