[Perf] Qwen3-Omni audio encoder: sglang-native + CUDA-graph wrapper (up to 6.4x)#333
[Perf] Qwen3-Omni audio encoder: sglang-native + CUDA-graph wrapper (up to 6.4x)#333jiaoew1991 wants to merge 2 commits intosgl-project:mainfrom
Conversation
…POC)
Adds a shape-static CUDA-graph-capturable wrapper around sglang main's
native Qwen3OmniMoeAudioEncoder, together with a micro-benchmark.
On 1x H100 bf16 the full graphed + CUDA graph path achieves up to 6.42x
speedup over the existing HF-wrapped audio encoder (B=1 L=300:
2.11 ms vs 13.54 ms), tapering to 2.06x at B=16 L=3000 as GPU compute
starts to dominate launch overhead.
Not a drop-in replacement yet: fixed (batch, seq_len) per instance,
and dispatch into the pipeline stage is left for a follow-up RFC.
Files:
- sglang_omni/models/qwen3_omni/components/audio_encoder_native.py
Thin wrapper around sglang.srt.models.qwen3_omni_moe.Qwen3OmniMoeAudioEncoder
with explicit init_sglang_env_for_encoder() bootstrap and strict weight
loading (q/k/v -> qkv_proj, out_proj -> proj, fail-loud on any unplaced
or untouched param).
- sglang_omni/models/qwen3_omni/components/audio_encoder_graphed.py
Fixed-shape wrapper that precomputes chunk_lengths, flat_mask_indices,
cu_seqlens and the sliced positional embedding at __init__. Replaces
the three CUDA-graph-incompatible patterns in base.forward:
1. torch.tensor([python_list], device=...) -> preallocated buffer
2. .item() / .tolist() / .max().item() -> precomputed int
3. padded_embed[bool_mask] -> index_select on a
precomputed long-index tensor
cu_seqlens is passed in the (cu_tensor, max_seqlen) tuple form expected
by sglang attention backends when SGLANG_VIT_ENABLE_CUDA_GRAPH is set,
falling back to a plain tensor otherwise.
- benchmarks/micro/bench_audio_encoder.py
Four modes: eager (HF vs sglang side-by-side), cuda_graph, graphed
(Phase 1 wrapper, with and without CUDA graph), and tp (TP=N via
torchrun). Enforces numerical parity vs HF across the (B, L) matrix
(max_abs_diff < 3e-1, mean_abs_diff < 5e-3, consistent with
industry-standard FA-vs-SDPA equivalence bounds).
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
There was a problem hiding this comment.
Pull request overview
Adds an sglang-native Qwen3-Omni audio encoder implementation plus a fixed-shape wrapper intended for CUDA graph capture/replay, along with a micro-benchmark to compare latency, memory, and numerical parity versus the current HF-based encoder.
Changes:
- Introduce
Qwen3OmniAudioEncoderNativeand a standalone sglang environment bootstrap + strict HF→sglang weight mapping/validation. - Add
GraphedAudioEncoderto precompute layout tensors and provide a graph-capturable “static forward” for fixed(batch, seq_len). - Add
bench_audio_encoder.pymicro-benchmark supporting eager / CUDA graph / graphed / TP modes plus parity checks.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
sglang_omni/models/qwen3_omni/components/audio_encoder_native.py |
Boots sglang distributed/server-args state and loads HF audio_tower weights into the sglang-native audio encoder. |
sglang_omni/models/qwen3_omni/components/audio_encoder_graphed.py |
Implements a fixed-shape wrapper that removes non-graph-capturable ops via precomputed indices and cu_seqlens handling. |
benchmarks/micro/bench_audio_encoder.py |
Adds a benchmark harness to measure latency/memory and enforce numerical parity across implementations and modes. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for _ in range(warmup): | ||
| _ = _run_once(encoder, input_features, lengths) | ||
| torch.cuda.synchronize(device) | ||
| torch.cuda.reset_peak_memory_stats(device) | ||
|
|
||
| latencies_ms = [] | ||
| out_sample = None | ||
| for i in range(iters): | ||
| torch.cuda.synchronize(device) | ||
| t0 = time.perf_counter() | ||
| out = _run_once(encoder, input_features, lengths) | ||
| torch.cuda.synchronize(device) | ||
| t1 = time.perf_counter() | ||
| latencies_ms.append((t1 - t0) * 1000.0) | ||
| if i == 0: | ||
| out_sample = out.detach().to(torch.float32).cpu() | ||
|
|
||
| peak_mem_mb = torch.cuda.max_memory_allocated(device) / (1024**2) |
There was a problem hiding this comment.
Benchmark helpers unconditionally call torch.cuda.synchronize/reset_peak_memory_stats/max_memory_allocated. If --device is set to a non-CUDA device (or CUDA is unavailable), this will raise at runtime. Consider validating torch.device(args.device).type == "cuda" early (or guarding the CUDA-only calls) so failures are clearer and CLI behavior is well-defined.
| for _ in range(warmup): | |
| _ = _run_once(encoder, input_features, lengths) | |
| torch.cuda.synchronize(device) | |
| torch.cuda.reset_peak_memory_stats(device) | |
| latencies_ms = [] | |
| out_sample = None | |
| for i in range(iters): | |
| torch.cuda.synchronize(device) | |
| t0 = time.perf_counter() | |
| out = _run_once(encoder, input_features, lengths) | |
| torch.cuda.synchronize(device) | |
| t1 = time.perf_counter() | |
| latencies_ms.append((t1 - t0) * 1000.0) | |
| if i == 0: | |
| out_sample = out.detach().to(torch.float32).cpu() | |
| peak_mem_mb = torch.cuda.max_memory_allocated(device) / (1024**2) | |
| use_cuda_timing = device.type == "cuda" and torch.cuda.is_available() | |
| for _ in range(warmup): | |
| _ = _run_once(encoder, input_features, lengths) | |
| if use_cuda_timing: | |
| torch.cuda.synchronize(device) | |
| torch.cuda.reset_peak_memory_stats(device) | |
| latencies_ms = [] | |
| out_sample = None | |
| for i in range(iters): | |
| if use_cuda_timing: | |
| torch.cuda.synchronize(device) | |
| t0 = time.perf_counter() | |
| out = _run_once(encoder, input_features, lengths) | |
| if use_cuda_timing: | |
| torch.cuda.synchronize(device) | |
| t1 = time.perf_counter() | |
| latencies_ms.append((t1 - t0) * 1000.0) | |
| if i == 0: | |
| out_sample = out.detach().to(torch.float32).cpu() | |
| peak_mem_mb = ( | |
| torch.cuda.max_memory_allocated(device) / (1024**2) | |
| if use_cuda_timing | |
| else 0.0 | |
| ) |
| os.environ.setdefault("MASTER_ADDR", "127.0.0.1") | ||
| os.environ.setdefault("MASTER_PORT", "29501") | ||
| os.environ.setdefault("WORLD_SIZE", "1") | ||
| os.environ.setdefault("RANK", "0") | ||
| os.environ.setdefault("LOCAL_RANK", "0") |
There was a problem hiding this comment.
init_sglang_env_for_encoder hard-codes a default MASTER_PORT=29501. When running multiple local benchmarks/tests on the same host, this can lead to nondeterministic “address already in use” failures. Consider following the repo’s existing pattern of selecting a free ephemeral port when MASTER_PORT is unset (see models/ming_omni/components/image_encoder.py), or reusing the _resolve_nccl_port() approach used elsewhere.
| import torch.distributed as dist | ||
|
|
There was a problem hiding this comment.
Unused import: torch.distributed as dist is never referenced in this function/module. Please remove it to avoid lint noise and to keep the initialization logic focused on the actual parallel_state helpers being used.
| import torch.distributed as dist |
| def __init__( | ||
| self, | ||
| base_encoder: nn.Module, | ||
| *, | ||
| batch: int, | ||
| seq_len: int, | ||
| device: torch.device | str, | ||
| dtype: torch.dtype, | ||
| ) -> None: |
There was a problem hiding this comment.
dtype is accepted as a constructor argument but never stored or used; the wrapper instead implicitly relies on next(self.base.parameters()).dtype. Either remove the dtype parameter, or enforce/cast inputs and precomputed buffers to it so the API matches the docstring (“device, dtype for precomputation and I/O dtype”).
| def _static_forward(self, input_features: torch.Tensor) -> torch.Tensor: | ||
| base = self.base | ||
| chunk_list = input_features.T.split(self.chunk_lengths_list, dim=0) | ||
| padded_feature = ( | ||
| nn.utils.rnn.pad_sequence(chunk_list, batch_first=True) | ||
| .transpose(1, 2) | ||
| .unsqueeze(1) | ||
| ) |
There was a problem hiding this comment.
_static_forward uses input_features directly without moving/casting to the base encoder’s device/dtype. This makes the wrapper easy to misuse (e.g., FP32/CPU inputs) and can cause runtime dtype/device errors (Conv2d/Linear typically require matching dtypes/devices). Consider mirroring Qwen3OmniAudioEncoderNative.forward by explicitly sending input_features to self._device and the base param dtype before the conv stack.
| if not skip_shape_check: | ||
| self._check_shape_contract( | ||
| input_features, feature_attention_mask, audio_feature_lengths | ||
| ) | ||
| out = self._static_forward(input_features) | ||
| return {"audio_embeds": out} |
There was a problem hiding this comment.
forward returns only {"audio_embeds": ...} whereas the existing audio encoder components return audio_embeds, audio_feature_lengths, and audio_output_lengths. Since downstream pipeline code expects the length fields (e.g., merged thinker inputs include audio_feature_lengths), this wrapper won’t be drop-in compatible. Consider returning the same dict shape (length tensors can be constant/precomputed for the fixed-shape contract).
- native: use free ephemeral port instead of hardcoded MASTER_PORT=29501
(matches _resolve_nccl_port pattern); drop unused `import torch.distributed`.
- graphed: drop unused `dtype` ctor arg (was never stored/used). Derive
param dtype + device from `self.base` and use them inside
`_static_forward` so CPU/FP32 inputs auto-cast instead of blowing up at
Conv2d. Return the full {audio_embeds, audio_feature_lengths,
audio_output_lengths} dict matching `Qwen3OmniAudioEncoder` — the length
tensors are precomputed once (constant under our fixed-shape contract).
- bench: fail-fast in `main()` if `--device` resolves to non-CUDA; avoids
the first `torch.cuda.synchronize(device)` throwing with a confusing
trace. Drop the now-unused `dtype=` kwarg at GraphedAudioEncoder call site.
Verified on H100: graphed-eager 12.28ms, graphed+CUDA graph 2.45ms at
B=1 L=1000, parity still OK (same 1.51e-02 max / 8.62e-04 mean vs HF).
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
|
Addressed all 6 Copilot review comments in e0a75da:
Re-verified on H100: graphed-eager 12.28 ms, graphed + CUDA graph 2.45 ms at B=1 L=1000, numerical parity unchanged (max 1.51e-02 / mean 8.62e-04 vs HF). |
Summary
Adds a shape-static, CUDA-graph-capturable wrapper around sglang main's native
Qwen3OmniMoeAudioEncoder, plus a standalone micro-benchmark that establishes numerical and performance equivalence vs the current HF-wrapped audio encoder.On 1×H100 (bf16, fixed shapes), the full graphed + CUDA graph path delivers:
All 9 (B, L) combinations pass parity vs HF baseline (
max_abs_diff < 3e-1,mean_abs_diff < 5e-3, consistent with industry-standard FA-vs-SDPA bounds).Peak GPU memory is lower with CUDA graph than eager at large shapes (e.g. B=16 L=3000: 2.7 GB vs 6.1 GB) because the captured graph reuses static buffers.
Motivation
`Qwen3OmniMoeAudioEncoder.forward` in its current form is launch-bound on H100: with 32 transformer layers × 6–7 ops/layer, Python dispatch time is ~70% of wall-clock at B=1. pmon confirms <15% SM util across all (B, L).
Three ops in the base forward block CUDA graph capture:
This PR eliminates all three by precomputing the chunk layout, mask→index tensor, and `(cu_seqlens, max_seqlen)` tuple at `init`, leaving `forward` as pure tensor ops.
Three layers of speedup
The perf win decomposes cleanly:
The graphed-eager layer is free speedup even without CUDA graph, and is safe to adopt incrementally.
Files
Known limitations (POC scope, not production-ready)
These would need to be resolved before this replaces the default audio encoder in the pipeline:
TP=2 result: negative
For completeness, I also ran sglang native with TP=2 (weights split across 2 GPUs). At every tested `(B, L)` it is 15–25% slower than TP=1:
The audio encoder is launch-bound with small kernels; splitting TP makes each kernel smaller while adding NCCL all-reduce per layer. Recommendation: audio encoder should always run TP=1, even when the surrounding thinker uses TP>1.
Numerical equivalence
The graphed wrapper uses the same FA3 kernels as sglang's native audio encoder, which produces bf16 outputs that differ from HF's SDPA path by approximately 1–2e-3 mean / 1–2e-1 max (the latter dominated by outlier positions in long sequences, expected from FA vs SDPA softmax reduction order). `mean_abs_diff` is the load-bearing threshold (1–2e-3 is at the bf16 precision floor for our activation range). The benchmark harness enforces both bounds across the full (B, L) matrix.
Test plan
Follow-ups (separate PRs)
🤖 Generated with Claude Code