[Perf][DSv4] Add cuteDSL generic LL Blockwise FP8 GEMM#43214
[Perf][DSv4] Add cuteDSL generic LL Blockwise FP8 GEMM#43214LopezCastroRoberto wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the LLFp8BlockScaledMMKernel, a low-latency FP8 block-scaled matrix multiplication kernel implemented using CuTe DSL. The kernel is designed for small batch sizes (M <= 16) and utilizes warp-specialized instructions with asynchronous copies. Feedback indicates a critical issue where the kernel hardcodes the E4M3 FP8 format but lacks a check in can_implement to prevent its use when E5M2 is enabled, which would lead to incorrect computations.
| @classmethod | ||
| def can_implement(cls, config): | ||
| return super().can_implement(config) |
There was a problem hiding this comment.
The underlying cuteDSL kernel (_ll_fp8_block_kernels.py) uses a hardcoded mma.sync instruction for the E4M3 FP8 format. However, this kernel can be selected even when the E5M2 format is enabled (via VLLM_USE_DEEP_GEMM_E5M2), which would lead to incorrect computation.
To prevent this, can_implement should check if E5M2 is being used and reject the configuration if so.
@classmethod
def can_implement(cls, config):
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
can_implement_base, reason = super().can_implement(config)
if not can_implement_base:
return can_implement_base, reason
if is_deep_gemm_e8m0_used():
return False, "LLFp8BlockScaledMMKernel only supports E4M3, not E5M2."
return True, None
No description provided.