Skip to content

Annotate VLM/audio tower nn.Linear calls in PyTorch profiles#934

Open
mgehre-amd wants to merge 1 commit into
gfx11from
matthias.profile-tower-linears
Open

Annotate VLM/audio tower nn.Linear calls in PyTorch profiles#934
mgehre-amd wants to merge 1 commit into
gfx11from
matthias.profile-tower-linears

Conversation

@mgehre-amd
Copy link
Copy Markdown

@mgehre-amd mgehre-amd commented May 13, 2026

Purpose

VLM tower models loaded via _mark_tower_model (e.g. Gemma-4's vision tower from AutoModel.from_config(...)) use raw torch.nn.Linear. Their F.linear -> aten::mm -> hipBLASlt path bypasses both vLLM annotation sites (rocm_unquantized_gemm, AWQ pytorch fallback), so the resulting hipBLASLt Cijk_* kernels show up unattributed in PyTorch profiles, defeating per-shape attribution and bandwidth analysis.

Approach

Add annotate_module_linears_for_profile(module) and call it from _mark_tower_model after the collected-children list is finalized. Each raw nn.Linear descendant of every tower module gets its forward wrapped in record_function_or_nullcontext(f"BLAS {n}x{m}x{k}") -- same label format as existing call sites. Strict type(child) is torch.nn.Linear avoids double-annotating vLLM's LinearBase subclasses (which already emit BLAS/wvSplitK/LLMM1 via their quantization method). The wrapper is a no-op unless VLLM_CUSTOM_SCOPES_FOR_PROFILING / VLLM_NVTX_SCOPES_FOR_PROFILING is set, so steady-state cost is one Python call per Linear invocation.

Test results

Gemma-4-26B-A4B-IT_VLM_AWQ-4bit on Strix Halo (gfx1151), input=512, output=1, num-prompts=3, max-num-seqs=1, --no-cudagraph, VLLM_CUSTOM_SCOPES_FOR_PROFILING=1:

  • TTFT median 675 ms (was 677 ms; within noise).
  • bench-script print_profile_bandwidth TOTAL: 53.1% -> 64.4% (+11.3 pp).
  • 4 new BLAS rows surfaced, all batch dim 2520 (= ViT patches), K/N in {768, 1152, 4304} -- the SigLIP/Gemma vision tower's patch-projection, attn-projection, MLP-up and MLP-down GEMMs.

@mgehre-amd mgehre-amd force-pushed the matthias.profile-tower-linears branch from 3d8e3fe to 2b79b41 Compare May 13, 2026 10:38
## Purpose

VLM tower models loaded via `_mark_tower_model` (e.g. Gemma-4's vision
tower from `AutoModel.from_config(...)`) use raw `torch.nn.Linear`. Their
`F.linear -> aten::mm -> rocBLAS` path bypasses both vLLM annotation sites
(`rocm_unquantized_gemm`, AWQ pytorch fallback), so the resulting hipBLASLt
`Cijk_*` kernels show up unattributed in PyTorch profiles, defeating
per-shape attribution and bandwidth analysis.

## Approach

Add `annotate_module_linears_for_profile(module)` and call it from
`_mark_tower_model` after the collected-children list is finalized. Each
raw `nn.Linear` descendant of every tower module gets its forward wrapped
in `record_function_or_nullcontext(f"BLAS {n}x{m}x{k} {dt}")` where `dt`
is the short input dtype name (`fp16` / `bf16` / `fp32` / ...) -- same
label format as existing call sites, plus a dtype suffix so non-fp16
vision/audio towers (e.g. bf16 SigLIP) are distinguishable. Strict
`type(child) is torch.nn.Linear` avoids double-annotating vLLM's
`LinearBase` subclasses (which already emit BLAS/wvSplitK/LLMM1 via their
quantization method). The wrapper is a no-op unless
`VLLM_CUSTOM_SCOPES_FOR_PROFILING` /
`VLLM_NVTX_SCOPES_FOR_PROFILING` is set, so steady-state cost is one
Python call per Linear invocation.

## Test results

Gemma-4-26B-A4B-IT_VLM_AWQ-4bit on Strix Halo (gfx1151), input=512,
output=1, num-prompts=3, max-num-seqs=1, --no-cudagraph,
VLLM_CUSTOM_SCOPES_FOR_PROFILING=1:
- TTFT median 675 ms (was 677 ms; within noise).
- bench-script print_profile_bandwidth TOTAL: 53.1% -> 64.4% (+11.3 pp).
- 4 new BLAS rows surfaced, all batch dim 2520 (= ViT patches), K/N in
  {768, 1152, 4304} -- the SigLIP/Gemma vision tower's patch-projection,
  attn-projection, MLP-up and MLP-down GEMMs.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd force-pushed the matthias.profile-tower-linears branch from 2b79b41 to 726f324 Compare May 13, 2026 11:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant