[ROCm][Perf] Fused GEMM + static FP8 output quantization#36810
[ROCm][Perf] Fused GEMM + static FP8 output quantization#36810andyluo7 wants to merge 2 commits intovllm-project:mainfrom
Conversation
Add a fused kernel that combines scaled_mm (FP8 GEMM) with static FP8 output quantization into a single hipBLASLt call, eliminating the BF16 intermediate DRAM round-trip. The implementation merges the output quantization scale into the GEMM's a_scales and calls torch._scaled_mm with out_dtype=fp8 directly. This works because hipBLASLt natively supports FP8 output dtype (since ROCm 6.0, PR ROCm/hipBLASLt#602). Benchmark on MI300X (PyTorch 2.6, ROCm 6.4, 8x MI300X): | Shape (M,K,N) | Unfused (µs) | Fused (µs) | Speedup | |-------------------|--------------|------------|---------| | (1,4096,4096) | 45.6 | 28.9 | 1.58x | | (128,4096,4096) | 45.7 | 31.6 | 1.45x | | (512,4096,4096) | 46.3 | 30.5 | 1.51x | | (512,4096,11008) | 94.8 | 61.4 | 1.54x | | (1024,4096,11008) | 161.9 | 100.1 | 1.62x | | (512,4096,14336) | 112.5 | 68.9 | 1.63x | | (2048,4096,4096) | 123.1 | 73.8 | 1.67x | Changes: - New file: rocm_fused_gemm_fp8_quant.py — fused op implementation - New file: gemm_quant_fusion.py — compilation pass (pattern matcher) - Modified: pass_manager.py — register the new pass - Modified: compilation.py — add fuse_gemm_quant config field - Modified: vllm.py — add defaults (disabled by default) The fusion is gated behind fuse_gemm_quant config flag (default: False) and only activates on ROCm. No impact on CUDA or other platforms. Signed-off-by: Andy Luo <andy.linluo@gmail.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization for ROCm by fusing a GEMM operation with static FP8 output quantization, which is a great improvement. The implementation leverages torch._scaled_mm with FP8 output, eliminating an intermediate memory round-trip. The changes are well-structured, with a new compilation pass, a corresponding kernel, and configuration flags. I have one high-severity comment regarding the robustness of the pattern matching, which could prevent the fusion from being applied in some cases. Please address it to ensure the optimization is effective.
| mm_result = auto_functionalized( | ||
| SCALED_MM_OP, | ||
| A=a, | ||
| B=b, | ||
| out_dtype=self.mm_out_dtype, | ||
| As=a_scales, | ||
| Bs=b_scales, | ||
| bias=bias, | ||
| ) | ||
| mm_out = mm_result[1] |
There was a problem hiding this comment.
The SCALED_MM_OP might return a tuple (e.g., (output, workspace)), similar to the underlying torch._scaled_mm. In the new fused kernel rocm_scaled_mm_static_fp8_quant_impl, you correctly handle this by checking isinstance(output, tuple). However, this pattern definition assumes mm_result[1] is always a single tensor. If SCALED_MM_OP can return a tuple, mm_out will become a tuple, and the pattern will fail to match because static_scaled_fp8_quant expects a tensor for its input argument. This would prevent the fusion from occurring in cases where SCALED_MM_OP returns a tuple. To ensure robustness, please verify that torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl.default is guaranteed to return a single tensor. If it's not, the pattern needs to be adjusted to handle a potential tuple return from SCALED_MM_OP.
There was a problem hiding this comment.
Good catch on the robustness concern. Let me clarify:
-
rocm_per_tensor_float_w8a8_scaled_mm_implalways returns a singletorch.Tensor— see rocm.py L63. The function signature is-> torch.Tensorand it handles thetorch._scaled_mmtuple unwrapping internally. -
auto_functionalizednormalizes returns — for an op with return typetorch.Tensor,auto_functionalized(op, ...)returns(token, output_tensor), somm_result[1]is always the output tensor. -
This matches the pattern used by other vLLM fusion passes — e.g.,
act_quant_fusion.pyuses the sameat[1]indexing forauto_functionalizedresults.
That said, I will add a clarifying comment in the code to make this explicit for future readers.
Address review comment: document that rocm_per_tensor_float_w8a8_scaled_mm_impl returns a single torch.Tensor, and auto_functionalized wraps it as (token, output_tensor), so [1] is always the output tensor. Signed-off-by: Andy Luo <andy.linluo@gmail.com>
Purpose
Fuse
scaled_mm(FP8 GEMM → BF16) +static_scaled_fp8_quant(BF16 → FP8) into a singletorch._scaled_mmcall that outputs FP8 directly, eliminating the BF16 intermediate DRAM round-trip.Related: #36689 (CUDA equivalent using CUTLASS epilogue)
How It Works
The key insight: merge
output_scaleintoa_scalesbefore the GEMM, then calltorch._scaled_mm(..., out_dtype=fp8). This works because hipBLASLt natively supports FP8 output dtype since ROCm 6.0 (ROCm/hipBLASLt#602).Benchmark Results
MI300X, PyTorch 2.6, ROCm 6.4:
The speedup comes from eliminating:
Changes
vllm/model_executor/kernels/linear/scaled_mm/rocm_fused_gemm_fp8_quant.py— Fused op implementationvllm/compilation/passes/fusion/gemm_quant_fusion.py— Compilation pass (Inductor pattern matcher)pass_manager.py— Register the new passcompilation.py— Addfuse_gemm_quantconfig fieldvllm.py— Add defaults (disabled by default at all compilation levels)Safety
fuse_gemm_quant: Falseat all compilation levels__post_init__disables on non-ROCmtorch._scaled_mmAPITest Plan
Microbenchmark inside ROCm Docker container. Enable via compilation config
fuse_gemm_quant: true.Limitations
torch.compileto be active