Skip to content

[Perf] Qwen3-Omni audio encoder: sglang-native + CUDA-graph wrapper (up to 6.4x)#333

Open
jiaoew1991 wants to merge 2 commits intosgl-project:mainfrom
jiaoew1991:feat/audio-encoder-cuda-graph
Open

[Perf] Qwen3-Omni audio encoder: sglang-native + CUDA-graph wrapper (up to 6.4x)#333
jiaoew1991 wants to merge 2 commits intosgl-project:mainfrom
jiaoew1991:feat/audio-encoder-cuda-graph

Conversation

@jiaoew1991
Copy link
Copy Markdown

Summary

Adds a shape-static, CUDA-graph-capturable wrapper around sglang main's native Qwen3OmniMoeAudioEncoder, plus a standalone micro-benchmark that establishes numerical and performance equivalence vs the current HF-wrapped audio encoder.

On 1×H100 (bf16, fixed shapes), the full graphed + CUDA graph path delivers:

(B, L) HF eager graphed + graph Speedup
1, 300 13.54 ms 2.11 ms 6.42×
1, 1000 13.66 ms 2.46 ms 5.55×
1, 3000 14.82 ms 3.39 ms 4.37×
4, 1000 15.58 ms 4.13 ms 3.77×
16, 1000 23.57 ms 11.43 ms 2.06×
16, 3000 66.54 ms 32.31 ms 2.06×

All 9 (B, L) combinations pass parity vs HF baseline (max_abs_diff < 3e-1, mean_abs_diff < 5e-3, consistent with industry-standard FA-vs-SDPA bounds).

Peak GPU memory is lower with CUDA graph than eager at large shapes (e.g. B=16 L=3000: 2.7 GB vs 6.1 GB) because the captured graph reuses static buffers.

Motivation

`Qwen3OmniMoeAudioEncoder.forward` in its current form is launch-bound on H100: with 32 transformer layers × 6–7 ops/layer, Python dispatch time is ~70% of wall-clock at B=1. pmon confirms <15% SM util across all (B, L).

Three ops in the base forward block CUDA graph capture:

  1. `torch.tensor([python_list], device=...)` driven by a runtime-computed length
  2. `.item()` / `.tolist()` / `seq_lens.max().item()` (implicit device→host syncs)
  3. `padded_embed[bool_mask]` (boolean mask indexing requires syncing `mask.sum()` for output allocation)

This PR eliminates all three by precomputing the chunk layout, mask→index tensor, and `(cu_seqlens, max_seqlen)` tuple at `init`, leaving `forward` as pure tensor ops.

Three layers of speedup

The perf win decomposes cleanly:

Layer What it does Alone gets you
Shape-static refactor (graphed-eager) Removes dynamic ops from forward Already 8–48% faster than HF eager — no CUDA graph needed
`(cu, max)` tuple attention path Trips sglang's `SGLANG_VIT_ENABLE_CUDA_GRAPH` branch (available on fa3, triton, ascend backends) Enables capture to succeed
`torch.cuda.graph()` capture + replay Eliminates all Python dispatch during inference Up to 6.4× total

The graphed-eager layer is free speedup even without CUDA graph, and is safe to adopt incrementally.

Files

  • `sglang_omni/models/qwen3_omni/components/audio_encoder_native.py` — Port of existing `Qwen3OmniAudioEncoder` to sglang main's `Qwen3OmniMoeAudioEncoder` (fused QKV, ColumnParallel/RowParallel linear). Adds `init_sglang_env_for_encoder(model_path)` as a standalone bootstrap helper, and a strict weight loader (q/k/v → `qkv_proj` with `shard_id`, `out_proj` → `proj`; fails loud on unplaced HF keys or untouched native params).
  • `sglang_omni/models/qwen3_omni/components/audio_encoder_graphed.py` — `GraphedAudioEncoder` wrapper, fixed `(batch, seq_len)` per instance. Precomputes `chunk_lengths_list`, `flat_mask_indices`, `cu_seqlens`, `max_seqlen`, and the sliced+cast positional embedding. Enforces the shape contract at `forward` (`skip_shape_check=True` escape hatch for graph replay).
  • `benchmarks/micro/bench_audio_encoder.py` — Four-mode harness (`eager` / `cuda_graph` / `graphed` / `tp`), with enforced HF-vs-graphed numerical parity and support for TP>1 via `torchrun`.

Known limitations (POC scope, not production-ready)

These would need to be resolved before this replaces the default audio encoder in the pipeline:

  1. Fixed shape per instance: one wrapper handles one `(batch, seq_len)`. Production needs shape bucketing (pool of wrappers keyed by shape). This is Phase 2.
  2. Reaches into base encoder internals (`conv2d1/2/3`, `layers`, `positional_embedding`, `ln_post`, `proj1/2`). Upstream refactors in sglang main will require re-syncing. A `Qwen3OmniMoeAudioEncoder.static_forward(input_features, layout)` hook upstream would make this clean — recommended as a parallel upstream PR.
  3. AMD AITER backend does not honor `SGLANG_VIT_ENABLE_CUDA_GRAPH` (vision.py:650 has unconditional `.max().item()`). AMD users must force `--mm-attention-backend=triton` for now, or accept that graph capture will fail on aiter. ~5-line upstream patch would fix this.
  4. Not integrated into the pipeline stage: this PR ships the wrapper + bench but does not wire it into `audio_encoder` stage dispatch. That's intentional — we want the wrapper + perf data reviewed first.
  5. Verification scope: parity checked on synthetic mel inputs; end-to-end accuracy verification on MMSU (audio understanding) is recommended as a CI follow-up.

TP=2 result: negative

For completeness, I also ran sglang native with TP=2 (weights split across 2 GPUs). At every tested `(B, L)` it is 15–25% slower than TP=1:

(B, L) sglang TP=1 sglang TP=2
1, 300 15.17 ms 17.59 ms
4, 3000 21.76 ms 24.30 ms

The audio encoder is launch-bound with small kernels; splitting TP makes each kernel smaller while adding NCCL all-reduce per layer. Recommendation: audio encoder should always run TP=1, even when the surrounding thinker uses TP>1.

Numerical equivalence

The graphed wrapper uses the same FA3 kernels as sglang's native audio encoder, which produces bf16 outputs that differ from HF's SDPA path by approximately 1–2e-3 mean / 1–2e-1 max (the latter dominated by outlier positions in long sequences, expected from FA vs SDPA softmax reduction order). `mean_abs_diff` is the load-bearing threshold (1–2e-3 is at the bf16 precision floor for our activation range). The benchmark harness enforces both bounds across the full (B, L) matrix.

Test plan

  • `python -m benchmarks.micro.bench_audio_encoder --mode eager --seq-lens 300,1000,3000 --batch-sizes 1,4,16` — passes
  • `SGLANG_VIT_ENABLE_CUDA_GRAPH=1 python -m benchmarks.micro.bench_audio_encoder --mode graphed --seq-lens 300,1000,3000 --batch-sizes 1,4,16` — all 9 parity checks pass, 2.11 ms → 32.31 ms depending on (B, L)
  • `torchrun --nproc_per_node=2 ... --mode tp --tp-size 2 --seq-lens 300,1000,3000 --batch-sizes 1,4` — runs, confirms TP=2 regression
  • End-to-end MMSU accuracy with graphed encoder wired into pipeline stage (follow-up)
  • AMD MI300X verification with triton backend (follow-up, needs AMD hardware)

Follow-ups (separate PRs)

  1. Phase 2: shape bucketing — pool of `GraphedAudioEncoder` keyed by `(B, L)`, dispatch at stage-input time
  2. Pipeline integration: switch the `audio_encoder` stage to use graphed wrapper when `SGLANG_VIT_ENABLE_CUDA_GRAPH` is set
  3. Upstream sglang PR: add `static_forward` hook to `Qwen3OmniMoeAudioEncoder` to eliminate the internals-reaching in this wrapper
  4. Upstream sglang PR: add `SGLANG_VIT_ENABLE_CUDA_GRAPH` branch to `VisionAiterAttention`
  5. Apply same pattern to `talker_ar` / `code_predictor` — these are the SeedTTS pipeline bottleneck (issue [Perf] Talker pipeline too slow for unified MMMU accuracy + audio CI #276) and have the same launch-bound profile

🤖 Generated with Claude Code

…POC)

