Skip to content

Conversation

@RunkaiTao
Copy link

@RunkaiTao RunkaiTao commented Dec 15, 2025

Purpose

  1. refactor the k-dim iterations in both regular fused moe and fused moe lora by using mm_k previously called in lora shrink and expand
  2. enable GDC in regular Triton MoE
  3. add quantization supports in mm_k
  4. enable EVEN_K in fused moe lora

Test Plan

  1. pass the unit tests for lora shrink and expand kernels, fused moe lora and fused moe kernel.
  2. test the speedup after enabling GDC.

Test Result

After retune the regular fused moe after enalbe GDC, for Maverick model, we obtain some small performance gain

python3 benchmarks/kernels/benchmark_moe.py --model meta-llama/Llama-4-Maverick-17B-128E-Instruct --tp-size 8
Batch size Kernel time v1 (us) Kernel time v2 (us) Speedup (%)
1 33.05 27.43 17.00
2 35.82 36.45 -1.76
4 49.27 49.20 0.14
8 76.39 76.39 0.00
16 136.45 131.07 3.94
24 180.71 180.60 0.06
32 237.79 228.94 3.72
48 309.78 311.03 -0.10
64 382.93 382.50 0.11
96 513.61 513.58 0.01
128 601.64 601.69 -0.01
256 827.99 812.03 1.93
512 923.94 923.92 0.00
1024 959.62 959.69 -0.01
1536 998.23 989.96 0.83
2048 1017.80 1007.44 1.02
3072 1063.08 1057.63 0.51
4096 1111.64 1108.50 0.28
---
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MoE kernels to use a shared mm_k function, which enables Programmatic Dependent Launch (GDC) for regular Triton MoE and adds quantization support. The changes are well-structured and the core logic for enabling GDC seems correct. I've identified a significant code duplication in the new mm_k function that should be addressed to improve maintainability. Other than that, the changes look good.

Comment on lines 347 to 527
def mm_k(
a_ptr,
b_ptr,
ak_stride,
bk_stride,
token_mask,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr,
USE_GDC: tl.constexpr,
base_k,
a_scale_ptrs=None,
b_scale_ptrs=None,
stride_ask=0,
stride_bsk=0,
group_k=0,
group_n=0,
use_int8_w8a16: tl.constexpr = False,
use_fp8_w8a8: tl.constexpr = False,
use_int8_w8a8: tl.constexpr = False,
compute_type: tl.constexpr = tl.float16,
IS_PRIMARY: tl.constexpr = False,
):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate, through the K dimension to compute the partial/complete
matrix block product.
If SPLIT_K == 1, the output m x n product is complete.
If SPLIT_K > 1, the thread block computes partial outputs. The partial
outputs are then atomically summed in the caller code.
Args:
a_ptr: Array of pointers, identifying rows of A
b_ptr: Array of pointers, identifying columns of B
ak_stride: K dimension stride of the A matrix
bk_stride: K dimension stride of the B matrix
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without any
masking.
SPLIT_K: Parameter signifying parallelism in the K dimension.
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
USE_GDC: Whether to use PDL. True indicates use.
base_k: Base offset along K dimension for current SPLIT_K group
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Step size along K for each iteration
STEP_K = BLOCK_K * SPLIT_K

# Total number of iterations (compile-time constant)
num_iters = tl.cdiv(K, STEP_K)

for k in range(num_iters):
# Current iteration's global K offset
iter_k = k * STEP_K + base_k

# Check if this iteration is completely valid (no masking needed)
block_end = iter_k + BLOCK_K

if EVEN_K:
# K is divisible by BLOCK_K, no masking ever needed
# pre-fetch lora weight
tiled_b = tl.load(b_ptr)
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr, mask=token_mask[:, None], other=0.0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
if use_int8_w8a16:
accumulator = tl.dot(tiled_a, tiled_b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
offs_ks = iter_k // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)

accumulator += (
tl.dot(tiled_a, tiled_b) * a_scale[:, None] * b_scale[None, :]
)
else:
if use_fp8_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(tiled_a, tiled_b, acc=accumulator)
else:
accumulator += tl.dot(tiled_a, tiled_b)
else:
accumulator += tl.dot(tiled_a, tiled_b)
else:
# Check if we need element-wise masking
if iter_k >= K:
# Entire block out of range, skip
pass
elif block_end <= K:
# Entire block in range, no masking needed (fast path)
tiled_b = tl.load(b_ptr)
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr, mask=token_mask[:, None], other=0.0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
if use_int8_w8a16:
accumulator = tl.dot(
tiled_a, tiled_b.to(compute_type), acc=accumulator
)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
offs_ks = iter_k // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask,
mask=token_mask,
other=0.0,
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)

accumulator += (
tl.dot(tiled_a, tiled_b)
* a_scale[:, None]
* b_scale[None, :]
)
else:
if use_fp8_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(tiled_a, tiled_b, acc=accumulator)
else:
accumulator += tl.dot(tiled_a, tiled_b)
else:
accumulator += tl.dot(tiled_a, tiled_b)
else:
# Partial block, need masking (only last iteration)
k_offsets = tl.arange(0, BLOCK_K)
mask = iter_k + k_offsets < K
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(
a_ptr, mask=token_mask[:, None] & mask[None, :], other=0.0
)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
if use_int8_w8a16:
accumulator = tl.dot(
tiled_a, tiled_b.to(compute_type), acc=accumulator
)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
offs_ks = iter_k // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask,
mask=token_mask,
other=0.0,
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)

accumulator += (
tl.dot(tiled_a, tiled_b)
* a_scale[:, None]
* b_scale[None, :]
)
else:
if use_fp8_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(tiled_a, tiled_b, acc=accumulator)
else:
accumulator += tl.dot(tiled_a, tiled_b)
else:
accumulator += tl.dot(tiled_a, tiled_b)

a_ptr += STEP_K * ak_stride
b_ptr += STEP_K * bk_stride

return accumulator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The matrix multiplication logic, including quantization handling, is duplicated three times within this function: once for the EVEN_K case, once for the block_end <= K case, and once for the partial block case. This makes the code difficult to maintain and prone to errors, as any change needs to be applied in all three places.

To improve this, I suggest refactoring the accumulation logic. One way is to restructure the loop to first determine the masks and load tiled_a and tiled_b, and then have a single, non-duplicated block for the accumulation logic.

An alternative is to extract the accumulation logic into a separate triton.jit helper function. This function would take the loaded tiled_a and tiled_b tensors and perform the dot product and accumulation based on the quantization settings.

This refactoring would significantly reduce code duplication and improve maintainability without impacting performance, as the branching on EVEN_K (a constexpr) would still be optimized at compile time.

Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
Signed-off-by: Runkai Tao <[email protected]>
@RunkaiTao RunkaiTao marked this pull request as ready for review December 15, 2025 20:59
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
accumulator = mm_k(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so much cleaner!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants