[Pref] Support Qwen3 Omni code2wav batch infernce with async chunk#1246
[Pref] Support Qwen3 Omni code2wav batch infernce with async chunk#1246ZeldaHuang wants to merge 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e2e6a6fa99
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| max_context_length = 50 | ||
| self.codes_buffer = torch.zeros( | ||
| (self.vllm_config.scheduler_config.max_num_seqs, code2wav_config.num_quantizers, max_context_length), |
There was a problem hiding this comment.
Avoid fixed 50-step codes_buffer for batched code2wav
When ubatch_slices is present, the forward path copies each request’s codes into self.codes_buffer using seq_len = code.shape[0] // 16. The buffer is hard‑coded to length 50, so any request with more than 50 codec frames will hit a shape mismatch in copy_ (dest is [num_quantizers, 50], source is [num_quantizers, seq_len]) and crash. This only happens in the async‑chunk/batched path, but it means legitimate longer code sequences will fail at runtime unless the buffer is sized from the actual max sequence length.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Pull request overview
This pull request adds batch inference support for the Qwen3 Omni model's code2wav stage when using async chunk processing. The changes enable multiple requests to be processed together in the code2wav stage, addressing a performance bottleneck identified when concurrency > 1. Testing shows a significant performance improvement from 9697ms to 6435ms in end-to-end processing time.
Changes:
- Modified the code2wav methods to return lists of tensors instead of single tensors to support batching
- Added batch processing logic in the forward pass using a pre-allocated codes_buffer and ubatch_slices from forward context
- Updated the model runner to handle list-based multimodal outputs for batched processing
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 11 comments.
| File | Description |
|---|---|
| vllm_omni/worker/gpu_generation_model_runner.py | Added request state management for async chunk mode, reused ubatch_slices for batch processing, and updated multimodal output handling to support lists of tensors |
| vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py | Changed return types from single tensor to list of tensors, updated chunked_decode_streaming to process batches using ubatch_slices from forward context |
| vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py | Added codes_buffer for batch processing, updated forward method to handle batched input with ubatch_slices, changed generate_audio to return list of tensors |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Returns: | ||
| waveform: [batch, 1, waveform_len] - Complete waveform | ||
| """ | ||
| # TODO Support batch_size > 1 |
There was a problem hiding this comment.
The TODO comment indicates that batch_size > 1 is not supported in the non-streaming chunked_decode method, but the streaming version (chunked_decode_streaming) appears to support batching. This inconsistency could lead to unexpected behavior if users call chunked_decode with batch_size > 1. Consider either implementing batch support for chunked_decode or adding a runtime check to raise an error if batch_size > 1 is passed.
There was a problem hiding this comment.
@ZeldaHuang should it be removed? have you test the non streaming cases as well?
| @@ -176,6 +177,7 @@ def chunked_decode( | |||
| Returns: | |||
| waveform: [batch, 1, waveform_len] - Complete waveform | |||
There was a problem hiding this comment.
The return type documentation states "waveform: [batch, 1, waveform_len] - Complete waveform" but the actual return type has been changed to list[torch.Tensor]. The documentation should be updated to reflect that this method now returns a list containing one tensor for batch_size=1, where each tensor has shape [1, waveform_len].
| waveform: [batch, 1, waveform_len] - Complete waveform | |
| list[torch.Tensor]: Complete waveform decoded from the input | |
| codes. For ``batch_size == 1``, this is a list containing a | |
| single tensor with shape ``[1, waveform_len]``. |
| # ==================== Audio Generation ==================== | ||
|
|
||
| def generate_audio(self, code: torch.Tensor, voice_type: str) -> torch.Tensor: | ||
| def generate_audio(self, code: torch.Tensor, voice_type: str) -> list[torch.Tensor]: |
There was a problem hiding this comment.
The function signature has been changed to return list[torch.Tensor], but the docstring still documents the return type as "audio_tensor: [1, waveform_len] - Audio waveform". Update the Returns section of the docstring to reflect that this method returns a list of audio tensors.
| if isinstance(out, list) and len(out) == num_reqs: | ||
| mm_payload[key] = out[i].detach().to("cpu").contiguous() | ||
| elif isinstance(out, torch.Tensor): | ||
| mm_payload[key] = out.detach().to("cpu").contiguous() |
There was a problem hiding this comment.
The code handles two cases: when out is a list with length matching num_reqs, and when out is a tensor. However, if out is a list with a different length, or is neither a list nor a tensor, the key won't be added to mm_payload, potentially leading to incomplete output data. Consider adding an else clause to handle unexpected types or list lengths with a warning or error message.
| if isinstance(out, list) and len(out) == num_reqs: | |
| mm_payload[key] = out[i].detach().to("cpu").contiguous() | |
| elif isinstance(out, torch.Tensor): | |
| mm_payload[key] = out.detach().to("cpu").contiguous() | |
| if isinstance(out, list): | |
| if len(out) != num_reqs: | |
| raise ValueError( | |
| f"Multimodal output list for key '{key}' has length {len(out)} " | |
| f"but expected {num_reqs} (one entry per request)." | |
| ) | |
| mm_payload[key] = out[i].detach().to("cpu").contiguous() | |
| elif isinstance(out, torch.Tensor): | |
| mm_payload[key] = out.detach().to("cpu").contiguous() | |
| else: | |
| raise TypeError( | |
| f"Unsupported multimodal output type for key '{key}': {type(out)}" | |
| ) |
| max_context_length = 50 | ||
| self.codes_buffer = torch.zeros( | ||
| (self.vllm_config.scheduler_config.max_num_seqs, code2wav_config.num_quantizers, max_context_length), | ||
| device=self._module_device(self.code2wav), | ||
| dtype=torch.long, | ||
| ) |
There was a problem hiding this comment.
The codes_buffer is allocated with a fixed size based on max_num_seqs and a hard-coded max_context_length of 50. When processing batches, only the relevant portion is sliced (line 374), but the entire buffer remains allocated. For large max_num_seqs values, this could lead to significant memory overhead. Consider whether a dynamic allocation approach or a smaller buffer size would be more memory-efficient, especially if typical sequence lengths are much shorter than 50.
| Returns: | ||
| waveform: [batch, 1, waveform_len] - Complete waveform |
There was a problem hiding this comment.
The return type documentation states "waveform: [batch, 1, waveform_len] - Complete waveform" but the actual return type has been changed to list[torch.Tensor]. The documentation should be updated to reflect that this method now returns a list of tensors, where each tensor has shape [1, waveform_len], with one element per request in the batch.
| # Remove context from output (context_size * total_upsample samples) | ||
| wavs.append(wav_chunk[..., context_size * self.total_upsample :]) | ||
| return torch.cat(wavs, dim=-1) | ||
| code_seq_lens = [codes.shape[-1]] |
There was a problem hiding this comment.
When ubatch_slices is None, the code defaults to using codes.shape[-1] as the sequence length. However, if codes has a batch dimension > 1 (codes.shape[0] > 1), this will process all batch elements with the same sequence length assumption, which may be incorrect if different batch elements have different actual sequence lengths. Consider whether sequence length information should be tracked per batch element when ubatch_slices is None.
| code_seq_lens = [codes.shape[-1]] | |
| # Fallback: assume all batch elements share the same sequence length. | |
| # Create one entry per batch so that each element is processed. | |
| code_seq_lens = [codes.shape[-1]] * codes.shape[0] |
| ubatch_slices = get_forward_context().ubatch_slices | ||
| if ubatch_slices is not None: | ||
| max_seq_len = max(ubatch_slices) // 16 | ||
| batch_size = len(ubatch_slices) | ||
| split_codes = torch.split(input_ids, ubatch_slices, dim=0) | ||
| for idx, code in enumerate(split_codes): | ||
| seq_len = code.shape[0] // 16 | ||
| self.codes_buffer[idx, :, :seq_len].copy_(code.reshape(16, -1)) | ||
| codes = self.codes_buffer[:batch_size, :, :max_seq_len] |
There was a problem hiding this comment.
The codes_buffer is reused across multiple forward passes without being explicitly cleared. While the code uses .copy_() to overwrite relevant portions (line 373), if a subsequent batch has shorter sequences than a previous batch, stale data from the previous batch could remain in the buffer beyond the current sequence length. While the slicing on line 374 ensures only the valid portion is used, consider adding a comment to clarify this is intentional, or explicitly zero out the buffer regions that will be used before copying to make the behavior more explicit.
| assert req_state is not None | ||
| req_state.prompt_token_ids = cached_reqs.prompt_token_ids.get(req_id) | ||
| req_states.append(req_state) | ||
| self.input_batch.remove_request(req_id) |
There was a problem hiding this comment.
In the request update logic, requests from cached_reqs.req_ids are removed from the input batch at line 69, but some of these requests may have already been removed earlier at line 61 if they were in unscheduled_req_ids. If remove_request doesn't gracefully handle removing a request that's already been removed, this could cause an error. Verify that remove_request is idempotent, or adjust the logic to only remove requests that are still present in the batch.
| self.input_batch.remove_request(req_id) | |
| # Remove the request from the current input batch only if it is still present. | |
| if req_id in self.input_batch.req_id_to_index: | |
| self.input_batch.remove_request(req_id) |
| self.input_batch.remove_request(req_id) | ||
| cached_reqs = scheduler_output.scheduled_cached_reqs | ||
| req_states = [] | ||
| for _, req_id in enumerate(cached_reqs.req_ids): |
There was a problem hiding this comment.
The loop at line 64 already iterates over cached_reqs.req_ids, so the enumerate variable (first element of the tuple) is never used. Consider using for req_id in cached_reqs.req_ids: instead of for _, req_id in enumerate(cached_reqs.req_ids): for cleaner code.
| for _, req_id in enumerate(cached_reqs.req_ids): | |
| for req_id in cached_reqs.req_ids: |
|
Thank you very much for your contribution. Could you please modify the |
Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com>
Done |
|
LGTM |
|
test plan test result |
please also attach the result before this PR |
hsliuustc0106
left a comment
There was a problem hiding this comment.
many of the copilot comments make sense, please check
| ) | ||
| self.model = self.code2wav | ||
| self.requires_raw_input_tokens = True | ||
| max_context_length = 50 |
There was a problem hiding this comment.
is 50 a model specific number for the context_length?
| Returns: | ||
| waveform: [batch, 1, waveform_len] - Complete waveform | ||
| """ | ||
| # TODO Support batch_size > 1 |
There was a problem hiding this comment.
@ZeldaHuang should it be removed? have you test the non streaming cases as well?
| runtime: | ||
| devices: "1" | ||
| max_batch_size: 1 | ||
| max_batch_size: 64 |
There was a problem hiding this comment.
how about the non async and non streaming cases?
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
ref #1211
Test Plan
Test Result
before this pr:
after this pr:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)