Adds a shape-static CUDA-graph-capturable wrapper around sglang main's
native Qwen3OmniMoeAudioEncoder, together with a micro-benchmark.

On 1x H100 bf16 the full graphed + CUDA graph path achieves up to 6.42x
speedup over the existing HF-wrapped audio encoder (B=1 L=300:
2.11 ms vs 13.54 ms), tapering to 2.06x at B=16 L=3000 as GPU compute
starts to dominate launch overhead.

Not a drop-in replacement yet: fixed (batch, seq_len) per instance,
and dispatch into the pipeline stage is left for a follow-up RFC.

Files:
  - sglang_omni/models/qwen3_omni/components/audio_encoder_native.py
    Thin wrapper around sglang.srt.models.qwen3_omni_moe.Qwen3OmniMoeAudioEncoder
    with explicit init_sglang_env_for_encoder() bootstrap and strict weight
    loading (q/k/v -> qkv_proj, out_proj -> proj, fail-loud on any unplaced
    or untouched param).

  - sglang_omni/models/qwen3_omni/components/audio_encoder_graphed.py
    Fixed-shape wrapper that precomputes chunk_lengths, flat_mask_indices,
    cu_seqlens and the sliced positional embedding at __init__. Replaces
    the three CUDA-graph-incompatible patterns in base.forward:
      1. torch.tensor([python_list], device=...)   -> preallocated buffer
      2. .item() / .tolist() / .max().item()       -> precomputed int
      3. padded_embed[bool_mask]                   -> index_select on a
         precomputed long-index tensor
    cu_seqlens is passed in the (cu_tensor, max_seqlen) tuple form expected
    by sglang attention backends when SGLANG_VIT_ENABLE_CUDA_GRAPH is set,
    falling back to a plain tensor otherwise.

  - benchmarks/micro/bench_audio_encoder.py
    Four modes: eager (HF vs sglang side-by-side), cuda_graph, graphed
    (Phase 1 wrapper, with and without CUDA graph), and tp (TP=N via
    torchrun). Enforces numerical parity vs HF across the (B, L) matrix
    (max_abs_diff < 3e-1, mean_abs_diff < 5e-3, consistent with
    industry-standard FA-vs-SDPA equivalence bounds).

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@jiaoew1991 jiaoew1991 requested a review from shuaills as a code owner April 22, 2026 21:31
Copilot AI review requested due to automatic review settings April 22, 2026 21:31
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an sglang-native Qwen3-Omni audio encoder implementation plus a fixed-shape wrapper intended for CUDA graph capture/replay, along with a micro-benchmark to compare latency, memory, and numerical parity versus the current HF-based encoder.

Changes:

  • Introduce Qwen3OmniAudioEncoderNative and a standalone sglang environment bootstrap + strict HF→sglang weight mapping/validation.
  • Add GraphedAudioEncoder to precompute layout tensors and provide a graph-capturable “static forward” for fixed (batch, seq_len).
  • Add bench_audio_encoder.py micro-benchmark supporting eager / CUDA graph / graphed / TP modes plus parity checks.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
sglang_omni/models/qwen3_omni/components/audio_encoder_native.py Boots sglang distributed/server-args state and loads HF audio_tower weights into the sglang-native audio encoder.
sglang_omni/models/qwen3_omni/components/audio_encoder_graphed.py Implements a fixed-shape wrapper that removes non-graph-capturable ops via precomputed indices and cu_seqlens handling.
benchmarks/micro/bench_audio_encoder.py Adds a benchmark harness to measure latency/memory and enforce numerical parity across implementations and modes.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +128 to +145
for _ in range(warmup):
_ = _run_once(encoder, input_features, lengths)
torch.cuda.synchronize(device)
torch.cuda.reset_peak_memory_stats(device)

latencies_ms = []
out_sample = None
for i in range(iters):
torch.cuda.synchronize(device)
t0 = time.perf_counter()
out = _run_once(encoder, input_features, lengths)
torch.cuda.synchronize(device)
t1 = time.perf_counter()
latencies_ms.append((t1 - t0) * 1000.0)
if i == 0:
out_sample = out.detach().to(torch.float32).cpu()

peak_mem_mb = torch.cuda.max_memory_allocated(device) / (1024**2)
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark helpers unconditionally call torch.cuda.synchronize/reset_peak_memory_stats/max_memory_allocated. If --device is set to a non-CUDA device (or CUDA is unavailable), this will raise at runtime. Consider validating torch.device(args.device).type == "cuda" early (or guarding the CUDA-only calls) so failures are clearer and CLI behavior is well-defined.

Suggested change
for _ in range(warmup):
_ = _run_once(encoder, input_features, lengths)
torch.cuda.synchronize(device)
torch.cuda.reset_peak_memory_stats(device)
latencies_ms = []
out_sample = None
for i in range(iters):
torch.cuda.synchronize(device)
t0 = time.perf_counter()
out = _run_once(encoder, input_features, lengths)
torch.cuda.synchronize(device)
t1 = time.perf_counter()
latencies_ms.append((t1 - t0) * 1000.0)
if i == 0:
out_sample = out.detach().to(torch.float32).cpu()
peak_mem_mb = torch.cuda.max_memory_allocated(device) / (1024**2)
use_cuda_timing = device.type == "cuda" and torch.cuda.is_available()
for _ in range(warmup):
_ = _run_once(encoder, input_features, lengths)
if use_cuda_timing:
torch.cuda.synchronize(device)
torch.cuda.reset_peak_memory_stats(device)
latencies_ms = []
out_sample = None
for i in range(iters):
if use_cuda_timing:
torch.cuda.synchronize(device)
t0 = time.perf_counter()
out = _run_once(encoder, input_features, lengths)
if use_cuda_timing:
torch.cuda.synchronize(device)
t1 = time.perf_counter()
latencies_ms.append((t1 - t0) * 1000.0)
if i == 0:
out_sample = out.detach().to(torch.float32).cpu()
peak_mem_mb = (
torch.cuda.max_memory_allocated(device) / (1024**2)
if use_cuda_timing
else 0.0
)

Copilot uses AI. Check for mistakes.
Comment on lines +54 to +58
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29501")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("LOCAL_RANK", "0")
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_sglang_env_for_encoder hard-codes a default MASTER_PORT=29501. When running multiple local benchmarks/tests on the same host, this can lead to nondeterministic “address already in use” failures. Consider following the repo’s existing pattern of selecting a free ephemeral port when MASTER_PORT is unset (see models/ming_omni/components/image_encoder.py), or reusing the _resolve_nccl_port() approach used elsewhere.

Copilot uses AI. Check for mistakes.
Comment on lines +52 to +53
import torch.distributed as dist

Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import: torch.distributed as dist is never referenced in this function/module. Please remove it to avoid lint noise and to keep the initialization logic focused on the actual parallel_state helpers being used.

Suggested change
import torch.distributed as dist

Copilot uses AI. Check for mistakes.
Comment on lines +69 to +77
def __init__(
self,
base_encoder: nn.Module,
*,
batch: int,
seq_len: int,
device: torch.device | str,
dtype: torch.dtype,
) -> None:
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype is accepted as a constructor argument but never stored or used; the wrapper instead implicitly relies on next(self.base.parameters()).dtype. Either remove the dtype parameter, or enforce/cast inputs and precomputed buffers to it so the API matches the docstring (“device, dtype for precomputation and I/O dtype”).

