Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
marlin_moe_intermediate_size,
marlin_quant_input,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
MARLIN_TILE_K,
MARLIN_TILE_N,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Static128BlockSym,
Expand Down Expand Up @@ -88,12 +92,16 @@ def _fused_marlin_moe(
M, K = hidden_states.size()
N = marlin_moe_intermediate_size(w1, w2)
w13_num_shards = 2 if activation.is_gated else 1
_w13_n = w13_num_shards * N
# Compute the same tile-aligned padded sizes used at weight-load time.
_w13_n_padded = _w13_n + ((-_w13_n) % MARLIN_TILE_N) # for w13 GEMM size_n
_N_padded = N + ((-N) % MARLIN_TILE_K) # for w2 GEMM size_k
if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4)

if intermediate_cache13 is None:
intermediate_cache13 = torch.empty(
(M * num_topk * max(w13_num_shards * N, K),),
(M * num_topk * max(_w13_n_padded, K),),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
Expand All @@ -106,7 +114,7 @@ def _fused_marlin_moe(
)

intermediate_cache1 = _resize_cache(
intermediate_cache13, (M * num_topk, w13_num_shards * N)
intermediate_cache13, (M * num_topk, _w13_n_padded)
)

intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K))
Expand Down Expand Up @@ -143,17 +151,21 @@ def _fused_marlin_moe(
mul_topk_weights=apply_router_weight_on_input,
b_q_type=quant_type,
size_m=M,
size_n=w13_num_shards * N,
size_n=_w13_n_padded, # padded to Marlin tile_n boundary
size_k=K,
is_k_full=is_k_full,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
# Trim w13 padding before activation (GEMM produced _w13_n_padded cols,
# activation expects true _w13_n cols).
if _w13_n_padded != _w13_n:
intermediate_cache1 = intermediate_cache1[:, :_w13_n].contiguous()
activation_func(
activation,
intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N),
intermediate_cache1.view(-1, _w13_n),
)

if output is None:
Expand All @@ -174,6 +186,13 @@ def _fused_marlin_moe(
intermediate_cache2, input_dtype
)

# Pad activation output to _N_padded so w2 GEMM size_k is tile-aligned.
# Extra columns are zero; the matching zero-padding in w2's repacked weights
# ensures they contribute nothing to the output.
if _N_padded != N:
intermediate_cache2 = torch.nn.functional.pad(
intermediate_cache2, (0, _N_padded - N)
)
output = ops.moe_wna16_marlin_gemm(
intermediate_cache2,
output,
Expand All @@ -196,7 +215,7 @@ def _fused_marlin_moe(
b_q_type=quant_type,
size_m=M * num_topk,
size_n=K,
size_k=N,
size_k=_N_padded, # padded to Marlin tile_k boundary
is_k_full=is_k_full,
use_atomic_add=False,
use_fp32_reduce=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


import torch
import torch.nn.functional as F

import vllm._custom_ops as ops
from vllm.logger import init_logger
Expand All @@ -20,6 +21,21 @@

logger = init_logger(__name__)

# Marlin kernel tile alignment requirements.
MARLIN_TILE_N = 64 # size_n must be divisible by this
MARLIN_TILE_K = 16 # size_k must be divisible by this


def _pad_to_marlin_tile(size_n: int, size_k: int) -> tuple[int, int, int, int]:
"""Return (padded_size_n, padded_size_k, pad_n, pad_k).

Computes the smallest tile-aligned sizes >= size_n and size_k.
pad_n / pad_k are zero when the dimension is already aligned.
"""
pad_n = (-size_n) % MARLIN_TILE_N
pad_k = (-size_k) % MARLIN_TILE_K
return size_n + pad_n, size_k + pad_k, pad_n, pad_k


def is_fp8_marlin_supported():
return current_platform.has_device_capability(75)
Expand Down Expand Up @@ -247,12 +263,28 @@ def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor:

assert weight.shape == (e, size_n, size_k)

# Pad size_n and size_k to Marlin tile boundaries so gptq_marlin_repack
# does not crash when TP sharding produces non-aligned per-rank dimensions:
# tile_n_size = 64 (affects w13 gate+up, e.g. 464 → 512)
# tile_k_size = 16 (affects w2 down-proj, e.g. 232 → 240)
_padded_size_n, _padded_size_k, _pad_n, _pad_k = _pad_to_marlin_tile(
size_n, size_k
)
for i in range(e):
qweight = pack_fp8_to_int32(weight[i], size_k_first=False)
# pad K before transposing: qweight shape is (size_n, size_k//4)
if _pad_k > 0:
qweight = F.pad(qweight, (0, _pad_k // 4))
qweight = qweight.T.contiguous()

# pad N after transposing: qweight shape is (padded_size_k//4, size_n)
if _pad_n > 0:
qweight = F.pad(qweight, (0, _pad_n))
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
b_q_weight=qweight,
perm=perm,
size_k=_padded_size_k,
size_n=_padded_size_n,
num_bits=8,
)
tensor_list.append(marlin_qweight)

Expand Down Expand Up @@ -302,9 +334,16 @@ def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
# size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous()

_padded_size_n, _padded_size_k, _pad_n, _ = _pad_to_marlin_tile(size_n, size_k)
for i in range(e):
_s = scales[i]
if _pad_n > 0:
_s = F.pad(_s, (0, _pad_n))
marlin_scales = marlin_permute_scales(
s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size
s=_s,
size_k=_padded_size_k,
size_n=_padded_size_n,
group_size=group_size,
)
tensor_list.append(marlin_scales)

Expand Down
Loading