-
-
Notifications
You must be signed in to change notification settings - Fork 12k
Enable GDC for regular Triton MoE by calling mm_k from Lora
#30673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the MoE kernels to use a shared mm_k function, which enables Programmatic Dependent Launch (GDC) for regular Triton MoE and adds quantization support. The changes are well-structured and the core logic for enabling GDC seems correct. I've identified a significant code duplication in the new mm_k function that should be addressed to improve maintainability. Other than that, the changes look good.
| def mm_k( | ||
| a_ptr, | ||
| b_ptr, | ||
| ak_stride, | ||
| bk_stride, | ||
| token_mask, | ||
| K: tl.constexpr, | ||
| BLOCK_M: tl.constexpr, | ||
| BLOCK_N: tl.constexpr, | ||
| BLOCK_K: tl.constexpr, | ||
| EVEN_K: tl.constexpr, | ||
| SPLIT_K: tl.constexpr, | ||
| CAST_TYPE: tl.constexpr, | ||
| b_dtype: tl.constexpr, | ||
| USE_GDC: tl.constexpr, | ||
| base_k, | ||
| a_scale_ptrs=None, | ||
| b_scale_ptrs=None, | ||
| stride_ask=0, | ||
| stride_bsk=0, | ||
| group_k=0, | ||
| group_n=0, | ||
| use_int8_w8a16: tl.constexpr = False, | ||
| use_fp8_w8a8: tl.constexpr = False, | ||
| use_int8_w8a8: tl.constexpr = False, | ||
| compute_type: tl.constexpr = tl.float16, | ||
| IS_PRIMARY: tl.constexpr = False, | ||
| ): | ||
| """ | ||
| Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of | ||
| B (k x n), iterate, through the K dimension to compute the partial/complete | ||
| matrix block product. | ||
| If SPLIT_K == 1, the output m x n product is complete. | ||
| If SPLIT_K > 1, the thread block computes partial outputs. The partial | ||
| outputs are then atomically summed in the caller code. | ||
| Args: | ||
| a_ptr: Array of pointers, identifying rows of A | ||
| b_ptr: Array of pointers, identifying columns of B | ||
| ak_stride: K dimension stride of the A matrix | ||
| bk_stride: K dimension stride of the B matrix | ||
| K: Length of the K dimension | ||
| BLOCK_M: M dimension of the output block m x n | ||
| BLOCK_N: N dimension of the output block m x n | ||
| BLOCK_K: K dimension atom | ||
| EVEN_K: True if the blocks of A and B can be loaded without any | ||
| masking. | ||
| SPLIT_K: Parameter signifying parallelism in the K dimension. | ||
| CAST_TYPE: if True, cast the values from the A matrix to the B | ||
| matrix dtype. | ||
| b_dtype: datatype of the B matrix | ||
| USE_GDC: Whether to use PDL. True indicates use. | ||
| base_k: Base offset along K dimension for current SPLIT_K group | ||
| """ | ||
| accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | ||
| # Step size along K for each iteration | ||
| STEP_K = BLOCK_K * SPLIT_K | ||
|
|
||
| # Total number of iterations (compile-time constant) | ||
| num_iters = tl.cdiv(K, STEP_K) | ||
|
|
||
| for k in range(num_iters): | ||
| # Current iteration's global K offset | ||
| iter_k = k * STEP_K + base_k | ||
|
|
||
| # Check if this iteration is completely valid (no masking needed) | ||
| block_end = iter_k + BLOCK_K | ||
|
|
||
| if EVEN_K: | ||
| # K is divisible by BLOCK_K, no masking ever needed | ||
| # pre-fetch lora weight | ||
| tiled_b = tl.load(b_ptr) | ||
| if USE_GDC and not IS_PRIMARY: | ||
| tl.extra.cuda.gdc_wait() | ||
| tiled_a = tl.load(a_ptr, mask=token_mask[:, None], other=0.0) | ||
| if CAST_TYPE: | ||
| tiled_a = tiled_a.to(b_dtype) | ||
| if use_int8_w8a16: | ||
| accumulator = tl.dot(tiled_a, tiled_b.to(compute_type), acc=accumulator) | ||
| elif use_fp8_w8a8 or use_int8_w8a8: | ||
| if group_k > 0 and group_n > 0: | ||
| offs_ks = iter_k // group_k | ||
| a_scale = tl.load( | ||
| a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 | ||
| ) | ||
| b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) | ||
|
|
||
| accumulator += ( | ||
| tl.dot(tiled_a, tiled_b) * a_scale[:, None] * b_scale[None, :] | ||
| ) | ||
| else: | ||
| if use_fp8_w8a8: | ||
| # acc used to enable fp8_fast_accum | ||
| accumulator = tl.dot(tiled_a, tiled_b, acc=accumulator) | ||
| else: | ||
| accumulator += tl.dot(tiled_a, tiled_b) | ||
| else: | ||
| accumulator += tl.dot(tiled_a, tiled_b) | ||
| else: | ||
| # Check if we need element-wise masking | ||
| if iter_k >= K: | ||
| # Entire block out of range, skip | ||
| pass | ||
| elif block_end <= K: | ||
| # Entire block in range, no masking needed (fast path) | ||
| tiled_b = tl.load(b_ptr) | ||
| if USE_GDC and not IS_PRIMARY: | ||
| tl.extra.cuda.gdc_wait() | ||
| tiled_a = tl.load(a_ptr, mask=token_mask[:, None], other=0.0) | ||
| if CAST_TYPE: | ||
| tiled_a = tiled_a.to(b_dtype) | ||
| if use_int8_w8a16: | ||
| accumulator = tl.dot( | ||
| tiled_a, tiled_b.to(compute_type), acc=accumulator | ||
| ) | ||
| elif use_fp8_w8a8 or use_int8_w8a8: | ||
| if group_k > 0 and group_n > 0: | ||
| offs_ks = iter_k // group_k | ||
| a_scale = tl.load( | ||
| a_scale_ptrs + offs_ks * stride_ask, | ||
| mask=token_mask, | ||
| other=0.0, | ||
| ) | ||
| b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) | ||
|
|
||
| accumulator += ( | ||
| tl.dot(tiled_a, tiled_b) | ||
| * a_scale[:, None] | ||
| * b_scale[None, :] | ||
| ) | ||
| else: | ||
| if use_fp8_w8a8: | ||
| # acc used to enable fp8_fast_accum | ||
| accumulator = tl.dot(tiled_a, tiled_b, acc=accumulator) | ||
| else: | ||
| accumulator += tl.dot(tiled_a, tiled_b) | ||
| else: | ||
| accumulator += tl.dot(tiled_a, tiled_b) | ||
| else: | ||
| # Partial block, need masking (only last iteration) | ||
| k_offsets = tl.arange(0, BLOCK_K) | ||
| mask = iter_k + k_offsets < K | ||
| tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0) | ||
| if USE_GDC and not IS_PRIMARY: | ||
| tl.extra.cuda.gdc_wait() | ||
| tiled_a = tl.load( | ||
| a_ptr, mask=token_mask[:, None] & mask[None, :], other=0.0 | ||
| ) | ||
| if CAST_TYPE: | ||
| tiled_a = tiled_a.to(b_dtype) | ||
| if use_int8_w8a16: | ||
| accumulator = tl.dot( | ||
| tiled_a, tiled_b.to(compute_type), acc=accumulator | ||
| ) | ||
| elif use_fp8_w8a8 or use_int8_w8a8: | ||
| if group_k > 0 and group_n > 0: | ||
| offs_ks = iter_k // group_k | ||
| a_scale = tl.load( | ||
| a_scale_ptrs + offs_ks * stride_ask, | ||
| mask=token_mask, | ||
| other=0.0, | ||
| ) | ||
| b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) | ||
|
|
||
| accumulator += ( | ||
| tl.dot(tiled_a, tiled_b) | ||
| * a_scale[:, None] | ||
| * b_scale[None, :] | ||
| ) | ||
| else: | ||
| if use_fp8_w8a8: | ||
| # acc used to enable fp8_fast_accum | ||
| accumulator = tl.dot(tiled_a, tiled_b, acc=accumulator) | ||
| else: | ||
| accumulator += tl.dot(tiled_a, tiled_b) | ||
| else: | ||
| accumulator += tl.dot(tiled_a, tiled_b) | ||
|
|
||
| a_ptr += STEP_K * ak_stride | ||
| b_ptr += STEP_K * bk_stride | ||
|
|
||
| return accumulator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The matrix multiplication logic, including quantization handling, is duplicated three times within this function: once for the EVEN_K case, once for the block_end <= K case, and once for the partial block case. This makes the code difficult to maintain and prone to errors, as any change needs to be applied in all three places.
To improve this, I suggest refactoring the accumulation logic. One way is to restructure the loop to first determine the masks and load tiled_a and tiled_b, and then have a single, non-duplicated block for the accumulation logic.
An alternative is to extract the accumulation logic into a separate triton.jit helper function. This function would take the loaded tiled_a and tiled_b tensors and perform the dot product and accumulation based on the quantization settings.
This refactoring would significantly reduce code duplication and improve maintainability without impacting performance, as the branching on EVEN_K (a constexpr) would still be optimized at compile time.
Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
| # Advance the ptrs to the next K block. | ||
| a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak | ||
| b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk | ||
| accumulator = mm_k( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is so much cleaner!
Purpose
mm_kpreviously called in lora shrink and expandmm_kEVEN_Kin fused moe loraTest Plan
Test Result
After retune the regular fused moe after enalbe GDC, for Maverick model, we obtain some small performance gain
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.