Skip to content

Commit a0b8102

Browse files
ssubhanjaliclaude
andcommitted
[Bugfix] Pad Marlin FP8 MoE weight dims to tile alignment under TP > 1
The Marlin kernel requires size_n % 64 == 0 (tile_n_size) and size_k % 16 == 0 (tile_k_size). When tensor-parallel sharding splits MoE weights across GPUs, per-rank dimensions can violate these constraints and cause a crash at model load time on any GPU that falls back to the Marlin FP8 MoE path (CC < 9.0: L40S, A100, A10G). Example — Nemotron Nano 3 at TP=4 (intermediate_size=1856): w13 gate+up: size_n = 464 per shard → 464 % 64 = 16 ✗ w2 down: size_k = 232 per shard → 232 % 16 = 8 ✗ This error is never triggered on Hopper+ (CC >= 9.0) because vLLM selects native FP8 MoE kernels (CUTLASS/Triton) on those GPUs and never enters the Marlin path. Fix: - Define MARLIN_TILE_N=64, MARLIN_TILE_K=16 and _pad_to_marlin_tile() helper in marlin_utils_fp8.py - repack_weight(): pad size_n/size_k to tile boundaries before calling gptq_marlin_repack - permute_scales(): pad scales to match padded size_n - fused_marlin_moe.py _fused_marlin_moe(): import tile constants, compute padded sizes, use them for GEMM calls, trim w13 padding before activation, pad intermediate output before w2 GEMM Padding with zeros is mathematically a no-op: zero weights and zero inputs contribute nothing to GEMM outputs. For already-aligned dimensions all padding amounts are zero and no operations are performed. Tested on B200 with VLLM_TEST_FORCE_FP8_MARLIN=1 using Nemotron Nano 3 weight shapes (E=2, K=1024, N=232, W13_N=464). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent a9e532a commit a0b8102

File tree

2 files changed

+66
-8
lines changed

2 files changed

+66
-8
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
marlin_moe_intermediate_size,
3636
marlin_quant_input,
3737
)
38+
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
39+
MARLIN_TILE_K,
40+
MARLIN_TILE_N,
41+
)
3842
from vllm.model_executor.layers.quantization.utils.quant_utils import (
3943
QuantKey,
4044
kFp8Static128BlockSym,
@@ -88,12 +92,16 @@ def _fused_marlin_moe(
8892
M, K = hidden_states.size()
8993
N = marlin_moe_intermediate_size(w1, w2)
9094
w13_num_shards = 2 if activation.is_gated else 1
95+
_w13_n = w13_num_shards * N
96+
# Compute the same tile-aligned padded sizes used at weight-load time.
97+
_w13_n_padded = _w13_n + ((-_w13_n) % MARLIN_TILE_N) # for w13 GEMM size_n
98+
_N_padded = N + ((-N) % MARLIN_TILE_K) # for w2 GEMM size_k
9199
if workspace is None:
92100
workspace = marlin_make_workspace_new(hidden_states.device, 4)
93101

94102
if intermediate_cache13 is None:
95103
intermediate_cache13 = torch.empty(
96-
(M * num_topk * max(w13_num_shards * N, K),),
104+
(M * num_topk * max(_w13_n_padded, K),),
97105
device=hidden_states.device,
98106
dtype=hidden_states.dtype,
99107
)
@@ -106,7 +114,7 @@ def _fused_marlin_moe(
106114
)
107115

108116
intermediate_cache1 = _resize_cache(
109-
intermediate_cache13, (M * num_topk, w13_num_shards * N)
117+
intermediate_cache13, (M * num_topk, _w13_n_padded)
110118
)
111119

112120
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K))
@@ -143,17 +151,21 @@ def _fused_marlin_moe(
143151
mul_topk_weights=apply_router_weight_on_input,
144152
b_q_type=quant_type,
145153
size_m=M,
146-
size_n=w13_num_shards * N,
154+
size_n=_w13_n_padded, # padded to Marlin tile_n boundary
147155
size_k=K,
148156
is_k_full=is_k_full,
149157
use_atomic_add=False,
150158
use_fp32_reduce=True,
151159
is_zp_float=False,
152160
)
161+
# Trim w13 padding before activation (GEMM produced _w13_n_padded cols,
162+
# activation expects true _w13_n cols).
163+
if _w13_n_padded != _w13_n:
164+
intermediate_cache1 = intermediate_cache1[:, :_w13_n].contiguous()
153165
activation_func(
154166
activation,
155167
intermediate_cache2,
156-
intermediate_cache1.view(-1, w13_num_shards * N),
168+
intermediate_cache1.view(-1, _w13_n),
157169
)
158170

159171
if output is None:
@@ -174,6 +186,13 @@ def _fused_marlin_moe(
174186
intermediate_cache2, input_dtype
175187
)
176188

189+
# Pad activation output to _N_padded so w2 GEMM size_k is tile-aligned.
190+
# Extra columns are zero; the matching zero-padding in w2's repacked weights
191+
# ensures they contribute nothing to the output.
192+
if _N_padded != N:
193+
intermediate_cache2 = torch.nn.functional.pad(
194+
intermediate_cache2, (0, _N_padded - N)
195+
)
177196
output = ops.moe_wna16_marlin_gemm(
178197
intermediate_cache2,
179198
output,
@@ -196,7 +215,7 @@ def _fused_marlin_moe(
196215
b_q_type=quant_type,
197216
size_m=M * num_topk,
198217
size_n=K,
199-
size_k=N,
218+
size_k=_N_padded, # padded to Marlin tile_k boundary
200219
is_k_full=is_k_full,
201220
use_atomic_add=False,
202221
use_fp32_reduce=True,

vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44

55
import torch
6+
import torch.nn.functional as F
67

78
import vllm._custom_ops as ops
89
from vllm.logger import init_logger
@@ -20,6 +21,21 @@
2021

2122
logger = init_logger(__name__)
2223

24+
# Marlin kernel tile alignment requirements.
25+
MARLIN_TILE_N = 64 # size_n must be divisible by this
26+
MARLIN_TILE_K = 16 # size_k must be divisible by this
27+
28+
29+
def _pad_to_marlin_tile(size_n: int, size_k: int) -> tuple[int, int, int, int]:
30+
"""Return (padded_size_n, padded_size_k, pad_n, pad_k).
31+
32+
Computes the smallest tile-aligned sizes >= size_n and size_k.
33+
pad_n / pad_k are zero when the dimension is already aligned.
34+
"""
35+
pad_n = (-size_n) % MARLIN_TILE_N
36+
pad_k = (-size_k) % MARLIN_TILE_K
37+
return size_n + pad_n, size_k + pad_k, pad_n, pad_k
38+
2339

2440
def is_fp8_marlin_supported():
2541
return current_platform.has_device_capability(75)
@@ -247,12 +263,28 @@ def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor:
247263

248264
assert weight.shape == (e, size_n, size_k)
249265

266+
# Pad size_n and size_k to Marlin tile boundaries so gptq_marlin_repack
267+
# does not crash when TP sharding produces non-aligned per-rank dimensions:
268+
# tile_n_size = 64 (affects w13 gate+up, e.g. 464 → 512)
269+
# tile_k_size = 16 (affects w2 down-proj, e.g. 232 → 240)
270+
_padded_size_n, _padded_size_k, _pad_n, _pad_k = _pad_to_marlin_tile(
271+
size_n, size_k
272+
)
250273
for i in range(e):
251274
qweight = pack_fp8_to_int32(weight[i], size_k_first=False)
275+
# pad K before transposing: qweight shape is (size_n, size_k//4)
276+
if _pad_k > 0:
277+
qweight = F.pad(qweight, (0, _pad_k // 4))
252278
qweight = qweight.T.contiguous()
253-
279+
# pad N after transposing: qweight shape is (padded_size_k//4, size_n)
280+
if _pad_n > 0:
281+
qweight = F.pad(qweight, (0, _pad_n))
254282
marlin_qweight = ops.gptq_marlin_repack(
255-
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
283+
b_q_weight=qweight,
284+
perm=perm,
285+
size_k=_padded_size_k,
286+
size_n=_padded_size_n,
287+
num_bits=8,
256288
)
257289
tensor_list.append(marlin_qweight)
258290

@@ -302,9 +334,16 @@ def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
302334
# size_n may not divisible by block_size[0]
303335
scales = scales[..., :size_n].contiguous()
304336

337+
_padded_size_n, _padded_size_k, _pad_n, _ = _pad_to_marlin_tile(size_n, size_k)
305338
for i in range(e):
339+
_s = scales[i]
340+
if _pad_n > 0:
341+
_s = F.pad(_s, (0, _pad_n))
306342
marlin_scales = marlin_permute_scales(
307-
s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size
343+
s=_s,
344+
size_k=_padded_size_k,
345+
size_n=_padded_size_n,
346+
group_size=group_size,
308347
)
309348
tensor_list.append(marlin_scales)
310349

0 commit comments

Comments
 (0)