|
20 | 20 | from vllm.config import ModelConfig, VllmConfig |
21 | 21 | from vllm.inputs import PromptType, TokensPrompt |
22 | 22 | from vllm.logger import init_logger |
23 | | -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding |
24 | 23 | from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal, SupportsPP, SupportsRealtime |
25 | 24 | from vllm.model_executor.models.qwen3_asr_realtime import Qwen3ASRRealtimeBuffer |
26 | 25 | from vllm.model_executor.models.qwen3_omni_moe_thinker import ( |
@@ -180,6 +179,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
180 | 179 | ("hidden_states", "last"), |
181 | 180 | ("hidden_states", "trailing_text"), |
182 | 181 | ("embed", "tts_pad_projected"), |
| 182 | + # talker MTP codec codes must stay on GPU to avoid a per-step D2H |
| 183 | + # sync stall; build_mm_cpu handles the eventual D2H at payload time. |
| 184 | + ("codes", "audio"), |
183 | 185 | } |
184 | 186 | # Keys that need to be accumulated across streaming inputs |
185 | 187 | self.streaming_accumulated_keys: set[tuple[str, str]] = { |
@@ -323,7 +325,12 @@ def get_mrope_input_positions( |
323 | 325 | msg = "Qwen3 Omni thinker get_mrope_input_positions requires mm_features" |
324 | 326 | raise ValueError(msg) |
325 | 327 | return self.thinker.get_mrope_input_positions(input_tokens, mm_features) |
326 | | - return MRotaryEmbedding.get_input_positions_tensor(input_tokens, **kwargs) |
| 328 | + # Talker/code2wav stages are text/codec-only and do not need |
| 329 | + # multimodal M-RoPE position computation. Return a cheap linear |
| 330 | + # position tensor to avoid unnecessary per-request M-RoPE work. |
| 331 | + seq_len = len(input_tokens) |
| 332 | + linear = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(3, seq_len) |
| 333 | + return linear, 0 |
327 | 334 |
|
328 | 335 | def forward( |
329 | 336 | self, |
@@ -1253,7 +1260,7 @@ def compute_logits( |
1253 | 1260 | if ( |
1254 | 1261 | getattr(self, "model_stage", None) == "talker" |
1255 | 1262 | and sampling_metadata is not None |
1256 | | - and (sampling_metadata.temperature is None or (sampling_metadata.temperature <= 0).any()) |
| 1263 | + and (sampling_metadata.temperature is None) |
1257 | 1264 | ): |
1258 | 1265 | self._warn_talker_sampling_temperature(sampling_metadata) |
1259 | 1266 |
|
|
0 commit comments