Annotate VLM/audio tower nn.Linear calls in PyTorch profiles#934
Open
mgehre-amd wants to merge 1 commit into
Open
Annotate VLM/audio tower nn.Linear calls in PyTorch profiles#934mgehre-amd wants to merge 1 commit into
mgehre-amd wants to merge 1 commit into
Conversation
3d8e3fe to
2b79b41
Compare
## Purpose
VLM tower models loaded via `_mark_tower_model` (e.g. Gemma-4's vision
tower from `AutoModel.from_config(...)`) use raw `torch.nn.Linear`. Their
`F.linear -> aten::mm -> rocBLAS` path bypasses both vLLM annotation sites
(`rocm_unquantized_gemm`, AWQ pytorch fallback), so the resulting hipBLASLt
`Cijk_*` kernels show up unattributed in PyTorch profiles, defeating
per-shape attribution and bandwidth analysis.
## Approach
Add `annotate_module_linears_for_profile(module)` and call it from
`_mark_tower_model` after the collected-children list is finalized. Each
raw `nn.Linear` descendant of every tower module gets its forward wrapped
in `record_function_or_nullcontext(f"BLAS {n}x{m}x{k} {dt}")` where `dt`
is the short input dtype name (`fp16` / `bf16` / `fp32` / ...) -- same
label format as existing call sites, plus a dtype suffix so non-fp16
vision/audio towers (e.g. bf16 SigLIP) are distinguishable. Strict
`type(child) is torch.nn.Linear` avoids double-annotating vLLM's
`LinearBase` subclasses (which already emit BLAS/wvSplitK/LLMM1 via their
quantization method). The wrapper is a no-op unless
`VLLM_CUSTOM_SCOPES_FOR_PROFILING` /
`VLLM_NVTX_SCOPES_FOR_PROFILING` is set, so steady-state cost is one
Python call per Linear invocation.
## Test results
Gemma-4-26B-A4B-IT_VLM_AWQ-4bit on Strix Halo (gfx1151), input=512,
output=1, num-prompts=3, max-num-seqs=1, --no-cudagraph,
VLLM_CUSTOM_SCOPES_FOR_PROFILING=1:
- TTFT median 675 ms (was 677 ms; within noise).
- bench-script print_profile_bandwidth TOTAL: 53.1% -> 64.4% (+11.3 pp).
- 4 new BLAS rows surfaced, all batch dim 2520 (= ViT patches), K/N in
{768, 1152, 4304} -- the SigLIP/Gemma vision tower's patch-projection,
attn-projection, MLP-up and MLP-down GEMMs.
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
2b79b41 to
726f324
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
VLM tower models loaded via
_mark_tower_model(e.g. Gemma-4's vision tower fromAutoModel.from_config(...)) use rawtorch.nn.Linear. TheirF.linear -> aten::mm -> hipBLASltpath bypasses both vLLM annotation sites (rocm_unquantized_gemm, AWQ pytorch fallback), so the resulting hipBLASLtCijk_*kernels show up unattributed in PyTorch profiles, defeating per-shape attribution and bandwidth analysis.Approach
Add
annotate_module_linears_for_profile(module)and call it from_mark_tower_modelafter the collected-children list is finalized. Each rawnn.Lineardescendant of every tower module gets its forward wrapped inrecord_function_or_nullcontext(f"BLAS {n}x{m}x{k}")-- same label format as existing call sites. Stricttype(child) is torch.nn.Linearavoids double-annotating vLLM'sLinearBasesubclasses (which already emit BLAS/wvSplitK/LLMM1 via their quantization method). The wrapper is a no-op unlessVLLM_CUSTOM_SCOPES_FOR_PROFILING/VLLM_NVTX_SCOPES_FOR_PROFILINGis set, so steady-state cost is one Python call per Linear invocation.Test results
Gemma-4-26B-A4B-IT_VLM_AWQ-4bit on Strix Halo (gfx1151), input=512, output=1, num-prompts=3, max-num-seqs=1, --no-cudagraph, VLLM_CUSTOM_SCOPES_FOR_PROFILING=1: