Skip to content

Commit 9609b1f

Browse files
authored
Integrate flashinfer mm_mxfp8 in ModelOpt MXFP8 (vllm-project#35053)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
1 parent a0c7081 commit 9609b1f

3 files changed

Lines changed: 230 additions & 11 deletions

File tree

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
MXFP8_VALUE_DTYPE,
7171
Mxfp8LinearBackend,
7272
Mxfp8LinearOp,
73+
swizzle_mxfp8_scale,
7374
)
7475
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
7576
apply_nvfp4_linear,
@@ -1689,9 +1690,9 @@ def __init__(self, quant_config: ModelOptMxFp8Config) -> None:
16891690
"Dynamic quantization is not supported."
16901691
)
16911692

1692-
backend: Mxfp8LinearBackend = Mxfp8LinearBackend.EMULATION
1693-
self.mxfp8_linear_op = Mxfp8LinearOp(backend=backend)
1694-
logger.info_once("Using %s backend for MXFP8 GEMM", backend.value)
1693+
self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS
1694+
self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend)
1695+
logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value)
16951696

16961697
def create_weights(
16971698
self,
@@ -1749,7 +1750,38 @@ def create_weights(
17491750
)
17501751
layer.register_parameter("weight_scale", weight_scale)
17511752

1753+
def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None:
1754+
"""Not swizzled - MXFP8 GEMM emulation"""
1755+
weight = layer.weight.data # [N, K]
1756+
N, K = weight.shape
1757+
scale_k = K // MXFP8_BLOCK_SIZE
1758+
1759+
# Slice weight_scale to match weight dimensions (handles padding)
1760+
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
1761+
1762+
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
1763+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1764+
1765+
def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None:
1766+
"""Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
1767+
weight = layer.weight.data # [N, K]
1768+
N, K = weight.shape
1769+
1770+
# 2D weight scale
1771+
weight_scale = layer.weight_scale.data
1772+
1773+
# Swizzle the weight scales
1774+
scale_k = K // MXFP8_BLOCK_SIZE
1775+
weight_scale_2d = weight_scale[:N, :scale_k].contiguous()
1776+
weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)
1777+
1778+
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
1779+
layer.weight_scale = Parameter(
1780+
weight_scale_swizzled.contiguous(), requires_grad=False
1781+
)
1782+
17521783
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1784+
# Validate weight tensor
17531785
if layer.weight.ndim != 2:
17541786
raise ValueError(
17551787
f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D "
@@ -1763,15 +1795,23 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
17631795
f"quantized with MXFP8."
17641796
)
17651797

1766-
weight = layer.weight.data # [N, K]
1767-
N, K = weight.shape
1768-
scale_k = K // MXFP8_BLOCK_SIZE
1798+
# Validate weight scale tensor (should be 2D, not swizzled)
1799+
assert layer.weight_scale.ndim == 2, (
1800+
f"MXFP8 weight scale must be 2D, got {layer.weight_scale.ndim}D"
1801+
)
1802+
assert layer.weight_scale.dtype == MXFP8_SCALE_DTYPE, (
1803+
f"MXFP8 weight scale must be {MXFP8_SCALE_DTYPE},"
1804+
f" got {layer.weight_scale.dtype}"
1805+
)
17691806

1770-
# Slice weight_scale to match weight dimensions (handles padding)
1771-
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
1807+
if self.backend == Mxfp8LinearBackend.EMULATION:
1808+
# Swizzled layout is not used
1809+
self._process_weights_after_loading_scale_2d(layer)
1810+
return
17721811

1773-
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
1774-
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1812+
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
1813+
# Swizzled layout is required for Flashinfer CUTLASS
1814+
self._process_weights_after_loading_scale_1d(layer)
17751815

17761816
def apply(
17771817
self,

vllm/model_executor/layers/quantization/utils/mxfp8_utils.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import torch
77

88
from vllm.logger import init_logger
9+
from vllm.utils import flashinfer as vllm_flashinfer
910
from vllm.utils.torch_utils import direct_register_custom_op
1011

1112
logger = init_logger(__name__)
1213

1314

1415
class Mxfp8LinearBackend(Enum):
1516
EMULATION = "emulation"
17+
FLASHINFER_CUTLASS = "flashinfer-cutlass"
1618

1719

1820
# MXFP8 constants
@@ -21,6 +23,30 @@ class Mxfp8LinearBackend(Enum):
2123
MXFP8_BLOCK_SIZE = 32
2224

2325

26+
def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
27+
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
28+
scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8
29+
factor = scaling_vector_size * 4 # 128
30+
31+
num_m_tiles = (M + 127) // 128
32+
num_k_tiles = (K + factor - 1) // factor
33+
34+
m_padded = num_m_tiles * 128
35+
k_scale_padded = num_k_tiles * 4
36+
37+
scale_cols = K // scaling_vector_size
38+
sf_padded = torch.zeros(
39+
(m_padded, k_scale_padded), dtype=sf.dtype, device=sf.device
40+
)
41+
sf_padded[:M, :scale_cols] = sf
42+
43+
sf_reshaped = sf_padded.view(num_m_tiles, 4, 32, num_k_tiles, 4)
44+
45+
sf_swizzled = sf_reshaped.transpose(1, 3)
46+
47+
return sf_swizzled.contiguous().view(-1)
48+
49+
2450
def _mxfp8_e4m3_quantize_impl(
2551
x: torch.Tensor, is_sf_swizzled_layout: bool = False
2652
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -108,7 +134,7 @@ def __init__(self, backend: Mxfp8LinearBackend):
108134

109135
self.backend = backend
110136

111-
def apply(
137+
def _apply_emulation(
112138
self,
113139
input: torch.Tensor,
114140
weight: torch.Tensor,
@@ -132,3 +158,79 @@ def apply(
132158

133159
output = torch.nn.functional.linear(input, weight_bf16, bias)
134160
return output.to(out_dtype)
161+
162+
def _apply_flashinfer_cutlass(
163+
self,
164+
input: torch.Tensor,
165+
weight: torch.Tensor,
166+
weight_scale: torch.Tensor,
167+
out_dtype: torch.dtype,
168+
bias: torch.Tensor | None = None,
169+
) -> torch.Tensor:
170+
N, K = weight.shape
171+
172+
input_shape = input.shape
173+
input_2d = input.view(-1, K)
174+
M_orig = input_2d.shape[0]
175+
176+
# Minimum dimension size for F8_128x4 block scaling layout
177+
min_dim = 128
178+
179+
assert min_dim <= K, (
180+
f"mm_mxfp8 requires K >= {min_dim}, got K={K}. "
181+
f"in_features is too small for mm_mxfp8."
182+
)
183+
assert K % MXFP8_BLOCK_SIZE == 0, (
184+
f"mm_mxfp8 requires K to be divisible by {MXFP8_BLOCK_SIZE}, got K={K}."
185+
)
186+
assert min_dim <= N, (
187+
f"mm_mxfp8 requires N >= {min_dim}, got N={N}. "
188+
f"out_features is too small for mm_mxfp8."
189+
)
190+
191+
M_padded = ((M_orig + min_dim - 1) // min_dim) * min_dim
192+
if M_padded != M_orig:
193+
pad_rows = M_padded - M_orig
194+
input_2d = torch.nn.functional.pad(input_2d, (0, 0, 0, pad_rows))
195+
196+
input_mxfp8, input_scale = mxfp8_e4m3_quantize(
197+
input_2d,
198+
is_sf_swizzled_layout=True, # Swizzled for best accuracy
199+
)
200+
201+
if not weight.is_contiguous():
202+
weight = weight.contiguous()
203+
204+
output = vllm_flashinfer.mm_mxfp8(
205+
input_mxfp8,
206+
weight.t(),
207+
input_scale,
208+
weight_scale,
209+
out_dtype=out_dtype,
210+
backend="cutlass",
211+
)
212+
213+
if M_padded != M_orig:
214+
output = output[:M_orig, :]
215+
216+
if bias is not None:
217+
output = output + bias
218+
219+
output_shape = (*input_shape[:-1], N)
220+
return output.view(output_shape)
221+
222+
def apply(
223+
self,
224+
input: torch.Tensor,
225+
weight: torch.Tensor,
226+
weight_scale: torch.Tensor,
227+
out_dtype: torch.dtype,
228+
bias: torch.Tensor | None = None,
229+
) -> torch.Tensor:
230+
if self.backend == Mxfp8LinearBackend.EMULATION:
231+
return self._apply_emulation(input, weight, weight_scale, out_dtype, bias)
232+
233+
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
234+
return self._apply_flashinfer_cutlass(
235+
input, weight, weight_scale, out_dtype, bias
236+
)

vllm/utils/flashinfer.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,83 @@ def flashinfer_nvfp4_quantize_fake(
553553
rounded_m, rounded_n, dtype=torch.uint8, device=a.device
554554
)
555555

556+
@torch.library.custom_op(
557+
"vllm::mm_mxfp8",
558+
mutates_args=[],
559+
device_types="cuda",
560+
)
561+
def mm_mxfp8(
562+
A: torch.Tensor,
563+
B: torch.Tensor,
564+
A_scale: torch.Tensor,
565+
B_scale: torch.Tensor,
566+
out_dtype: torch.dtype,
567+
backend: str = "cutlass",
568+
) -> torch.Tensor:
569+
from flashinfer import mm_mxfp8 as mm_mxfp8_
570+
571+
return mm_mxfp8_(
572+
A,
573+
B,
574+
A_scale,
575+
B_scale,
576+
out=None,
577+
out_dtype=out_dtype,
578+
backend=backend,
579+
)
580+
581+
@torch.library.register_fake(
582+
"vllm::mm_mxfp8",
583+
)
584+
def mm_mxfp8_fake(
585+
A: torch.Tensor,
586+
B: torch.Tensor,
587+
A_scale: torch.Tensor,
588+
B_scale: torch.Tensor,
589+
out_dtype: torch.dtype,
590+
backend: str = "cutlass",
591+
) -> torch.Tensor:
592+
# A is [m, k], B is [k, n] -> output [m, n]
593+
return torch.empty(A.shape[0], B.shape[1], dtype=out_dtype, device=A.device)
594+
595+
596+
def flashinfer_mm_mxfp8(
597+
a: torch.Tensor,
598+
b: torch.Tensor,
599+
block_scale_a: torch.Tensor,
600+
block_scale_b: torch.Tensor,
601+
out_dtype: torch.dtype,
602+
backend: str = "cutlass",
603+
) -> torch.Tensor:
604+
"""MXFP8 MM helper - mirrors flashinfer_scaled_fp4_mm API.
605+
606+
Takes non-transposed weights and handles transpose internally.
607+
608+
CRITICAL: mm_mxfp8 CUTLASS kernel requires SWIZZLED 1D scales for optimal
609+
performance and accuracy. Both input and weight scales should be in
610+
swizzled format from FlashInfer's mxfp8_quantize(is_sf_swizzled_layout=True).
611+
"""
612+
# a shape [M, K]
613+
# b shape [K, N]
614+
assert a.ndim == 2 and b.ndim == 2
615+
assert a.shape[1] == b.shape[1] # K dimension must match
616+
617+
if block_scale_b.ndim != 1:
618+
raise ValueError(
619+
"mm_mxfp8 expects 1D swizzled weight scales for CUTLASS; "
620+
f"got shape={tuple(block_scale_b.shape)}"
621+
)
622+
623+
# Output tensor [M, N]
624+
return mm_mxfp8(
625+
a,
626+
b.t(), # Transpose weight: [N, K] -> [K, N]
627+
block_scale_a,
628+
block_scale_b,
629+
out_dtype,
630+
backend=backend,
631+
)
632+
556633

557634
def flashinfer_scaled_fp4_mm(
558635
a: torch.Tensor,

0 commit comments

Comments
 (0)