-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
fused_moe_kernel opt #38679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fused_moe_kernel opt #38679
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -325,8 +325,6 @@ def fused_moe_kernel( | |||||
| expert_ids_ptr, | ||||||
| num_tokens_post_padded_ptr, | ||||||
| # Matrix dimensions | ||||||
| N, | ||||||
| K, | ||||||
| EM, | ||||||
| num_valid_tokens, | ||||||
| # The stride variables represent how much to increase the ptr by when | ||||||
|
|
@@ -365,6 +363,9 @@ def fused_moe_kernel( | |||||
| use_int8_w8a16: tl.constexpr, | ||||||
| per_channel_quant: tl.constexpr, | ||||||
| HAS_BIAS: tl.constexpr, | ||||||
| block_k_diviable: tl.constexpr, | ||||||
| N: tl.constexpr, | ||||||
| K: tl.constexpr, | ||||||
| ): | ||||||
| """ | ||||||
| Implements the fused computation for a Mixture of Experts (MOE) using | ||||||
|
|
@@ -503,12 +504,20 @@ def fused_moe_kernel( | |||||
| for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | ||||||
| # Load the next block of A and B, generate a mask by checking the | ||||||
| # K dimension. | ||||||
| a = tl.load( | ||||||
| a_ptrs, | ||||||
| mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), | ||||||
| other=0.0, | ||||||
| ) | ||||||
| b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | ||||||
| if block_k_diviable: | ||||||
| a = tl.load( | ||||||
| a_ptrs, | ||||||
| mask=token_mask[:, None], | ||||||
| other=0.0, | ||||||
| ) | ||||||
| b = tl.load(b_ptrs) | ||||||
| else: | ||||||
| a = tl.load( | ||||||
| a_ptrs, | ||||||
| mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), | ||||||
| other=0.0, | ||||||
| ) | ||||||
| b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | ||||||
| # We accumulate along the K dimension. | ||||||
| if use_int8_w8a16: | ||||||
| accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) | ||||||
|
|
@@ -802,8 +811,6 @@ def invoke_fused_moe_triton_kernel( | |||||
| sorted_token_ids, | ||||||
| expert_ids, | ||||||
| num_tokens_post_padded, | ||||||
| B.size(1), | ||||||
| B.size(2), | ||||||
| EM, | ||||||
| num_tokens, | ||||||
| A.stride(0), | ||||||
|
|
@@ -822,6 +829,8 @@ def invoke_fused_moe_triton_kernel( | |||||
| B_bias.stride(1) if B_bias is not None else 0, | ||||||
| 0 if block_shape is None else block_shape[0], | ||||||
| 0 if block_shape is None else block_shape[1], | ||||||
| N=B.size(1), | ||||||
| K=B.size(2), | ||||||
| MUL_ROUTED_WEIGHT=mul_routed_weight, | ||||||
| top_k=top_k, | ||||||
| compute_type=compute_type, | ||||||
|
|
@@ -832,6 +841,7 @@ def invoke_fused_moe_triton_kernel( | |||||
| naive_block_assignment=(sorted_token_ids is None), | ||||||
| HAS_BIAS=HAS_BIAS, | ||||||
| BLOCK_SIZE_K=BLOCK_SIZE_K, | ||||||
| block_k_diviable=A.size(1) % BLOCK_SIZE_K == 0, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calculation of
Suggested change
|
||||||
| **config, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The optimization for
block_k_diviable(typo intended to match existing code) correctly avoids theKdimension mask whenKis a multiple ofBLOCK_SIZE_K. However, the load forbin theif block_k_diviablebranch (line 513) omits theother=0.0parameter, whereas theelsebranch (line 520) includes it. While functionally safe due to the modulo% Ninoffs_bnand the final output masking, it is better to maintain consistency or explicitly document whyotheris omitted here. More importantly, the typodiviableinstead ofdivisibleis propagated; while it matchesfused_moe_kernel_gptq_awq, it would be better to fix it in both places if possible, though strictly speaking, I should only suggest changes to the modified lines.