forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgemm_quant_fusion.py
More file actions
173 lines (142 loc) · 5.23 KB
/
gemm_quant_fusion.py
File metadata and controls
173 lines (142 loc) · 5.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Fusion pass: GEMM (scaled_mm) + static FP8 quantization.
Matches the graph pattern where a scaled matrix multiply produces BF16/FP16
output that is immediately quantized to FP8 via static_scaled_fp8_quant,
and replaces it with a single fused kernel.
On ROCm: uses torch._scaled_mm with FP8 output dtype via hipBLASLt.
"""
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (
PatternMatcherPass,
fwd_only,
register_replacement,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
# Static quant op (same on all platforms)
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
# Platform-specific scaled_mm and fused ops
SCALED_MM_OP = None
FUSED_OP = None
if current_platform.is_rocm():
# Ensure the fused op is registered
import vllm.model_executor.kernels.linear.scaled_mm.rocm_fused_gemm_fp8_quant # noqa: F401, E501
FUSED_OP = torch.ops.vllm.rocm_scaled_mm_static_fp8_quant.default
if hasattr(torch.ops.vllm, "rocm_per_tensor_float_w8a8_scaled_mm_impl"):
SCALED_MM_OP = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl.default
class GemmStaticFP8QuantPattern:
"""
Matches: scaled_mm(a, b, out_dtype, As, Bs, bias) → BF16/FP16
+ static_scaled_fp8_quant(result, input, scale, group_shape) → FP8
Replaces with: fused_op(a, b, As, Bs, output_scale, bias) → FP8
"""
def __init__(
self,
mm_out_dtype: torch.dtype,
device: torch.device,
) -> None:
self.mm_out_dtype = mm_out_dtype
self.device = device
def _empty(self, *shape: int, dtype: torch.dtype) -> torch.Tensor:
return torch.empty(*shape, dtype=dtype, device=self.device)
def pattern(
self,
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
bias: torch.Tensor,
output_scale: torch.Tensor,
) -> torch.Tensor:
# Step 1: scaled_mm → BF16/FP16
mm_result = auto_functionalized(
SCALED_MM_OP,
A=a,
B=b,
out_dtype=self.mm_out_dtype,
As=a_scales,
Bs=b_scales,
bias=bias,
)
mm_out = mm_result[1]
# Step 2: static_scaled_fp8_quant → FP8
quant_result = auto_functionalized(
STATIC_FP8_QUANT_OP,
result=self._empty(1, 1, dtype=FP8_DTYPE),
input=mm_out,
scale=output_scale,
group_shape=None,
)
return quant_result[1]
def replacement(
self,
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
bias: torch.Tensor,
output_scale: torch.Tensor,
) -> torch.Tensor:
fused_result = auto_functionalized(
FUSED_OP,
a=a,
b=b,
a_scales=a_scales,
b_scales=b_scales,
output_scale=output_scale,
bias=bias,
)
return fused_result[1]
def register(self, pm_pass: PatternMatcherPass) -> None:
inputs = [
self._empty(1, 1, dtype=FP8_DTYPE), # a
self._empty(1, 1, dtype=FP8_DTYPE), # b
self._empty(1, 1, dtype=torch.float32), # a_scales
self._empty(1, 1, dtype=torch.float32), # b_scales
self._empty(1, dtype=torch.float32), # bias
self._empty(1, dtype=torch.float32), # output_scale
]
register_replacement(
self.pattern,
self.replacement,
inputs,
fwd_only,
pm_pass,
)
class GemmQuantFusionPass(VllmPatternMatcherPass):
"""
Compilation pass that fuses GEMM + static FP8 quantization.
Supported platforms:
- ROCm (MI300X+): via torch._scaled_mm with FP8 output dtype
(hipBLASLt natively supports FP8 output since ROCm 6.0)
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns = PatternMatcherPass(pass_name="gemm_quant_fusion_pass")
if SCALED_MM_OP is None or FUSED_OP is None:
logger.debug(
"GEMM + FP8 quant fusion: no fused op available "
"for current platform, skipping"
)
return
for out_dtype in (torch.bfloat16, torch.float16):
GemmStaticFP8QuantPattern(out_dtype, self.device).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug(
"GemmQuantFusion: replaced %s patterns",
self.matched_count,
)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, GemmStaticFP8QuantPattern)