Copilot uses AI. Check for mistakes.
Comment on lines +192 to +199
def _static_forward(self, input_features: torch.Tensor) -> torch.Tensor:
base = self.base
chunk_list = input_features.T.split(self.chunk_lengths_list, dim=0)
padded_feature = (
nn.utils.rnn.pad_sequence(chunk_list, batch_first=True)
.transpose(1, 2)
.unsqueeze(1)
)
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_static_forward uses input_features directly without moving/casting to the base encoder’s device/dtype. This makes the wrapper easy to misuse (e.g., FP32/CPU inputs) and can cause runtime dtype/device errors (Conv2d/Linear typically require matching dtypes/devices). Consider mirroring Qwen3OmniAudioEncoderNative.forward by explicitly sending input_features to self._device and the base param dtype before the conv stack.

Copilot uses AI. Check for mistakes.
Comment on lines +277 to +282
if not skip_shape_check:
self._check_shape_contract(
input_features, feature_attention_mask, audio_feature_lengths
)
out = self._static_forward(input_features)
return {"audio_embeds": out}
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward returns only {"audio_embeds": ...} whereas the existing audio encoder components return audio_embeds, audio_feature_lengths, and audio_output_lengths. Since downstream pipeline code expects the length fields (e.g., merged thinker inputs include audio_feature_lengths), this wrapper won’t be drop-in compatible. Consider returning the same dict shape (length tensors can be constant/precomputed for the fixed-shape contract).

Copilot uses AI. Check for mistakes.
- native: use free ephemeral port instead of hardcoded MASTER_PORT=29501
  (matches _resolve_nccl_port pattern); drop unused `import torch.distributed`.
- graphed: drop unused `dtype` ctor arg (was never stored/used). Derive
  param dtype + device from `self.base` and use them inside
  `_static_forward` so CPU/FP32 inputs auto-cast instead of blowing up at
  Conv2d. Return the full {audio_embeds, audio_feature_lengths,
  audio_output_lengths} dict matching `Qwen3OmniAudioEncoder` — the length
  tensors are precomputed once (constant under our fixed-shape contract).
- bench: fail-fast in `main()` if `--device` resolves to non-CUDA; avoids
  the first `torch.cuda.synchronize(device)` throwing with a confusing
  trace. Drop the now-unused `dtype=` kwarg at GraphedAudioEncoder call site.

Verified on H100: graphed-eager 12.28ms, graphed+CUDA graph 2.45ms at
B=1 L=1000, parity still OK (same 1.51e-02 max / 8.62e-04 mean vs HF).

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@jiaoew1991
Copy link
Copy Markdown
Author

Addressed all 6 Copilot review comments in e0a75da:

# File / line (orig) Fix
1 bench L145 — unconditional torch.cuda.synchronize main() now fails fast with a clear error if --device is non-CUDA (the whole harness is CUDA-only by construction; explicit guard > confusing .synchronize() trace).
2 native L58 — hardcoded MASTER_PORT=29501 Picks a free ephemeral port via socket.bind(("", 0)), matching the _resolve_nccl_port pattern in engines/ar/sglang_backend/model_worker.py.
3 native L53 — unused import torch.distributed as dist Removed.
4 graphed L77 — dtype ctor arg never stored/used Dropped the dtype parameter. Device/dtype are now derived from self.base.parameters() (self._param_device, self._param_dtype) and applied inside _static_forward.
5 graphed L199 — _static_forward doesn't cast input device/dtype Added input_features.to(device=self._param_device, dtype=self._param_dtype) at the top of _static_forward. No-op when already matching → safe under graph capture.
6 graphed L282 — return dict missing length fields Now returns {audio_embeds, audio_feature_lengths, audio_output_lengths}, matching Qwen3OmniAudioEncoder / Qwen3OmniAudioEncoderNative. Length tensors are precomputed at __init__ under the fixed-shape contract and reused per forward (no allocation in hot path).

Re-verified on H100: graphed-eager 12.28 ms, graphed + CUDA graph 2.45 ms at B=1 L=1000, numerical parity unchanged (max 1.51e-02 / mean 8.62e-04 vs HF).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants