Skip to content

Commit 2255bca

Browse files
committed
init
1 parent 2bb3e9e commit 2255bca

7 files changed

Lines changed: 257 additions & 81 deletions

File tree

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -256,33 +256,36 @@ class FusedMoeLauncher {
256256
Tensor num_non_exiting_ctas;
257257

258258
void prepare_routing_common() {
259+
int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts;
260+
int32_t const totalNumExperts = args->num_experts + args->num_fused_shared_experts;
261+
259262
// Allocate routing phase workspace tensors
260-
num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, hidden_states.device());
263+
num_tokens_per_expert = alloc_tensor({totalNumExperts}, dl_int32, hidden_states.device());
261264
int32_t max_num_padded_tokens =
262265
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
263-
args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim);
266+
args->num_tokens, totalExpertsPerToken, totalNumExperts, tile_tokens_dim);
264267

265268
total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states.device());
266269

267270
expanded_idx_to_permuted_idx =
268-
alloc_tensor({args->num_tokens * args->top_k}, dl_int32, hidden_states.device());
271+
alloc_tensor({args->num_tokens * totalExpertsPerToken}, dl_int32, hidden_states.device());
269272

270273
permuted_idx_to_token_idx =
271274
alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device());
272275

273276
expert_indexes =
274-
alloc_tensor({args->num_tokens, args->top_k}, dl_int32, hidden_states.device());
277+
alloc_tensor({args->num_tokens, totalExpertsPerToken}, dl_int32, hidden_states.device());
275278

276279
// expert_weights allocation should be done by derived class since data type could vary
277280

278-
int64_t const size_of_expert_count_histogram = std::max(args->num_experts * 2, 256 * 2);
281+
int64_t const size_of_expert_count_histogram = std::max(totalNumExperts * 2, 256 * 2);
279282
expert_count_histogram = alloc_tensor({size_of_expert_count_histogram},
280283
dl_int32, // 256 is the max number of threads per block
281284
// and max number of experts
282285
hidden_states.device());
283286

284287
int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
285-
args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim);
288+
args->num_tokens, totalExpertsPerToken, totalNumExperts, tile_tokens_dim);
286289

287290
cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device());
288291

@@ -334,14 +337,17 @@ class FusedMoeLauncher {
334337
this->activation_type, this->use_shuffled_weight, this->weight_layout);
335338
}
336339

340+
int32_t const effectiveTopK = args->top_k + args->num_fused_shared_experts;
341+
int32_t const effectiveLocalExperts = args->local_num_experts + args->num_fused_shared_experts;
342+
337343
if (moe_tactic == -1) {
338-
moe_tactic = moe_runner->getDefaultValidConfigIndex(
339-
args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts,
340-
args->num_tokens);
344+
moe_tactic = moe_runner->getDefaultValidConfigIndex(effectiveTopK, args->hidden_size,
345+
args->intermediate_size,
346+
effectiveLocalExperts, args->num_tokens);
341347
}
342348
auto valid_cfgs =
343-
moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size,
344-
args->local_num_experts, args->num_tokens);
349+
moe_runner->getValidConfigIndices(effectiveTopK, args->hidden_size, args->intermediate_size,
350+
effectiveLocalExperts, args->num_tokens);
345351
auto valid_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), moe_tactic);
346352
FLASHINFER_CHECK(valid_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic,
347353
" for tile_N=", tile_tokens_dim, ". Number of valid tactics for this tile is ",
@@ -377,8 +383,8 @@ class FusedMoeLauncher {
377383

378384
routing_runner.run(
379385
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
380-
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
381-
args->routed_scaling_factor, workspace.routing_expert_indexes,
386+
args->num_fused_shared_experts, args->n_group, args->topk_group, args->local_expert_offset,
387+
args->local_num_experts, args->routed_scaling_factor, workspace.routing_expert_indexes,
382388
static_cast<int*>(expert_count_histogram.data_ptr()),
383389
static_cast<int*>(total_num_padded_tokens.data_ptr()),
384390
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
@@ -910,12 +916,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
910916
auto const routing_bias_dtype =
911917
routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16;
912918
mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
919+
int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts;
913920
// Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
914921
bool has_precomputed_weights = expert_weights.ndim() == 2 && expert_weights.size(0) > 0;
915922
if (!has_precomputed_weights) {
916923
// Allocate expert_weights buffer for routing output
917-
FusedMoeLauncher::expert_weights =
918-
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
924+
FusedMoeLauncher::expert_weights = alloc_tensor({args->num_tokens, totalExpertsPerToken},
925+
dl_bfloat16, hidden_states.device());
919926
workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr();
920927
} else {
921928
workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr());
@@ -946,12 +953,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
946953
TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8.";
947954
TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8.";
948955

956+
int64_t const totalLocalExperts = args->local_num_experts + args->num_fused_shared_experts;
949957
if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
950958
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32)
951959
<< "gemm1_weights_scale must be float.";
952960
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D.";
953-
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts)
954-
<< "gemm1_weights_scale has incorrect shape.";
961+
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), totalLocalExperts)
962+
<< "gemm1_weights_scale has incorrect dim 0.";
955963
TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0)
956964
<< "intermediate_size must be a multiple of 128.";
957965
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1),
@@ -971,8 +979,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
971979
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32)
972980
<< "gemm2_weights_scale must be float.";
973981
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D.";
974-
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts)
975-
<< "gemm2_weights_scale has incorrect shape.";
982+
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), totalLocalExperts)
983+
<< "gemm2_weights_scale has incorrect dim 0.";
976984
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128)
977985
<< "gemm2_weights_scale has incorrect shape.";
978986
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128)
@@ -1082,8 +1090,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
10821090
// routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes
10831091
routing_runner.run(
10841092
use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens,
1085-
args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset,
1086-
args->local_num_experts, args->routed_scaling_factor, workspace.routing_expert_indexes,
1093+
args->num_experts, args->top_k, args->num_fused_shared_experts, args->n_group,
1094+
args->topk_group, args->local_expert_offset, args->local_num_experts,
1095+
args->routed_scaling_factor, workspace.routing_expert_indexes,
10871096
static_cast<int*>(expert_count_histogram.data_ptr()),
10881097
static_cast<int*>(total_num_padded_tokens.data_ptr()),
10891098
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
@@ -1545,8 +1554,9 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
15451554

15461555
routing_runner.run(
15471556
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
1548-
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
1549-
args->routed_scaling_factor, static_cast<int*>(expert_indices.data_ptr()),
1557+
args->num_fused_shared_experts, args->n_group, args->topk_group, args->local_expert_offset,
1558+
args->local_num_experts, args->routed_scaling_factor,
1559+
static_cast<int*>(expert_indices.data_ptr()),
15501560
static_cast<int*>(expert_count_histogram.data_ptr()),
15511561
static_cast<int*>(total_num_padded_tokens.data_ptr()),
15521562
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
@@ -1779,10 +1789,11 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
17791789
Optional<TensorView> routing_bias, TensorView hidden_states, TensorView hidden_states_scale,
17801790
TensorView gemm1_weights, TensorView gemm1_weights_scale, TensorView gemm2_weights,
17811791
TensorView gemm2_weights_scale, TensorView output, int64_t num_experts, int64_t top_k,
1782-
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
1783-
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
1784-
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
1785-
bool enable_pdl, Array<int64_t> config_index, Fp8QuantizationType quantization_type) {
1792+
Optional<int64_t> num_fused_shared_experts, Optional<int64_t> n_group,
1793+
Optional<int64_t> topk_group, int64_t intermediate_size, int64_t local_expert_offset,
1794+
int64_t local_num_experts, Optional<double> routed_scaling_factor, int64_t routing_method_type,
1795+
bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, bool enable_pdl,
1796+
Array<int64_t> config_index, Fp8QuantizationType quantization_type) {
17861797
// Basic type validation
17871798
auto dtype = hidden_states.dtype();
17881799

@@ -1843,9 +1854,13 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
18431854
auto const num_tokens = hidden_states.size(0);
18441855
auto const hidden_size = hidden_states.size(1);
18451856

1857+
int64_t const nFusedShared = num_fused_shared_experts.value_or(0);
1858+
int64_t const totalExpertsPerToken = top_k + nFusedShared;
1859+
int64_t const totalLocalExperts = local_num_experts + nFusedShared;
1860+
18461861
auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type);
1847-
std::set<int32_t> selected_tile_nums =
1848-
computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts);
1862+
std::set<int32_t> selected_tile_nums = computeSelectedTileN(
1863+
supported_tile_nums, num_tokens, totalExpertsPerToken, totalLocalExperts);
18491864

18501865
// Create a map of launchers for each tile size
18511866
std::unordered_map<int32_t, std::unique_ptr<Fp8BlockScaleLauncher>> launchers_map;
@@ -1855,6 +1870,7 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
18551870
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
18561871
args->num_tokens = num_tokens;
18571872
args->num_experts = num_experts;
1873+
args->num_fused_shared_experts = nFusedShared;
18581874
args->hidden_size = hidden_size;
18591875
args->hidden_size_output = args->hidden_size;
18601876
args->top_k = top_k;

csrc/trtllm_fused_moe_routing_deepseek.cu

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,17 +250,28 @@ __global__ void routingMainKernel(KernelParams params) {
250250
auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm};
251251

252252
// write expert idx out already
253-
auto idxTopK = blockIdx.x * params.mTopK + laneIdx;
253+
auto idxTopK = blockIdx.x * params.mTotalExpertsPerToken + laneIdx;
254+
auto idxShared = blockIdx.x * params.mTotalExpertsPerToken + params.mTopK + laneIdx;
254255
if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) {
255256
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(finalScore),
256257
static_cast<int16_t>(expertIdx)};
257258
params.mPtrTopKPacked[idxTopK] = packedScore;
258259
}
259260

261+
if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKPacked != nullptr) {
262+
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(1.0F),
263+
static_cast<int16_t>(params.mNumExperts + laneIdx)};
264+
params.mPtrTopKPacked[idxShared] = packedScore;
265+
}
266+
260267
if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr &&
261268
params.mPtrTopKIds == nullptr) {
262269
params.mPtrTopKWeights[idxTopK] = finalScore;
263270
}
271+
272+
if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKWeights != nullptr) {
273+
params.mPtrTopKWeights[idxShared] = static_cast<OutputT>(1.0F);
274+
}
264275
}
265276
}
266277
}
@@ -561,6 +572,11 @@ void runImpl(Data& data, void* stream) {
561572
FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups,
562573
"Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups,
563574
data.mNumLimitedGroups);
575+
576+
int const numExperts = data.mNumExperts + data.mNumFusedSharedExperts;
577+
int const topK = data.mTopK + data.mNumFusedSharedExperts;
578+
int const numThreadsHist = getMaxNumExperts(numExperts);
579+
564580
// Test limits according to values passed in launch, see definition of LAUNCH_ROUTING_DEEPSEEK
565581
if (data.mNumExperts <= NumKimiK2Experts) {
566582
FLASHINFER_CHECK(
@@ -573,6 +589,9 @@ void runImpl(Data& data, void* stream) {
573589
"When NumExperts > NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d",
574590
MaxSupportedTopExperts, data.mTopK);
575591
}
592+
FLASHINFER_CHECK(topK <= MaxSupportedTopExperts,
593+
"Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts,
594+
topK);
576595
FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d",
577596
data.mTopK);
578597
FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize,
@@ -598,14 +617,19 @@ void runImpl(Data& data, void* stream) {
598617
data.mNumExperts / data.mNumExpertGroups <= WarpSize,
599618
"Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d",
600619
data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups);
620+
621+
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
622+
"Number of fused shared experts (%d) must be less than warp size.",
623+
data.mNumFusedSharedExperts);
601624
}
602625
FLASHINFER_CHECK(data.mNumExperts % 4 == 0,
603626
"Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
604627

605628
int const numBlocks = data.mNumTokens;
606-
int const numThreadsHist = getMaxNumExperts(data.mNumExperts);
607629

608-
bool const useSingleCluster = data.mNumTokens <= 1024;
630+
int numThreadsPerCluster = numThreadsHist * NumBlocksPerCluster;
631+
bool const useSingleCluster =
632+
data.mNumTokens <= 1024 && data.mNumTokens * topK <= numThreadsPerCluster;
609633
if (!useSingleCluster) {
610634
// Reset the global histograms (not used in single-cluster code path).
611635
// Cover both for the cooperative and two-kernel code paths.
@@ -629,7 +653,7 @@ void runImpl(Data& data, void* stream) {
629653
int const numBlocksCoop = 128;
630654

631655
// Maximum number of tokens supported by the kernel using a cooperative launch.
632-
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
656+
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / topK;
633657
if (data.mPtrTopKIds == nullptr) {
634658
int const numThreadsMain =
635659
max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts));
@@ -645,6 +669,12 @@ void runImpl(Data& data, void* stream) {
645669
stream, data.mNumExpertGroups > 1);
646670
}
647671

672+
if (data.mNumFusedSharedExperts > 0) {
673+
data.mNumExperts += data.mNumFusedSharedExperts;
674+
data.mTopK += data.mNumFusedSharedExperts;
675+
data.mNumLocalExperts += data.mNumFusedSharedExperts;
676+
}
677+
648678
if (data.mPtrPermutedIdxSize != nullptr) {
649679
if (useSingleCluster) {
650680
LAUNCH_ROUTING_DEEPSEEK(data,
@@ -659,7 +689,7 @@ void runImpl(Data& data, void* stream) {
659689
/*smemSize=*/0, // No dynamic smem
660690
stream, data.mNumExpertGroups > 1);
661691
} else {
662-
const int32_t expandedIdxSize = data.mNumTokens * data.mTopK;
692+
const int32_t expandedIdxSize = data.mNumTokens * topK;
663693
const int32_t histogramEltsPerBlock = 8 * numThreadsHist;
664694
const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist;
665695

0 commit comments

Comments
 (0)