Skip to content

[ROCm][Perf] Fused GEMM + static FP8 output quantization#36810

Open
andyluo7 wants to merge 2 commits intovllm-project:mainfrom
andyluo7:rocm-fused-gemm-fp8-quant
Open

[ROCm][Perf] Fused GEMM + static FP8 output quantization#36810
andyluo7 wants to merge 2 commits intovllm-project:mainfrom
andyluo7:rocm-fused-gemm-fp8-quant

Conversation

@andyluo7
Copy link

Purpose

Fuse scaled_mm (FP8 GEMM → BF16) + static_scaled_fp8_quant (BF16 → FP8) into a single torch._scaled_mm call that outputs FP8 directly, eliminating the BF16 intermediate DRAM round-trip.

Related: #36689 (CUDA equivalent using CUTLASS epilogue)

How It Works

The key insight: merge output_scale into a_scales before the GEMM, then call torch._scaled_mm(..., out_dtype=fp8). This works because hipBLASLt natively supports FP8 output dtype since ROCm 6.0 (ROCm/hipBLASLt#602).

# Before (2 kernels + BF16 round-trip):
mm_out = scaled_mm(a, b, a_scales, b_scales, out_dtype=bf16)
fp8_out = static_scaled_fp8_quant(mm_out, output_scale)

# After (1 kernel, no intermediate):
combined_a_scales = a_scales * output_scale.reciprocal()
fp8_out = scaled_mm(a, b, combined_a_scales, b_scales, out_dtype=fp8)

Benchmark Results

MI300X, PyTorch 2.6, ROCm 6.4:

Shape (M,K,N) Unfused (µs) Fused (µs) Speedup
(1, 4096, 4096) 45.6 28.9 1.58x
(128, 4096, 4096) 45.7 31.6 1.45x
(512, 4096, 4096) 46.3 30.5 1.51x
(512, 4096, 11008) 94.8 61.4 1.54x
(1024, 4096, 11008) 161.9 100.1 1.62x
(512, 4096, 14336) 112.5 68.9 1.63x
(2048, 4096, 4096) 123.1 73.8 1.67x

The speedup comes from eliminating:

  1. One kernel launch
  2. BF16 intermediate write to DRAM (M×N × 2 bytes)
  3. BF16 intermediate read from DRAM (M×N × 2 bytes)

Changes

  • New: vllm/model_executor/kernels/linear/scaled_mm/rocm_fused_gemm_fp8_quant.py — Fused op implementation
  • New: vllm/compilation/passes/fusion/gemm_quant_fusion.py — Compilation pass (Inductor pattern matcher)
  • Modified: pass_manager.py — Register the new pass
  • Modified: compilation.py — Add fuse_gemm_quant config field
  • Modified: vllm.py — Add defaults (disabled by default at all compilation levels)

Safety

  • Default disabled: fuse_gemm_quant: False at all compilation levels
  • ROCm only: Platform guard in __post_init__ disables on non-ROCm
  • No CUDA impact: Zero changes to CUTLASS/CUDA code paths
  • No C/C++ changes: Pure Python, uses existing torch._scaled_mm API

Test Plan

Microbenchmark inside ROCm Docker container. Enable via compilation config fuse_gemm_quant: true.

Limitations

  • Only supports static per-tensor output quantization (not dynamic or per-token)
  • Requires ROCm 6.0+ (hipBLASLt FP8 output support)
  • Compilation pass requires torch.compile to be active

Add a fused kernel that combines scaled_mm (FP8 GEMM) with static FP8
output quantization into a single hipBLASLt call, eliminating the BF16
intermediate DRAM round-trip.

The implementation merges the output quantization scale into the GEMM's
a_scales and calls torch._scaled_mm with out_dtype=fp8 directly. This
works because hipBLASLt natively supports FP8 output dtype (since ROCm
6.0, PR ROCm/hipBLASLt#602).

Benchmark on MI300X (PyTorch 2.6, ROCm 6.4, 8x MI300X):

| Shape (M,K,N)     | Unfused (µs) | Fused (µs) | Speedup |
|-------------------|--------------|------------|---------|
| (1,4096,4096)     | 45.6         | 28.9       | 1.58x   |
| (128,4096,4096)   | 45.7         | 31.6       | 1.45x   |
| (512,4096,4096)   | 46.3         | 30.5       | 1.51x   |
| (512,4096,11008)  | 94.8         | 61.4       | 1.54x   |
| (1024,4096,11008) | 161.9        | 100.1      | 1.62x   |
| (512,4096,14336)  | 112.5        | 68.9       | 1.63x   |
| (2048,4096,4096)  | 123.1        | 73.8       | 1.67x   |

Changes:
- New file: rocm_fused_gemm_fp8_quant.py — fused op implementation
- New file: gemm_quant_fusion.py — compilation pass (pattern matcher)
- Modified: pass_manager.py — register the new pass
- Modified: compilation.py — add fuse_gemm_quant config field
- Modified: vllm.py — add defaults (disabled by default)

The fusion is gated behind fuse_gemm_quant config flag (default: False)
and only activates on ROCm. No impact on CUDA or other platforms.

Signed-off-by: Andy Luo <andy.linluo@gmail.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the rocm Related to AMD ROCm label Mar 11, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 11, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for ROCm by fusing a GEMM operation with static FP8 output quantization, which is a great improvement. The implementation leverages torch._scaled_mm with FP8 output, eliminating an intermediate memory round-trip. The changes are well-structured, with a new compilation pass, a corresponding kernel, and configuration flags. I have one high-severity comment regarding the robustness of the pattern matching, which could prevent the fusion from being applied in some cases. Please address it to ensure the optimization is effective.

Comment on lines +78 to +87
mm_result = auto_functionalized(
SCALED_MM_OP,
A=a,
B=b,
out_dtype=self.mm_out_dtype,
As=a_scales,
Bs=b_scales,
bias=bias,
)
mm_out = mm_result[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The SCALED_MM_OP might return a tuple (e.g., (output, workspace)), similar to the underlying torch._scaled_mm. In the new fused kernel rocm_scaled_mm_static_fp8_quant_impl, you correctly handle this by checking isinstance(output, tuple). However, this pattern definition assumes mm_result[1] is always a single tensor. If SCALED_MM_OP can return a tuple, mm_out will become a tuple, and the pattern will fail to match because static_scaled_fp8_quant expects a tensor for its input argument. This would prevent the fusion from occurring in cases where SCALED_MM_OP returns a tuple. To ensure robustness, please verify that torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl.default is guaranteed to return a single tensor. If it's not, the pattern needs to be adjusted to handle a potential tuple return from SCALED_MM_OP.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on the robustness concern. Let me clarify:

  1. rocm_per_tensor_float_w8a8_scaled_mm_impl always returns a single torch.Tensor — see rocm.py L63. The function signature is -> torch.Tensor and it handles the torch._scaled_mm tuple unwrapping internally.

  2. auto_functionalized normalizes returns — for an op with return type torch.Tensor, auto_functionalized(op, ...) returns (token, output_tensor), so mm_result[1] is always the output tensor.

  3. This matches the pattern used by other vLLM fusion passes — e.g., act_quant_fusion.py uses the same at[1] indexing for auto_functionalized results.

That said, I will add a clarifying comment in the code to make this explicit for future readers.

Address review comment: document that rocm_per_tensor_float_w8a8_scaled_mm_impl
returns a single torch.Tensor, and auto_functionalized wraps it as
(token, output_tensor), so [1] is always the output tensor.

Signed-off-by: Andy Luo <andy.linluo@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

2 participants