22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44
5+ from collections .abc import Sequence
6+
57import torch
68
79from vllm import _custom_ops as ops
@@ -150,6 +152,12 @@ def apply_weights(
150152
151153
152154class CutlassFP8ScaledMMLinearKernel (FP8ScaledMMLinearKernel ):
155+ def __init__ (
156+ self , c : FP8ScaledMMLinearLayerConfig , layer_param_names : Sequence [str ]
157+ ) -> None :
158+ self .logical_output_size : int | None = None
159+ super ().__init__ (c , layer_param_names )
160+
153161 @classmethod
154162 def is_supported (
155163 cls , compute_capability : int | None = None
@@ -176,6 +184,33 @@ def _pad_to_alignment(
176184 pad_spec [- (2 * dim + 1 )] = pad_size
177185 return torch .nn .functional .pad (x , pad_spec , value = value )
178186
187+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
188+ weight_name , weight_scale_name , _ , _ = self .layer_param_names
189+ weight = getattr (layer , weight_name )
190+
191+ # keep the logical output width so runtime can slice away static padding.
192+ self .logical_output_size = weight .shape [1 ]
193+
194+ pad_k = (16 - weight .shape [0 ] % 16 ) % 16
195+ pad_n = (16 - weight .shape [1 ] % 16 ) % 16
196+ if pad_k == 0 and pad_n == 0 :
197+ return
198+
199+ # B is column-major [K, N]
200+ padded_weight = torch .nn .functional .pad (
201+ weight .t ().contiguous (),
202+ (0 , pad_k , 0 , pad_n ),
203+ ).t ()
204+ replace_parameter (layer , weight_name , padded_weight .data )
205+
206+ weight_scale = getattr (layer , weight_scale_name , None )
207+ if weight_scale is not None and pad_n > 0 and weight_scale .numel () > 1 :
208+ flat_scale = weight_scale .reshape (- 1 )
209+ padded_scale = self ._pad_to_alignment (
210+ flat_scale , dim = 0 , alignment = 16 , value = 1.0
211+ ).view (- 1 , * weight_scale .shape [1 :])
212+ replace_parameter (layer , weight_scale_name , padded_scale .data )
213+
179214 def apply_scaled_mm (
180215 self ,
181216 * ,
@@ -187,39 +222,25 @@ def apply_scaled_mm(
187222 bias : torch .Tensor | None ,
188223 output_shape : list ,
189224 ) -> torch .Tensor :
190- # Per-tensor/Per-channel padding to use Cutlass instead of Triton.
191- K , N = B .shape
192- pad_k = (16 - K % 16 ) % 16
193- pad_n = (16 - N % 16 ) % 16
194-
195- if pad_k > 0 or pad_n > 0 :
196- # B is column-major [K, N]. Transpose to row-major [N, K],
197- # pad both dims in one call, then transpose back so the
198- # result keeps column-major layout with stride (1, K_padded).
199- B = torch .nn .functional .pad (B .t ().contiguous (), (0 , pad_k , 0 , pad_n )).t ()
200-
201- if pad_k > 0 :
202- A = self ._pad_to_alignment (A , dim = 1 , alignment = 16 )
203- if pad_n > 0 :
204- if bias is not None :
205- bias = self ._pad_to_alignment (bias , dim = 0 , alignment = 16 )
206- # Bs is per-tensor (numel==1) or per-channel (numel==N)
207- # in this kernel class — never 2D block-wise.
208- if Bs .numel () > 1 :
209- Bs = self ._pad_to_alignment (
210- Bs .view (- 1 ), dim = 0 , alignment = 16 , value = 1.0
211- )
212- if Bs .dim () == 1 and B .shape [1 ] > 1 :
213- Bs = Bs .view (- 1 , 1 )
225+ padded_k , padded_n = B .shape
226+ output_size = self .logical_output_size
227+ assert output_size is not None
228+ pad_k = padded_k - A .shape [1 ]
229+ pad_n = padded_n - output_size
230+
231+ if pad_k > 0 :
232+ A = self ._pad_to_alignment (A , dim = 1 , alignment = 16 )
233+ if pad_n > 0 and bias is not None :
234+ bias = self ._pad_to_alignment (bias , dim = 0 , alignment = 16 )
214235
215236 output = ops .cutlass_scaled_mm (
216237 A , B , out_dtype = out_dtype , scale_a = As , scale_b = Bs , bias = bias
217238 )
218239
219240 if pad_n > 0 :
220- output = output [..., :N ].contiguous ()
241+ output = output [..., :output_size ].contiguous ()
221242
222- return output .view (* output_shape )
243+ return output .view (* output_shape [: - 1 ], output_size )
223244
224245
225246class CutlassFp8BlockScaledMMKernel (Fp8BlockScaledMMLinearKernel ):
0 commit comments