diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d5b8feb3c9b9..0005a84ecde9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -325,8 +325,6 @@ def fused_moe_kernel( expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions - N, - K, EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when @@ -365,6 +363,9 @@ def fused_moe_kernel( use_int8_w8a16: tl.constexpr, per_channel_quant: tl.constexpr, HAS_BIAS: tl.constexpr, + block_k_diviable: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -503,12 +504,20 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + if block_k_diviable: + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) @@ -802,8 +811,6 @@ def invoke_fused_moe_triton_kernel( sorted_token_ids, expert_ids, num_tokens_post_padded, - B.size(1), - B.size(2), EM, num_tokens, A.stride(0), @@ -822,6 +829,8 @@ def invoke_fused_moe_triton_kernel( B_bias.stride(1) if B_bias is not None else 0, 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], + N=B.size(1), + K=B.size(2), MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, @@ -832,6 +841,7 @@ def invoke_fused_moe_triton_kernel( naive_block_assignment=(sorted_token_ids is None), HAS_BIAS=HAS_BIAS, BLOCK_SIZE_K=BLOCK_SIZE_K, + block_k_diviable=A.size(1) % BLOCK_SIZE_K == 0, **config, )