Skip to content

[Pref] Support Qwen3 Omni code2wav batch infernce with async chunk#1246

Open
ZeldaHuang wants to merge 2 commits intovllm-project:mainfrom
ZeldaHuang:support_code2wav_batch
Open

[Pref] Support Qwen3 Omni code2wav batch infernce with async chunk#1246
ZeldaHuang wants to merge 2 commits intovllm-project:mainfrom
ZeldaHuang:support_code2wav_batch

Conversation

@ZeldaHuang
Copy link
Contributor

@ZeldaHuang ZeldaHuang commented Feb 6, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

ref #1211

Test Plan

python openai_chat_completion_client_for_multimodal_generation.py \
 --query-type text \
 --model /mnt/data/models/Qwen3-Omni-30B-A3B-Instruct \
 --num-concurrent-requests 16 \

Test Result

before this pr:

'e2e_total_time_ms': 9697.282075881958

after this pr:

'e2e_total_time_ms': 6435.441255569458

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e2e6a6fa99

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +174 to +176
max_context_length = 50
self.codes_buffer = torch.zeros(
(self.vllm_config.scheduler_config.max_num_seqs, code2wav_config.num_quantizers, max_context_length),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid fixed 50-step codes_buffer for batched code2wav

When ubatch_slices is present, the forward path copies each request’s codes into self.codes_buffer using seq_len = code.shape[0] // 16. The buffer is hard‑coded to length 50, so any request with more than 50 codec frames will hit a shape mismatch in copy_ (dest is [num_quantizers, 50], source is [num_quantizers, seq_len]) and crash. This only happens in the async‑chunk/batched path, but it means legitimate longer code sequences will fail at runtime unless the buffer is sized from the actual max sequence length.

Useful? React with 👍 / 👎.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds batch inference support for the Qwen3 Omni model's code2wav stage when using async chunk processing. The changes enable multiple requests to be processed together in the code2wav stage, addressing a performance bottleneck identified when concurrency > 1. Testing shows a significant performance improvement from 9697ms to 6435ms in end-to-end processing time.

Changes:

  • Modified the code2wav methods to return lists of tensors instead of single tensors to support batching
  • Added batch processing logic in the forward pass using a pre-allocated codes_buffer and ubatch_slices from forward context
  • Updated the model runner to handle list-based multimodal outputs for batched processing

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 11 comments.

File Description
vllm_omni/worker/gpu_generation_model_runner.py Added request state management for async chunk mode, reused ubatch_slices for batch processing, and updated multimodal output handling to support lists of tensors
vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py Changed return types from single tensor to list of tensors, updated chunked_decode_streaming to process batches using ubatch_slices from forward context
vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py Added codes_buffer for batch processing, updated forward method to handle batched input with ubatch_slices, changed generate_audio to return list of tensors

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Returns:
waveform: [batch, 1, waveform_len] - Complete waveform
"""
# TODO Support batch_size > 1
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment indicates that batch_size > 1 is not supported in the non-streaming chunked_decode method, but the streaming version (chunked_decode_streaming) appears to support batching. This inconsistency could lead to unexpected behavior if users call chunked_decode with batch_size > 1. Consider either implementing batch support for chunked_decode or adding a runtime check to raise an error if batch_size > 1 is passed.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZeldaHuang should it be removed? have you test the non streaming cases as well?

@@ -176,6 +177,7 @@ def chunked_decode(
Returns:
waveform: [batch, 1, waveform_len] - Complete waveform
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type documentation states "waveform: [batch, 1, waveform_len] - Complete waveform" but the actual return type has been changed to list[torch.Tensor]. The documentation should be updated to reflect that this method now returns a list containing one tensor for batch_size=1, where each tensor has shape [1, waveform_len].

Suggested change
waveform: [batch, 1, waveform_len] - Complete waveform
list[torch.Tensor]: Complete waveform decoded from the input
codes. For ``batch_size == 1``, this is a list containing a
single tensor with shape ``[1, waveform_len]``.

Copilot uses AI. Check for mistakes.
# ==================== Audio Generation ====================

def generate_audio(self, code: torch.Tensor, voice_type: str) -> torch.Tensor:
def generate_audio(self, code: torch.Tensor, voice_type: str) -> list[torch.Tensor]:
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function signature has been changed to return list[torch.Tensor], but the docstring still documents the return type as "audio_tensor: [1, waveform_len] - Audio waveform". Update the Returns section of the docstring to reflect that this method returns a list of audio tensors.

Copilot uses AI. Check for mistakes.
Comment on lines +365 to +368
if isinstance(out, list) and len(out) == num_reqs:
mm_payload[key] = out[i].detach().to("cpu").contiguous()
elif isinstance(out, torch.Tensor):
mm_payload[key] = out.detach().to("cpu").contiguous()
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code handles two cases: when out is a list with length matching num_reqs, and when out is a tensor. However, if out is a list with a different length, or is neither a list nor a tensor, the key won't be added to mm_payload, potentially leading to incomplete output data. Consider adding an else clause to handle unexpected types or list lengths with a warning or error message.

Suggested change
if isinstance(out, list) and len(out) == num_reqs:
mm_payload[key] = out[i].detach().to("cpu").contiguous()
elif isinstance(out, torch.Tensor):
mm_payload[key] = out.detach().to("cpu").contiguous()
if isinstance(out, list):
if len(out) != num_reqs:
raise ValueError(
f"Multimodal output list for key '{key}' has length {len(out)} "
f"but expected {num_reqs} (one entry per request)."
)
mm_payload[key] = out[i].detach().to("cpu").contiguous()
elif isinstance(out, torch.Tensor):
mm_payload[key] = out.detach().to("cpu").contiguous()
else:
raise TypeError(
f"Unsupported multimodal output type for key '{key}': {type(out)}"
)

Copilot uses AI. Check for mistakes.
Comment on lines +174 to +179
max_context_length = 50
self.codes_buffer = torch.zeros(
(self.vllm_config.scheduler_config.max_num_seqs, code2wav_config.num_quantizers, max_context_length),
device=self._module_device(self.code2wav),
dtype=torch.long,
)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The codes_buffer is allocated with a fixed size based on max_num_seqs and a hard-coded max_context_length of 50. When processing batches, only the relevant portion is sliced (line 374), but the entire buffer remains allocated. For large max_num_seqs values, this could lead to significant memory overhead. Consider whether a dynamic allocation approach or a smaller buffer size would be more memory-efficient, especially if typical sequence lengths are much shorter than 50.

Copilot uses AI. Check for mistakes.
Comment on lines 217 to 218
Returns:
waveform: [batch, 1, waveform_len] - Complete waveform
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type documentation states "waveform: [batch, 1, waveform_len] - Complete waveform" but the actual return type has been changed to list[torch.Tensor]. The documentation should be updated to reflect that this method now returns a list of tensors, where each tensor has shape [1, waveform_len], with one element per request in the batch.

Copilot uses AI. Check for mistakes.
# Remove context from output (context_size * total_upsample samples)
wavs.append(wav_chunk[..., context_size * self.total_upsample :])
return torch.cat(wavs, dim=-1)
code_seq_lens = [codes.shape[-1]]
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When ubatch_slices is None, the code defaults to using codes.shape[-1] as the sequence length. However, if codes has a batch dimension > 1 (codes.shape[0] > 1), this will process all batch elements with the same sequence length assumption, which may be incorrect if different batch elements have different actual sequence lengths. Consider whether sequence length information should be tracked per batch element when ubatch_slices is None.

Suggested change
code_seq_lens = [codes.shape[-1]]
# Fallback: assume all batch elements share the same sequence length.
# Create one entry per batch so that each element is processed.
code_seq_lens = [codes.shape[-1]] * codes.shape[0]

Copilot uses AI. Check for mistakes.
Comment on lines +366 to +374
ubatch_slices = get_forward_context().ubatch_slices
if ubatch_slices is not None:
max_seq_len = max(ubatch_slices) // 16
batch_size = len(ubatch_slices)
split_codes = torch.split(input_ids, ubatch_slices, dim=0)
for idx, code in enumerate(split_codes):
seq_len = code.shape[0] // 16
self.codes_buffer[idx, :, :seq_len].copy_(code.reshape(16, -1))
codes = self.codes_buffer[:batch_size, :, :max_seq_len]
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The codes_buffer is reused across multiple forward passes without being explicitly cleared. While the code uses .copy_() to overwrite relevant portions (line 373), if a subsequent batch has shorter sequences than a previous batch, stale data from the previous batch could remain in the buffer beyond the current sequence length. While the slicing on line 374 ensures only the valid portion is used, consider adding a comment to clarify this is intentional, or explicitly zero out the buffer regions that will be used before copying to make the behavior more explicit.

Copilot uses AI. Check for mistakes.
assert req_state is not None
req_state.prompt_token_ids = cached_reqs.prompt_token_ids.get(req_id)
req_states.append(req_state)
self.input_batch.remove_request(req_id)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the request update logic, requests from cached_reqs.req_ids are removed from the input batch at line 69, but some of these requests may have already been removed earlier at line 61 if they were in unscheduled_req_ids. If remove_request doesn't gracefully handle removing a request that's already been removed, this could cause an error. Verify that remove_request is idempotent, or adjust the logic to only remove requests that are still present in the batch.

Suggested change
self.input_batch.remove_request(req_id)
# Remove the request from the current input batch only if it is still present.
if req_id in self.input_batch.req_id_to_index:
self.input_batch.remove_request(req_id)

Copilot uses AI. Check for mistakes.
self.input_batch.remove_request(req_id)
cached_reqs = scheduler_output.scheduled_cached_reqs
req_states = []
for _, req_id in enumerate(cached_reqs.req_ids):
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop at line 64 already iterates over cached_reqs.req_ids, so the enumerate variable (first element of the tuple) is never used. Consider using for req_id in cached_reqs.req_ids: instead of for _, req_id in enumerate(cached_reqs.req_ids): for cleaner code.

Suggested change
for _, req_id in enumerate(cached_reqs.req_ids):
for req_id in cached_reqs.req_ids:

Copilot uses AI. Check for mistakes.
@amy-why-3459
Copy link
Contributor

Thank you very much for your contribution. Could you please modify the max_batch_size configuration in the qwen3_omni_moe_async_chunk.yaml file for stage-2?

Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com>
@ZeldaHuang
Copy link
Contributor Author

Thank you very much for your contribution. Could you please modify the max_batch_size configuration in the qwen3_omni_moe_async_chunk.yaml file for stage-2?

Done

@amy-why-3459
Copy link
Contributor

LGTM

@amy-why-3459
Copy link
Contributor

test plan

vllm bench serve     --omni   --dataset-name random   
--port 28889   --max-concurrency 32   
--model Qwen/Qwen3-Omni-30B-A3B-Instruct   
--endpoint /v1/chat/completions   --backend openai-chat-omni   
--request-rate 1   --num-prompts 32   --random-input-len 100   
--ignore-eos   
--percentile-metrics ttft,tpot,itl,e2el,audio_ttfp,audio_rtf   
--random-output-len 100 
--extra_body '{"modalities": ["text", "audio"]}'

test result

============ Serving Benchmark Result ============
Successful requests:                     32
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           1.00
Benchmark duration (s):                  697.88
Request throughput (req/s):              0.05
Peak concurrent requests:                32.00
----------------End-to-end Latency----------------
Mean E2EL (ms):                          501468.78
Median E2EL (ms):                        623488.85
P99 E2EL (ms):                           665237.60
================== Text Result ===================
Total input tokens:                      3200
Total generated tokens:                  161610
Output token throughput (tok/s):         231.57
Peak output token throughput (tok/s):    590.00
Peak concurrent requests:                32.00
Total Token throughput (tok/s):          236.16
---------------Time to First Token----------------
Mean TTFT (ms):                          10782.98
Median TTFT (ms):                        13712.31
P99 TTFT (ms):                           16459.95
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          97.18
Median TPOT (ms):                        120.58
P99 TPOT (ms):                           128.48
---------------Inter-token Latency----------------
Mean ITL (ms):                           19.19
Median ITL (ms):                         0.01
P99 ITL (ms):                            125.99
================== Audio Result ==================
Total audio duration generated(s):       10351.68
Total audio frames generated:            248440140
Audio throughput(audio duration/s):      14.83
---------------Time to First Packet---------------
Mean AUDIO_TTFP (ms):                    181522.32
Median AUDIO_TTFP (ms):                  271592.73
P99 AUDIO_TTFP (ms):                     345857.61
-----------------Real Time Factor-----------------
Mean AUDIO_RTF:                          1.55
Median AUDIO_RTF:                        1.93
P99 AUDIO_RTF:                           2.06
==================================================

@hsliuustc0106
Copy link
Collaborator

test plan

vllm bench serve     --omni   --dataset-name random   
--port 28889   --max-concurrency 32   
--model Qwen/Qwen3-Omni-30B-A3B-Instruct   
--endpoint /v1/chat/completions   --backend openai-chat-omni   
--request-rate 1   --num-prompts 32   --random-input-len 100   
--ignore-eos   
--percentile-metrics ttft,tpot,itl,e2el,audio_ttfp,audio_rtf   
--random-output-len 100 
--extra_body '{"modalities": ["text", "audio"]}'

test result

============ Serving Benchmark Result ============
Successful requests:                     32
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           1.00
Benchmark duration (s):                  697.88
Request throughput (req/s):              0.05
Peak concurrent requests:                32.00
----------------End-to-end Latency----------------
Mean E2EL (ms):                          501468.78
Median E2EL (ms):                        623488.85
P99 E2EL (ms):                           665237.60
================== Text Result ===================
Total input tokens:                      3200
Total generated tokens:                  161610
Output token throughput (tok/s):         231.57
Peak output token throughput (tok/s):    590.00
Peak concurrent requests:                32.00
Total Token throughput (tok/s):          236.16
---------------Time to First Token----------------
Mean TTFT (ms):                          10782.98
Median TTFT (ms):                        13712.31
P99 TTFT (ms):                           16459.95
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          97.18
Median TPOT (ms):                        120.58
P99 TPOT (ms):                           128.48
---------------Inter-token Latency----------------
Mean ITL (ms):                           19.19
Median ITL (ms):                         0.01
P99 ITL (ms):                            125.99
================== Audio Result ==================
Total audio duration generated(s):       10351.68
Total audio frames generated:            248440140
Audio throughput(audio duration/s):      14.83
---------------Time to First Packet---------------
Mean AUDIO_TTFP (ms):                    181522.32
Median AUDIO_TTFP (ms):                  271592.73
P99 AUDIO_TTFP (ms):                     345857.61
-----------------Real Time Factor-----------------
Mean AUDIO_RTF:                          1.55
Median AUDIO_RTF:                        1.93
P99 AUDIO_RTF:                           2.06
==================================================

please also attach the result before this PR

Copy link
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

many of the copilot comments make sense, please check

)
self.model = self.code2wav
self.requires_raw_input_tokens = True
max_context_length = 50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is 50 a model specific number for the context_length?

Returns:
waveform: [batch, 1, waveform_len] - Complete waveform
"""
# TODO Support batch_size > 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZeldaHuang should it be removed? have you test the non streaming cases as well?

runtime:
devices: "1"
max_batch_size: 1
max_batch_size: 64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about the non async and non streaming cases?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC]: Qwen3 Omni code2wav stage support batching

3 participants