Skip to content

Commit 45dfceb

Browse files
committed
fix routed
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Made-with: Cursor
1 parent 47aae4c commit 45dfceb

2 files changed

Lines changed: 14 additions & 1 deletion

File tree

csrc/fused_moe/trtllm_backend/routingRenormalize/launchBlockKernel.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa
103103
}
104104
}
105105
} // end if (validToken)
106+
} else if (params.mPtrTopKPacked != nullptr) {
107+
if (validToken) {
108+
if (laneIdx < params.mTopK) {
109+
int offset = warpIdx * MaxNumExperts +
110+
static_cast<int>(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].idx);
111+
smemKIdx[offset] = static_cast<int8_t>(laneIdx);
112+
if (params.mPtrTopKWeights != nullptr) {
113+
params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] =
114+
static_cast<OutputT>(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].score);
115+
}
116+
}
117+
}
106118
}
107119
__syncthreads();
108120

csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_renormalize.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ void run(Data const& data, void* stream) {
5353
FLASHINFER_CHECK(data.mNumExperts % 4 == 0,
5454
"Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
5555

56-
bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens;
56+
bool const useSingleBlock =
57+
data.mNumTokens <= BlockKernelMaxNumTokens && data.mPtrTopKPacked == nullptr;
5758

5859
bool const useSingleCluster =
5960
data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr)

0 commit comments

Comments
 (0)