Use FI tinygemm for faster BF16 GEMMs#39921
Use FI tinygemm for faster BF16 GEMMs#39921askliar wants to merge 7 commits intovllm-project:mainfrom
Conversation
…ion) For batch size 1 (and up to M=8), replace cuBLAS BF16 GEMM with FlashInfer's tinygemm_bf16 kernel which is latency-optimized for small M using warp-specialized TMA + HMMA design from TRT-LLM. This affects all unquantized BF16 linear layers: attention QKV/O projections, dense MLP up/down projections, shared MoE experts, and latent MoE projections. Requires SM90+ (Hopper/Blackwell). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
e821244 to
f6c2425
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces support for B12X NVFP4 dense linear and fused MoE kernels for Blackwell (SM120/SM121) GPUs and optimizes unquantized GEMM operations for small batch sizes on SM90+ hardware using FlashInfer's tinygemm kernel. A critical correctness issue was identified in the default_unquantized_gemm implementation where multi-dimensional input tensors were not correctly handled, potentially leading to shape mismatches or incorrect kernel triggering. A code suggestion was provided to flatten the tokens and correctly calculate the number of tokens for the threshold check and output tensor allocation.
| if ( | ||
| _tinygemm_bf16_available() | ||
| and x.shape[0] <= 8 | ||
| and x.dtype == torch.bfloat16 | ||
| and weight.dtype == torch.bfloat16 | ||
| and weight.shape[0] % 16 == 0 | ||
| and x.is_contiguous() | ||
| and weight.is_contiguous() | ||
| and (bias is None or bias.dtype == torch.bfloat16) | ||
| ): | ||
| from vllm.utils.flashinfer import flashinfer_tinygemm_bf16 | ||
| out = torch.empty( | ||
| x.shape[0], weight.shape[0], dtype=torch.bfloat16, device=x.device | ||
| ) | ||
| flashinfer_tinygemm_bf16(x, weight, out, bias=bias) | ||
| return out |
There was a problem hiding this comment.
The current implementation of default_unquantized_gemm using tinygemm_bf16 has a correctness issue when the input tensor x is multi-dimensional (e.g., [batch, seq, hidden]).
- Incorrect Tiny Detection:
x.shape[0] <= 8only checks the first dimension. For a 3D tensor like[1, 1024, hidden], this would incorrectly trigger the tiny GEMM path for 1024 tokens becauseshape[0]is 1. - Shape Mismatch: The output tensor
outis allocated as[x.shape[0], weight.shape[0]], which ignores intermediate dimensions (likeseq). This will lead to a shape mismatch or crash when returning to the caller. - Kernel Application: The
tinygemmkernel should be applied to a flattened view of the tokens to correctly handle multi-dimensional inputs.
I suggest calculating the total number of tokens M and using it for both the threshold check and the tensor views.
| if ( | |
| _tinygemm_bf16_available() | |
| and x.shape[0] <= 8 | |
| and x.dtype == torch.bfloat16 | |
| and weight.dtype == torch.bfloat16 | |
| and weight.shape[0] % 16 == 0 | |
| and x.is_contiguous() | |
| and weight.is_contiguous() | |
| and (bias is None or bias.dtype == torch.bfloat16) | |
| ): | |
| from vllm.utils.flashinfer import flashinfer_tinygemm_bf16 | |
| out = torch.empty( | |
| x.shape[0], weight.shape[0], dtype=torch.bfloat16, device=x.device | |
| ) | |
| flashinfer_tinygemm_bf16(x, weight, out, bias=bias) | |
| return out | |
| num_tokens = x.numel() // x.shape[-1] | |
| if ( | |
| _tinygemm_bf16_available() | |
| and num_tokens <= 8 | |
| and x.dtype == torch.bfloat16 | |
| and weight.dtype == torch.bfloat16 | |
| and weight.shape[0] % 16 == 0 | |
| and x.is_contiguous() | |
| and weight.is_contiguous() | |
| and (bias is None or bias.dtype == torch.bfloat16) | |
| ): | |
| from vllm.utils.flashinfer import flashinfer_tinygemm_bf16 | |
| out = x.new_empty((*x.shape[:-1], weight.shape[0])) | |
| flashinfer_tinygemm_bf16(x.view(num_tokens, -1), weight, | |
| out.view(num_tokens, -1), bias=bias) | |
| return out |
Handle inputs that may be 3D (batch, seq, hidden) by computing the total token count and reshaping for tinygemm, preserving the original shape in the output. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add @torch._dynamo.assume_constant_result to _tinygemm_bf16_available() so dynamo doesn't trace into importlib/shutil calls that it can't handle. The function returns a hardware-capability constant, so this is safe. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Resolve tinygemm availability at import time and call torch.ops.flashinfer.tinygemm2_op directly instead of going through the lazy wrapper. This avoids importlib calls that dynamo cannot trace in fullgraph AOT compilation mode. - _init_tinygemm() runs at import: checks SM90+, JIT-compiles and registers the flashinfer::tinygemm2_op custom op - dispatch_unquantized_gemm() reads a plain bool global - _tinygemm_unquantized_gemm() only uses tensor ops and the registered custom op — all dynamo-safe Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Follow the established vLLM pattern for FlashInfer ops in compiled code (matching scaled_mm/flashinfer.py): - Register torch.ops.vllm.tinygemm_bf16 via direct_register_custom_op - Real impl calls FlashInfer lazily (runs at execution time) - Fake impl returns empty tensor (used during graph tracing) - Compiled graph captures the custom op as an opaque FX node This avoids calling importlib/FlashInfer wrappers during dynamo tracing while keeping the actual FlashInfer kernel in the graph. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Hi @WoosukKwon @youkaichao @robertgshaw2 @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @pavanimajety - this PR is ready for review! Summary
Please review when you have time. Thanks! |
|
@Himan-D do you know if this tinygemm is PDL enabled? We are interested in a general gemm PDL replacement for low latency |
| pass | ||
|
|
||
|
|
||
| _init_tinygemm() |
There was a problem hiding this comment.
We should we should not import tinygemm on import of this file. It should be delayed to the first call of the unquantized gemm function itself
| capability = current_platform.get_device_capability() | ||
| if capability is None or capability[0] < 9: |
There was a problem hiding this comment.
Can we do the capability check only for arches that we have validated? I am wary of enabling this on all NVIDIA GPUs >= Hopper without benchmarks on each platform.
There was a problem hiding this comment.
I will do microbenchmarking across multiple platforms to make sure we are honest here.
| if bias is None: | ||
| bias = torch.zeros( | ||
| weight.shape[0], dtype=torch.bfloat16, device=x.device, | ||
| ) |
There was a problem hiding this comment.
It seems very suboptimal to be always allocating and filling bias on a low latency path. I would avoid this
There was a problem hiding this comment.
Seems non-optional. Not sure what a better solution looks like, but agree that filling the bias every time is a pain
There was a problem hiding this comment.
maybe we can use a pattern like the workspace buffer for FlashInfer attention?
There was a problem hiding this comment.
Yes, indeed. I will talk to FlashInfer kernel author to see if that is something that can be changed.
| num_tokens <= 8 | ||
| and x.dtype == torch.bfloat16 | ||
| and weight.dtype == torch.bfloat16 | ||
| and weight.shape[0] % 16 == 0 |
There was a problem hiding this comment.
I would like to see a microbenchmark script to make sure this gemm is good across a range of weight shapes, currently this will apply to most weights small or large
|
Hi @askliar, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
@mgoin tinygemm has PDL enabled. Check here: https://github.com/askliar/flashinfer/blob/main/csrc/tinygemm2.cu |
|
Is M <= 8 an arbitrary heuristic or is this the recommended choice? Is this backed by benchmarks? |
|
Hi @askliar, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
vadiklyutiy
left a comment
There was a problem hiding this comment.
What do you think about bypassing all bf16 gemms to FalshInfer?
We do it for fp4 and fp8.
FlashInfer has autotuning and can choose the best kernel including this tinygemm.
Also I found early some perf gaps for medium size batches (from 257-271) where flashinfer can provide improvement. See #35467 and #27173
(But need 2 small modifications of flashinfer flashinfer-ai/flashinfer#2958 and flashinfer-ai/flashinfer#2914 - both ready to be merged).
|
Thanks for the review feedback! A few updates on the current implementation:
Would you like us to make any additional changes? |
|
Thanks for the review! Here are the updates:
Let me know if you have more feedback! |
|
Great question! Yes, the FlashInfer tinygemm_bf16 kernel supports PDL (Programmatic Dependent Launch) via the parameter. Looking at the FlashInfer source (flashinfer/gemm/routergemm.py): When , it uses to overlap DMA with compute from preceding kernels - ideal for low-latency decode scenarios. Currently our implementation doesn't enable PDL - we could add this as an option. Would you like us to:
This would be particularly useful for PD (Prefill-Decode) disaggregation where we want to overlap the GEMM with previous kernels. |
|
Hey @mgoin - to follow up on your question about PDL: Yes the FlashInfer tinygemm_bf16 supports PDL (Programmatic Dependent Launch) via the use_pdl parameter. Currently we have it disabled (use_pdl=False). For PD workloads, we could add PDL support - would you like us to:
This would help overlap GEMM compute with previous kernel DMA for lower latency. |
|
Good question! Yes the FlashInfer tinygemm supports PDL via use_pdl parameter. Currently use_pdl=False. For PD workloads we could add PDL support - want us to add it as configurable option? |
|
Following up on PDL question - yes FlashInfer tinygemm supports PDL (use_pdl param). Currently use_pdl=False. Could add as configurable option for PD workloads. |
|
We should use PDL. |
|
I think if we can just call |
Purpose
Route small-M BF16 linears (M ≤ 8) through FlashInfer's
tinygemm_bf16kernel to speed up decode-phase QKV/O, MLP up/down, shared MoE, and latent MoE projections. Uses TRT-LLM's warp-specialized TMA + HMMA design tuned for latency-bound GEMVs. Requires SM90+ and FlashInfer; falls back toF.linearotherwise.Changes
flashinfer_tinygemm_bf16to the FlashInfer lazy-import registry.torch.ops.vllm.tinygemm_bf16viadirect_register_custom_op(real + fake impls) fortorch.compilefullgraph compatibility._tinygemm_unquantized_gemmwith small-M BF16 predicate and multi-dim input handling; wire it intodispatch_unquantized_gemm()behind an import-time_TINYGEMM_AVAILABLEflag.Test Plan
lm_evalaccuracy parity on a BF16 model.torch.compilefullgraph decode (no graph breaks).