Conversation
There was a problem hiding this comment.
Code Review
This pull request optimizes the fused MoE kernel by converting N and K into constant expressions and introducing a block_k_diviable flag to skip masking when the K dimension is block-aligned. Review feedback suggests fixing a typo in the variable name, ensuring consistency in tl.load parameters by including the other=0.0 argument, and using B.size(2) for the divisibility check to align with the definition of K used in the kernel call.
| if block_k_diviable: | ||
| a = tl.load( | ||
| a_ptrs, | ||
| mask=token_mask[:, None], | ||
| other=0.0, | ||
| ) | ||
| b = tl.load(b_ptrs) | ||
| else: | ||
| a = tl.load( | ||
| a_ptrs, | ||
| mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), | ||
| other=0.0, | ||
| ) | ||
| b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) |
There was a problem hiding this comment.
The optimization for block_k_diviable (typo intended to match existing code) correctly avoids the K dimension mask when K is a multiple of BLOCK_SIZE_K. However, the load for b in the if block_k_diviable branch (line 513) omits the other=0.0 parameter, whereas the else branch (line 520) includes it. While functionally safe due to the modulo % N in offs_bn and the final output masking, it is better to maintain consistency or explicitly document why other is omitted here. More importantly, the typo diviable instead of divisible is propagated; while it matches fused_moe_kernel_gptq_awq, it would be better to fix it in both places if possible, though strictly speaking, I should only suggest changes to the modified lines.
| naive_block_assignment=(sorted_token_ids is None), | ||
| HAS_BIAS=HAS_BIAS, | ||
| BLOCK_SIZE_K=BLOCK_SIZE_K, | ||
| block_k_diviable=A.size(1) % BLOCK_SIZE_K == 0, |
There was a problem hiding this comment.
The calculation of block_k_diviable uses A.size(1), which is the input feature dimension K. While correct, it is slightly inconsistent with the kernel call where K is explicitly passed as B.size(2). It would be cleaner to use the same source for K to avoid any confusion, although they are expected to be equal.
| block_k_diviable=A.size(1) % BLOCK_SIZE_K == 0, | |
| block_k_diviable=B.size(2) % BLOCK_SIZE_K == 0, |
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.