Skip to content

Commit 53ff50f

Browse files
[Perf] Optimize CutlassFP8ScaledMMLinearKernel when padding needed by pre-weight processing, 13.5% TTFT improvement (vllm-project#42651)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
1 parent 363fc84 commit 53ff50f

1 file changed

Lines changed: 47 additions & 26 deletions

File tree

  • vllm/model_executor/kernels/linear/scaled_mm

vllm/model_executor/kernels/linear/scaled_mm/cutlass.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44

5+
from collections.abc import Sequence
6+
57
import torch
68

79
from vllm import _custom_ops as ops
@@ -150,6 +152,12 @@ def apply_weights(
150152

151153

152154
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
155+
def __init__(
156+
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
157+
) -> None:
158+
self.logical_output_size: int | None = None
159+
super().__init__(c, layer_param_names)
160+
153161
@classmethod
154162
def is_supported(
155163
cls, compute_capability: int | None = None
@@ -176,6 +184,33 @@ def _pad_to_alignment(
176184
pad_spec[-(2 * dim + 1)] = pad_size
177185
return torch.nn.functional.pad(x, pad_spec, value=value)
178186

187+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
188+
weight_name, weight_scale_name, _, _ = self.layer_param_names
189+
weight = getattr(layer, weight_name)
190+
191+
# keep the logical output width so runtime can slice away static padding.
192+
self.logical_output_size = weight.shape[1]
193+
194+
pad_k = (16 - weight.shape[0] % 16) % 16
195+
pad_n = (16 - weight.shape[1] % 16) % 16
196+
if pad_k == 0 and pad_n == 0:
197+
return
198+
199+
# B is column-major [K, N]
200+
padded_weight = torch.nn.functional.pad(
201+
weight.t().contiguous(),
202+
(0, pad_k, 0, pad_n),
203+
).t()
204+
replace_parameter(layer, weight_name, padded_weight.data)
205+
206+
weight_scale = getattr(layer, weight_scale_name, None)
207+
if weight_scale is not None and pad_n > 0 and weight_scale.numel() > 1:
208+
flat_scale = weight_scale.reshape(-1)
209+
padded_scale = self._pad_to_alignment(
210+
flat_scale, dim=0, alignment=16, value=1.0
211+
).view(-1, *weight_scale.shape[1:])
212+
replace_parameter(layer, weight_scale_name, padded_scale.data)
213+
179214
def apply_scaled_mm(
180215
self,
181216
*,
@@ -187,39 +222,25 @@ def apply_scaled_mm(
187222
bias: torch.Tensor | None,
188223
output_shape: list,
189224
) -> torch.Tensor:
190-
# Per-tensor/Per-channel padding to use Cutlass instead of Triton.
191-
K, N = B.shape
192-
pad_k = (16 - K % 16) % 16
193-
pad_n = (16 - N % 16) % 16
194-
195-
if pad_k > 0 or pad_n > 0:
196-
# B is column-major [K, N]. Transpose to row-major [N, K],
197-
# pad both dims in one call, then transpose back so the
198-
# result keeps column-major layout with stride (1, K_padded).
199-
B = torch.nn.functional.pad(B.t().contiguous(), (0, pad_k, 0, pad_n)).t()
200-
201-
if pad_k > 0:
202-
A = self._pad_to_alignment(A, dim=1, alignment=16)
203-
if pad_n > 0:
204-
if bias is not None:
205-
bias = self._pad_to_alignment(bias, dim=0, alignment=16)
206-
# Bs is per-tensor (numel==1) or per-channel (numel==N)
207-
# in this kernel class — never 2D block-wise.
208-
if Bs.numel() > 1:
209-
Bs = self._pad_to_alignment(
210-
Bs.view(-1), dim=0, alignment=16, value=1.0
211-
)
212-
if Bs.dim() == 1 and B.shape[1] > 1:
213-
Bs = Bs.view(-1, 1)
225+
padded_k, padded_n = B.shape
226+
output_size = self.logical_output_size
227+
assert output_size is not None
228+
pad_k = padded_k - A.shape[1]
229+
pad_n = padded_n - output_size
230+
231+
if pad_k > 0:
232+
A = self._pad_to_alignment(A, dim=1, alignment=16)
233+
if pad_n > 0 and bias is not None:
234+
bias = self._pad_to_alignment(bias, dim=0, alignment=16)
214235

215236
output = ops.cutlass_scaled_mm(
216237
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
217238
)
218239

219240
if pad_n > 0:
220-
output = output[..., :N].contiguous()
241+
output = output[..., :output_size].contiguous()
221242

222-
return output.view(*output_shape)
243+
return output.view(*output_shape[:-1], output_size)
223244

224245

225246
class CutlassFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):

0 commit comments

Comments
 (0)