Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,63 @@ def rocm_unquantized_gemm(
)


_BLAS_DTYPE_SHORT_NAMES: dict[torch.dtype, str] = {
torch.float16: "fp16",
torch.bfloat16: "bf16",
torch.float32: "fp32",
torch.float64: "fp64",
torch.int8: "int8",
torch.int16: "int16",
torch.int32: "int32",
torch.int64: "int64",
torch.uint8: "uint8",
}


def annotate_module_linears_for_profile(module: torch.nn.Module) -> None:
"""Wrap every raw ``torch.nn.Linear`` descendant of ``module`` so its
forward call is enclosed in ``record_function_or_nullcontext(
f"BLAS {n}x{m}x{k} {dt}")`` where ``dt`` is the short input dtype name
(``fp16`` / ``bf16`` / ``fp32`` / ...).

Used to label HuggingFace-stock vision/audio tower GEMMs (which call
``F.linear`` directly and bypass vLLM's ``rocm_unquantized_gemm`` path,
leaving the resulting hipBLASLt ``Cijk_*`` kernels unattributed in
PyTorch profiles).

Skips vLLM's own ``LinearBase`` subclasses (``ColumnParallelLinear``,
``RowParallelLinear``, ``ReplicatedLinear``, etc.) — those route through
a quantization method that already emits
``BLAS``/``wvSplitK``/``LLMM1`` annotations.

Idempotent: safe to call twice on the same module.
"""
for child in module.modules():
if type(child) is not torch.nn.Linear:
continue
if getattr(child, "_vllm_profile_annotated", False):
continue
original_forward = child.forward
out_features = child.out_features
in_features = child.in_features

def wrapped_forward(
x: torch.Tensor,
_orig=original_forward,
_m=out_features,
_k=in_features,
) -> torch.Tensor:
n = x.numel() // x.size(-1)
dt = _BLAS_DTYPE_SHORT_NAMES.get(
x.dtype, str(x.dtype).removeprefix("torch.")
)
with record_function_or_nullcontext(f"BLAS {n}x{_m}x{_k} {dt}"):
return _orig(x)

child.forward = wrapped_forward
child._vllm_profile_annotated = True


def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
return (
torch.cpu._is_amx_tile_supported()
Expand Down
13 changes: 13 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,19 @@ def _mark_tower_model(

self._tower_model_names = children_names

# Annotate raw nn.Linear modules in the freshly-constructed tower so
# their hipBLASLt Cijk_* kernels are grouped as "BLAS NxKxM" in
# profiles. Local import: layers.utils -> _custom_ops would be a
# heavy top-level import here.
from vllm.model_executor.layers.utils import (
annotate_module_linears_for_profile,
)

for child_name in children_names:
child = getattr(self, child_name, None)
if isinstance(child, nn.Module):
annotate_module_linears_for_profile(child)

@contextmanager
def _mark_composite_model(
self,
Expand Down
Loading