Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _determine_shared_experts_order(
return SharedExpertsOrder.MK_INTERNAL_OVERLAPPED

should_run_shared_in_aux_stream = (
current_platform.is_cuda()
current_platform.is_cuda_alike()
and not self._use_dp_chunking
and self._stream is not None
and hidden_states.shape[0]
Expand Down
21 changes: 11 additions & 10 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]

if self.is_fp4_ckpt:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will break the main checkpoint which has different dtypes for these

# Always included; the fallback check in the loading loop handles
# checkpoints that already have fused wk_weights_proj tensors.
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)

expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
self,
Expand Down Expand Up @@ -297,10 +298,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
continue
name_mapped = name.replace(weight_name, param_name)

# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
# QKV fusion and indexer fusion are optionalfall back to
# direct weight loading when the mapped name doesn't exist.
if (
param_name == "fused_qkv_a_proj"
param_name in ("fused_qkv_a_proj", "wk_weights_proj")
) and name_mapped not in params_dict:
continue
else:
Expand Down
73 changes: 29 additions & 44 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,36 +644,20 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
)
if self.is_fp4_ckpt:
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# weights_proj does not get quantized,
# so we run both with quant_config=None
# wk may be upcasted from the default quant;
# experiments show fusion is always faster unless WK proj is in FP4,
# which is not the case for all known quants.
self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
[self.head_dim, self.n_head],
bias=False,
quant_config=None,
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
else:
self.wk = ReplicatedLinear(
hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
)
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# FP4 checkpoints don't quantize weights_proj, so use quant_config=None.
# Other quantized checkpoints (e.g. GLM-5-FP8) may quantize the fused
# tensor, so pass quant_config to create weight_scale_inv parameters.
# Checkpoints with separate wk/weights_proj tensors are handled by the
# stacked_params_mapping in load_weights.
self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
[self.head_dim, self.n_head],
bias=False,
quant_config=None if self.is_fp4_ckpt else quant_config,
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5

Expand Down Expand Up @@ -714,14 +698,9 @@ def forward(
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
if self.is_fp4_ckpt:
# Fused wk + weights_proj: one GEMM, then split
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights = kw[:, self.head_dim :]
else:
k, _ = self.wk(hidden_states)
weights, _ = self.weights_proj(hidden_states)
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights = kw[:, self.head_dim :]

k = self.k_norm(k)
k_pe, k_nope = torch.split(
Expand Down Expand Up @@ -1469,8 +1448,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
if self.is_fp4_ckpt:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
if self.is_v32:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj).
# For checkpoints with separate wk/weights_proj tensors, this mapping
# loads them into the fused MergedColumnParallelLinear shards.
# For checkpoints that already have fused wk_weights_proj (e.g.
# GLM-5-FP8), the substring match is a false positive and the
# fallback check in the loading loop skips it gracefully.
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
Expand Down Expand Up @@ -1528,11 +1512,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
continue
name_mapped = name.replace(weight_name, param_name)

# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
# if go with fusion option, then update name
# QKV fusion and indexer fusion are optional — fall back to
# direct weight loading when the mapped name doesn't exist
# (e.g. fused checkpoints where "wk" falsely matches
# "wk_weights_proj", or when QKV fusion is disabled).
if (
param_name == "fused_qkv_a_proj"
param_name in ("fused_qkv_a_proj", "wk_weights_proj")
) and name_mapped not in params_dict:
continue
else:
Expand Down
26 changes: 23 additions & 3 deletions vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
import torch

from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton

logger = init_logger(__name__)

_AITER_MQA_SMALL_HEADS_WARNED = False

if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops

Expand Down Expand Up @@ -322,17 +327,31 @@ def rocm_fp8_paged_mqa_logits(
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
global _AITER_MQA_SMALL_HEADS_WARNED
from vllm._aiter_ops import rocm_aiter_ops

batch_size, next_n, heads, _ = q_fp8.shape

# AITER's deepgemm_fp8_paged_mqa_logits_stage1 computes TileQCount
# from num_heads; when heads < 16 (e.g. GLM-5 with TP=8 → 8 heads)
# TileQCount becomes 0, causing ZeroDivisionError.
# Tracked: https://github.com/ROCm/aiter/issues/2563
aiter_paged_mqa_logits_module = None
if rocm_aiter_ops.is_enabled():
if rocm_aiter_ops.is_enabled() and heads >= 16:
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
elif rocm_aiter_ops.is_enabled() and not _AITER_MQA_SMALL_HEADS_WARNED:
logger.warning(
"AITER paged MQA logits kernel does not support %d heads "
"(requires >= 16). Falling back to PyTorch reference. "
"See https://github.com/ROCm/aiter/issues/2563",
heads,
)
_AITER_MQA_SMALL_HEADS_WARNED = True

if aiter_paged_mqa_logits_module is not None:
deepgemm_fp8_paged_mqa_logits_stage1 = (
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
)
batch_size, next_n, heads, _ = q_fp8.shape
out_qk = torch.full(
(heads, batch_size * next_n, max_model_len),
float("-inf"),
Expand Down Expand Up @@ -449,8 +468,9 @@ def rocm_fp8_mqa_logits(
# path after aiter merge this kernel into main
from vllm._aiter_ops import rocm_aiter_ops

heads = q.shape[1]
aiter_mqa_logits_module = None
if rocm_aiter_ops.is_enabled():
if rocm_aiter_ops.is_enabled() and heads >= 16:
aiter_mqa_logits_module = mqa_logits_module()

if aiter_mqa_logits_module is not None:
Expand Down
Loading