Skip to content
46 changes: 24 additions & 22 deletions vllm/lora/ops/triton_ops/fused_moe_lora_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe.utils import mm_k
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op

Expand Down Expand Up @@ -91,6 +92,7 @@ def _fused_moe_lora_kernel(
USE_GDC: tl.constexpr,
launch_pdl: tl.constexpr,
IS_PRIMARY: tl.constexpr,
EVEN_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
Expand All @@ -105,7 +107,6 @@ def _fused_moe_lora_kernel(
# Early exit for the no moe lora case.
return
max_loras = tl.num_programs(axis=2)
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)

# calculate pid_m,pid_n
pid_sk = pid % SPLIT_K
Expand Down Expand Up @@ -156,25 +157,24 @@ def _fused_moe_lora_kernel(
+ offs_bn[None, :] * stride_bn
)

# accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, grid_k):
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
# pre-fetch lora weight
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
other=0.0,
)
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
accumulator = mm_k(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so much cleaner!

a_ptrs,
b_ptrs,
stride_ak,
stride_bk,
token_mask,
K,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
EVEN_K,
SPLIT_K=SPLIT_K,
CAST_TYPE=None,
b_dtype=b_ptr.dtype.element_ty,
USE_GDC=USE_GDC,
IS_PRIMARY=IS_PRIMARY,
base_k=pid_sk * BLOCK_SIZE_K,
)

if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
Expand Down Expand Up @@ -242,7 +242,7 @@ def _fused_moe_lora_shrink(
}

b_ptr = _get_ptr(lora_a_stacked, device)

EVEN_K = K % (block_size_k * split_k) == 0
grid = lambda META: (
split_k
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
Expand Down Expand Up @@ -282,6 +282,7 @@ def _fused_moe_lora_shrink(
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
IS_PRIMARY=True,
EVEN_K=EVEN_K,
**shrink_config,
)

Expand Down Expand Up @@ -348,7 +349,7 @@ def _fused_moe_lora_expand(
"USE_GDC": use_gdc,
"launch_pdl": use_gdc, # triton kernel metadata
}

EVEN_K = K % (block_size_k * split_k) == 0
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),
Expand Down Expand Up @@ -386,6 +387,7 @@ def _fused_moe_lora_expand(
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
IS_PRIMARY=False,
EVEN_K=EVEN_K,
**expand_config,
)
for i in range(num_slices):
Expand Down
105 changes: 3 additions & 102 deletions vllm/lora/ops/triton_ops/kernel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,109 +4,10 @@
Utilities for Punica kernel construction.
"""

from vllm.model_executor.layers.fused_moe.utils import mm_k
from vllm.triton_utils import tl, triton


@triton.jit
def mm_k(
a_ptr,
b_ptr,
ak_stride,
bk_stride,
offset_k,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr,
USE_GDC: tl.constexpr,
base_k,
):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate, through the K dimension to compute the partial/complete
matrix block product.
If SPLIT_K == 1, the output m x n product is complete.
If SPLIT_K > 1, the thread block computes partial outputs. The partial
outputs are then atomically summed in the caller code.
Args:
a_ptr: Array of pointers, identifying rows of A
b_ptr: Array of pointers, identifying columns of B
ak_stride: K dimension stride of the A matrix
bk_stride: K dimension stride of the B matrix
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without any
masking.
SPLIT_K: Parameter signifying parallelism in the K dimension.
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
USE_GDC: Whether to use PDL. True indicates use.
base_k: Base offset along K dimension for current SPLIT_K group
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

# Step size along K for each iteration
STEP_K = BLOCK_K * SPLIT_K

# Total number of iterations (compile-time constant)
num_iters = tl.cdiv(K, STEP_K)

for k in range(num_iters):
# Current iteration's global K offset
iter_k = k * STEP_K + base_k

# Check if this iteration is completely valid (no masking needed)
block_end = iter_k + BLOCK_K

if EVEN_K:
# K is divisible by BLOCK_K, no masking ever needed
# pre-fetch lora weight
tiled_b = tl.load(b_ptr)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else:
# Check if we need element-wise masking
if iter_k >= K:
# Entire block out of range, skip
pass
elif block_end <= K:
# Entire block in range, no masking needed (fast path)
tiled_b = tl.load(b_ptr)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else:
# Partial block, need masking (only last iteration)
k_offsets = tl.arange(0, BLOCK_K)
mask = iter_k + k_offsets < K
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)

a_ptr += STEP_K * ak_stride
b_ptr += STEP_K * bk_stride

return accumulator


@triton.jit
def do_expand_kernel(
pid_n,
Expand Down Expand Up @@ -199,7 +100,7 @@ def do_expand_kernel(
b_ptr,
input_d2_stride,
cur_lora_d2_stride,
offset_k,
tl.full((BLOCK_M,), 1, tl.int1),
K,
BLOCK_M,
BLOCK_N,
Expand Down Expand Up @@ -306,7 +207,7 @@ def do_shrink_kernel(
b_ptr,
input_d1_stride,
lora_d2_stride,
offset_k,
tl.full((BLOCK_M,), 1, tl.int1),
K,
BLOCK_M,
BLOCK_N,
Expand Down
Loading