Skip to content

Commit 2696751

Browse files
[Perf] Qwen3-Omni performance optimization (vllm-project#3878)
Signed-off-by: amy-why-3459 <wuhaiyan17@huawei.com> Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
1 parent 8c4a42b commit 2696751

2 files changed

Lines changed: 16 additions & 4 deletions

File tree

vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from vllm.config import ModelConfig, VllmConfig
2121
from vllm.inputs import PromptType, TokensPrompt
2222
from vllm.logger import init_logger
23-
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
2423
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal, SupportsPP, SupportsRealtime
2524
from vllm.model_executor.models.qwen3_asr_realtime import Qwen3ASRRealtimeBuffer
2625
from vllm.model_executor.models.qwen3_omni_moe_thinker import (
@@ -180,6 +179,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
180179
("hidden_states", "last"),
181180
("hidden_states", "trailing_text"),
182181
("embed", "tts_pad_projected"),
182+
# talker MTP codec codes must stay on GPU to avoid a per-step D2H
183+
# sync stall; build_mm_cpu handles the eventual D2H at payload time.
184+
("codes", "audio"),
183185
}
184186
# Keys that need to be accumulated across streaming inputs
185187
self.streaming_accumulated_keys: set[tuple[str, str]] = {
@@ -323,7 +325,12 @@ def get_mrope_input_positions(
323325
msg = "Qwen3 Omni thinker get_mrope_input_positions requires mm_features"
324326
raise ValueError(msg)
325327
return self.thinker.get_mrope_input_positions(input_tokens, mm_features)
326-
return MRotaryEmbedding.get_input_positions_tensor(input_tokens, **kwargs)
328+
# Talker/code2wav stages are text/codec-only and do not need
329+
# multimodal M-RoPE position computation. Return a cheap linear
330+
# position tensor to avoid unnecessary per-request M-RoPE work.
331+
seq_len = len(input_tokens)
332+
linear = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(3, seq_len)
333+
return linear, 0
327334

328335
def forward(
329336
self,
@@ -1253,7 +1260,7 @@ def compute_logits(
12531260
if (
12541261
getattr(self, "model_stage", None) == "talker"
12551262
and sampling_metadata is not None
1256-
and (sampling_metadata.temperature is None or (sampling_metadata.temperature <= 0).any())
1263+
and (sampling_metadata.temperature is None)
12571264
):
12581265
self._warn_talker_sampling_temperature(sampling_metadata)
12591266

vllm_omni/worker/gpu_model_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,12 @@ def _store_value(self, dest: dict, key: str, value: Any, gpu_keys: set) -> None:
17651765
if key in gpu_keys:
17661766
dest[key] = value.detach().clone()
17671767
else:
1768-
dest[key] = value.detach().to("cpu").contiguous()
1768+
t = value.detach()
1769+
if t.is_cuda:
1770+
dest[key] = t.to("cpu").contiguous()
1771+
else:
1772+
# If the tensor is already on the CPU, there is no need to unload it to the CPU.
1773+
dest[key] = t.contiguous()
17691774
elif isinstance(value, list):
17701775
dest[key] = [
17711776
(item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in value

0 commit comments

Comments
 (0)