-
-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[ROCm][Perf] Fused GEMM + static FP8 output quantization #36810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+280
−0
Closed
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| Fusion pass: GEMM (scaled_mm) + static FP8 quantization. | ||
|
|
||
| Matches the graph pattern where a scaled matrix multiply produces BF16/FP16 | ||
| output that is immediately quantized to FP8 via static_scaled_fp8_quant, | ||
| and replaces it with a single fused kernel. | ||
|
|
||
| On ROCm: uses torch._scaled_mm with FP8 output dtype via hipBLASLt. | ||
| """ | ||
|
|
||
| import torch | ||
| from torch._higher_order_ops.auto_functionalize import auto_functionalized | ||
| from torch._inductor.pattern_matcher import ( | ||
| PatternMatcherPass, | ||
| fwd_only, | ||
| register_replacement, | ||
| ) | ||
|
|
||
| from vllm.config import VllmConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.platforms import current_platform | ||
|
|
||
| from ..inductor_pass import enable_fake_mode | ||
| from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| FP8_DTYPE = current_platform.fp8_dtype() | ||
|
|
||
| # Static quant op (same on all platforms) | ||
| STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default | ||
|
|
||
| # Platform-specific scaled_mm and fused ops | ||
| SCALED_MM_OP = None | ||
| FUSED_OP = None | ||
|
|
||
| if current_platform.is_rocm(): | ||
| # Ensure the fused op is registered | ||
| import vllm.model_executor.kernels.linear.scaled_mm.rocm_fused_gemm_fp8_quant # noqa: F401, E501 | ||
|
|
||
| FUSED_OP = torch.ops.vllm.rocm_scaled_mm_static_fp8_quant.default | ||
|
|
||
| if hasattr(torch.ops.vllm, "rocm_per_tensor_float_w8a8_scaled_mm_impl"): | ||
| SCALED_MM_OP = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl.default | ||
|
|
||
|
|
||
| class GemmStaticFP8QuantPattern: | ||
| """ | ||
| Matches: scaled_mm(a, b, out_dtype, As, Bs, bias) → BF16/FP16 | ||
| + static_scaled_fp8_quant(result, input, scale, group_shape) → FP8 | ||
|
|
||
| Replaces with: fused_op(a, b, As, Bs, output_scale, bias) → FP8 | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| mm_out_dtype: torch.dtype, | ||
| device: torch.device, | ||
| ) -> None: | ||
| self.mm_out_dtype = mm_out_dtype | ||
| self.device = device | ||
|
|
||
| def _empty(self, *shape: int, dtype: torch.dtype) -> torch.Tensor: | ||
| return torch.empty(*shape, dtype=dtype, device=self.device) | ||
|
|
||
| def pattern( | ||
| self, | ||
| a: torch.Tensor, | ||
| b: torch.Tensor, | ||
| a_scales: torch.Tensor, | ||
| b_scales: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| output_scale: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| # Step 1: scaled_mm → BF16/FP16 | ||
| mm_result = auto_functionalized( | ||
| SCALED_MM_OP, | ||
| A=a, | ||
| B=b, | ||
| out_dtype=self.mm_out_dtype, | ||
| As=a_scales, | ||
| Bs=b_scales, | ||
| bias=bias, | ||
| ) | ||
| mm_out = mm_result[1] | ||
|
|
||
| # Step 2: static_scaled_fp8_quant → FP8 | ||
| quant_result = auto_functionalized( | ||
| STATIC_FP8_QUANT_OP, | ||
| result=self._empty(1, 1, dtype=FP8_DTYPE), | ||
| input=mm_out, | ||
| scale=output_scale, | ||
| group_shape=None, | ||
| ) | ||
| return quant_result[1] | ||
|
|
||
| def replacement( | ||
| self, | ||
| a: torch.Tensor, | ||
| b: torch.Tensor, | ||
| a_scales: torch.Tensor, | ||
| b_scales: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| output_scale: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| fused_result = auto_functionalized( | ||
| FUSED_OP, | ||
| a=a, | ||
| b=b, | ||
| a_scales=a_scales, | ||
| b_scales=b_scales, | ||
| output_scale=output_scale, | ||
| bias=bias, | ||
| ) | ||
| return fused_result[1] | ||
|
|
||
| def register(self, pm_pass: PatternMatcherPass) -> None: | ||
| inputs = [ | ||
| self._empty(1, 1, dtype=FP8_DTYPE), # a | ||
| self._empty(1, 1, dtype=FP8_DTYPE), # b | ||
| self._empty(1, 1, dtype=torch.float32), # a_scales | ||
| self._empty(1, 1, dtype=torch.float32), # b_scales | ||
| self._empty(1, dtype=torch.float32), # bias | ||
| self._empty(1, dtype=torch.float32), # output_scale | ||
| ] | ||
|
|
||
| register_replacement( | ||
| self.pattern, | ||
| self.replacement, | ||
| inputs, | ||
| fwd_only, | ||
| pm_pass, | ||
| ) | ||
|
|
||
|
|
||
| class GemmQuantFusionPass(VllmPatternMatcherPass): | ||
| """ | ||
| Compilation pass that fuses GEMM + static FP8 quantization. | ||
|
|
||
| Supported platforms: | ||
| - ROCm (MI300X+): via torch._scaled_mm with FP8 output dtype | ||
| (hipBLASLt natively supports FP8 output since ROCm 6.0) | ||
| """ | ||
|
|
||
| @enable_fake_mode | ||
| def __init__(self, config: VllmConfig) -> None: | ||
| super().__init__(config) | ||
| self.patterns = PatternMatcherPass(pass_name="gemm_quant_fusion_pass") | ||
|
|
||
| if SCALED_MM_OP is None or FUSED_OP is None: | ||
| logger.debug( | ||
| "GEMM + FP8 quant fusion: no fused op available " | ||
| "for current platform, skipping" | ||
| ) | ||
| return | ||
|
|
||
| for out_dtype in (torch.bfloat16, torch.float16): | ||
| GemmStaticFP8QuantPattern(out_dtype, self.device).register(self.patterns) | ||
|
|
||
| self.dump_patterns(config, self.patterns) | ||
|
|
||
| @VllmInductorPass.time_and_log | ||
| def __call__(self, graph: torch.fx.Graph) -> None: | ||
| self.matched_count = self.patterns.apply(graph) | ||
| logger.debug( | ||
| "GemmQuantFusion: replaced %s patterns", | ||
| self.matched_count, | ||
| ) | ||
|
|
||
| def uuid(self) -> str: | ||
| return VllmInductorPass.hash_source(self, GemmStaticFP8QuantPattern) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
vllm/model_executor/kernels/linear/scaled_mm/rocm_fused_gemm_fp8_quant.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| ROCm fused GEMM + static FP8 output quantization. | ||
|
|
||
| Fuses scaled_mm (FP8→BF16/FP16) + static_scaled_fp8_quant (→FP8) into a | ||
| single scaled_mm call that outputs FP8 directly, eliminating the BF16/FP16 | ||
| intermediate DRAM round-trip. | ||
|
|
||
| The key insight: merge output_scale into a_scales before the GEMM, then | ||
| call torch._scaled_mm(..., out_dtype=fp8). This works because hipBLASLt | ||
| natively supports FP8 output dtype since ROCm 6.0. | ||
|
|
||
| Benchmark on MI300X (PyTorch 2.6, ROCm 6.4): | ||
| Shape (512,4096,4096): unfused 46.3µs → fused 30.5µs (1.51x) | ||
| Shape (512,4096,11008): unfused 94.8µs → fused 61.4µs (1.54x) | ||
| Shape (512,4096,14336): unfused 112.5µs → fused 68.9µs (1.63x) | ||
| Shape (2048,4096,4096): unfused 123.1µs → fused 73.8µs (1.67x) | ||
| """ | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.platforms import current_platform | ||
| from vllm.utils.torch_utils import direct_register_custom_op | ||
|
|
||
|
|
||
| def rocm_scaled_mm_static_fp8_quant_impl( | ||
| a: torch.Tensor, | ||
| b: torch.Tensor, | ||
| a_scales: torch.Tensor, | ||
| b_scales: torch.Tensor, | ||
| output_scale: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Fused FP8 GEMM + static FP8 output quantization. | ||
|
|
||
| Equivalent to: | ||
| mm_out = scaled_mm(a, b, a_scales, b_scales, out_dtype=bf16) | ||
| fp8_out = static_scaled_fp8_quant(mm_out, output_scale) | ||
|
|
||
| But done in a single hipBLASLt call by merging output_scale into | ||
| a_scales and requesting FP8 output directly. | ||
| """ | ||
| out_dtype = current_platform.fp8_dtype() | ||
|
|
||
| # Merge output_scale into a_scales: combined = a_scales / output_scale | ||
| # This is equivalent to: output = (a @ b * a_scales * b_scales) / output_scale | ||
| combined_a_scales = a_scales * output_scale.reciprocal() | ||
|
|
||
| output = torch._scaled_mm( | ||
| a, | ||
| b, | ||
| out_dtype=out_dtype, | ||
| scale_a=combined_a_scales, | ||
| scale_b=b_scales, | ||
| bias=bias, | ||
| ) | ||
| if isinstance(output, tuple): | ||
| output = output[0] | ||
| return output | ||
|
|
||
|
|
||
| def rocm_scaled_mm_static_fp8_quant_fake( | ||
| a: torch.Tensor, | ||
| b: torch.Tensor, | ||
| a_scales: torch.Tensor, | ||
| b_scales: torch.Tensor, | ||
| output_scale: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| out_dtype = current_platform.fp8_dtype() | ||
| return a.new_empty((*a.shape[:-1], b.shape[1]), dtype=out_dtype) | ||
|
|
||
|
|
||
| if current_platform.is_rocm(): | ||
| direct_register_custom_op( | ||
| op_name="rocm_scaled_mm_static_fp8_quant", | ||
| op_func=rocm_scaled_mm_static_fp8_quant_impl, | ||
| fake_impl=rocm_scaled_mm_static_fp8_quant_fake, | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
SCALED_MM_OPmight return a tuple (e.g.,(output, workspace)), similar to the underlyingtorch._scaled_mm. In the new fused kernelrocm_scaled_mm_static_fp8_quant_impl, you correctly handle this by checkingisinstance(output, tuple). However, this pattern definition assumesmm_result[1]is always a single tensor. IfSCALED_MM_OPcan return a tuple,mm_outwill become a tuple, and the pattern will fail to match becausestatic_scaled_fp8_quantexpects a tensor for itsinputargument. This would prevent the fusion from occurring in cases whereSCALED_MM_OPreturns a tuple. To ensure robustness, please verify thattorch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl.defaultis guaranteed to return a single tensor. If it's not, the pattern needs to be adjusted to handle a potential tuple return fromSCALED_MM_OP.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch on the robustness concern. Let me clarify:
rocm_per_tensor_float_w8a8_scaled_mm_implalways returns a singletorch.Tensor— see rocm.py L63. The function signature is-> torch.Tensorand it handles thetorch._scaled_mmtuple unwrapping internally.auto_functionalizednormalizes returns — for an op with return typetorch.Tensor,auto_functionalized(op, ...)returns(token, output_tensor), somm_result[1]is always the output tensor.This matches the pattern used by other vLLM fusion passes — e.g.,
act_quant_fusion.pyuses the sameat[1]indexing forauto_functionalizedresults.That said, I will add a clarifying comment in the code to make this explicit for future readers.