Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,6 @@ def fused_moe_kernel(
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
Expand Down Expand Up @@ -365,6 +363,9 @@ def fused_moe_kernel(
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
HAS_BIAS: tl.constexpr,
block_k_diviable: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
Expand Down Expand Up @@ -503,12 +504,20 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
if block_k_diviable:
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
b = tl.load(b_ptrs)
else:
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
Comment on lines +507 to +520
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The optimization for block_k_diviable (typo intended to match existing code) correctly avoids the K dimension mask when K is a multiple of BLOCK_SIZE_K. However, the load for b in the if block_k_diviable branch (line 513) omits the other=0.0 parameter, whereas the else branch (line 520) includes it. While functionally safe due to the modulo % N in offs_bn and the final output masking, it is better to maintain consistency or explicitly document why other is omitted here. More importantly, the typo diviable instead of divisible is propagated; while it matches fused_moe_kernel_gptq_awq, it would be better to fix it in both places if possible, though strictly speaking, I should only suggest changes to the modified lines.

# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
Expand Down Expand Up @@ -802,8 +811,6 @@ def invoke_fused_moe_triton_kernel(
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.size(1),
B.size(2),
EM,
num_tokens,
A.stride(0),
Expand All @@ -822,6 +829,8 @@ def invoke_fused_moe_triton_kernel(
B_bias.stride(1) if B_bias is not None else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
N=B.size(1),
K=B.size(2),
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
Expand All @@ -832,6 +841,7 @@ def invoke_fused_moe_triton_kernel(
naive_block_assignment=(sorted_token_ids is None),
HAS_BIAS=HAS_BIAS,
BLOCK_SIZE_K=BLOCK_SIZE_K,
block_k_diviable=A.size(1) % BLOCK_SIZE_K == 0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of block_k_diviable uses A.size(1), which is the input feature dimension K. While correct, it is slightly inconsistent with the kernel call where K is explicitly passed as B.size(2). It would be cleaner to use the same source for K to avoid any confusion, although they are expected to be equal.

Suggested change
block_k_diviable=A.size(1) % BLOCK_SIZE_K == 0,
block_k_diviable=B.size(2) % BLOCK_SIZE_K == 0,

**config,
)

Expand Down
Loading