Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,87 @@ def apply_penalties(
return logits


def _tinygemm_bf16_impl(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
"""Real implementation: calls FlashInfer tinygemm via lazy wrapper."""
from vllm.utils.flashinfer import flashinfer_tinygemm_bf16
out = torch.empty(
input.shape[0], weight.shape[0],
dtype=torch.bfloat16, device=input.device,
)
flashinfer_tinygemm_bf16(input, weight, out, bias=bias)
return out


def _tinygemm_bf16_fake(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
"""Fake implementation for torch.compile graph tracing."""
return torch.empty(
input.shape[0], weight.shape[0],
dtype=torch.bfloat16, device=input.device,
)


_TINYGEMM_AVAILABLE = False


def _init_tinygemm():
"""Register tinygemm custom op if FlashInfer is available on SM90+."""
global _TINYGEMM_AVAILABLE
try:
from vllm.utils.flashinfer import has_flashinfer
if not has_flashinfer():
return
capability = current_platform.get_device_capability()
if capability is None or capability[0] < 9:
Comment on lines +129 to +130
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do the capability check only for arches that we have validated? I am wary of enabling this on all NVIDIA GPUs >= Hopper without benchmarks on each platform.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do microbenchmarking across multiple platforms to make sure we are honest here.

return
direct_register_custom_op(
"tinygemm_bf16",
_tinygemm_bf16_impl,
fake_impl=_tinygemm_bf16_fake,
)
_TINYGEMM_AVAILABLE = True
except Exception:
pass


_init_tinygemm()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should we should not import tinygemm on import of this file. It should be delayed to the first call of the unquantized gemm function itself



def _tinygemm_unquantized_gemm(
layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
):
num_tokens = x.numel() // x.shape[-1]
if (
num_tokens <= 8
and x.dtype == torch.bfloat16
and weight.dtype == torch.bfloat16
and weight.shape[0] % 16 == 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to see a microbenchmark script to make sure this gemm is good across a range of weight shapes, currently this will apply to most weights small or large

and x.is_contiguous()
and weight.is_contiguous()
and (bias is None or bias.dtype == torch.bfloat16)
):
if bias is None:
bias = torch.zeros(
weight.shape[0], dtype=torch.bfloat16, device=x.device,
)
Comment on lines +161 to +164
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems very suboptimal to be always allocating and filling bias on a low latency path. I would avoid this

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/flashinfer-ai/flashinfer/blob/9e3d8b9d1d11af893f7e7389baa25192db34bd8e/flashinfer/gemm/routergemm.py#L313

Seems non-optional. Not sure what a better solution looks like, but agree that filling the bias every time is a pain

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can use a pattern like the workspace buffer for FlashInfer attention?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed. I will talk to FlashInfer kernel author to see if that is something that can be changed.

out_shape = (*x.shape[:-1], weight.shape[0])
result = torch.ops.vllm.tinygemm_bf16(
x.view(num_tokens, -1), weight, bias,
)
return result.view(out_shape)
return torch.nn.functional.linear(x, weight, bias)


def default_unquantized_gemm(
layer: torch.nn.Module,
x: torch.Tensor,
Expand Down Expand Up @@ -304,5 +385,7 @@ def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
return rocm_unquantized_gemm
elif current_platform.is_cpu():
return cpu_unquantized_gemm
elif _TINYGEMM_AVAILABLE:
return _tinygemm_unquantized_gemm
else:
return default_unquantized_gemm
3 changes: 3 additions & 0 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def wrapper(*args, **kwargs):
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
"flashinfer", "trtllm_fp4_block_scale_moe"
)
flashinfer_tinygemm_bf16 = _lazy_import_wrapper(
"flashinfer.gemm", "tinygemm_bf16"
)
# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
"flashinfer.autotuner",
Expand Down
Loading