Skip to content

Commit 259d279

Browse files
committed
init
1 parent 26ef055 commit 259d279

7 files changed

Lines changed: 258 additions & 81 deletions

File tree

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -256,33 +256,36 @@ class FusedMoeLauncher {
256256
Tensor num_non_exiting_ctas;
257257

258258
void prepare_routing_common() {
259+
int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts;
260+
int32_t const totalNumExperts = args->num_experts + args->num_fused_shared_experts;
261+
259262
// Allocate routing phase workspace tensors
260-
num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, hidden_states.device());
263+
num_tokens_per_expert = alloc_tensor({totalNumExperts}, dl_int32, hidden_states.device());
261264
int32_t max_num_padded_tokens =
262265
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
263-
args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim);
266+
args->num_tokens, totalExpertsPerToken, totalNumExperts, tile_tokens_dim);
264267

265268
total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states.device());
266269

267270
expanded_idx_to_permuted_idx =
268-
alloc_tensor({args->num_tokens * args->top_k}, dl_int32, hidden_states.device());
271+
alloc_tensor({args->num_tokens * totalExpertsPerToken}, dl_int32, hidden_states.device());
269272

270273
permuted_idx_to_token_idx =
271274
alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device());
272275

273276
expert_indexes =
274-
alloc_tensor({args->num_tokens, args->top_k}, dl_int32, hidden_states.device());
277+
alloc_tensor({args->num_tokens, totalExpertsPerToken}, dl_int32, hidden_states.device());
275278

276279
// expert_weights allocation should be done by derived class since data type could vary
277280

278-
int64_t const size_of_expert_count_histogram = std::max(args->num_experts * 2, 256 * 2);
281+
int64_t const size_of_expert_count_histogram = std::max(totalNumExperts * 2, 256 * 2);
279282
expert_count_histogram = alloc_tensor({size_of_expert_count_histogram},
280283
dl_int32, // 256 is the max number of threads per block
281284
// and max number of experts
282285
hidden_states.device());
283286

284287
int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
285-
args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim);
288+
args->num_tokens, totalExpertsPerToken, totalNumExperts, tile_tokens_dim);
286289

287290
cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device());
288291

@@ -334,14 +337,17 @@ class FusedMoeLauncher {
334337
this->activation_type, this->use_shuffled_weight, this->weight_layout);
335338
}
336339

340+
int32_t const effectiveTopK = args->top_k + args->num_fused_shared_experts;
341+
int32_t const effectiveLocalExperts = args->local_num_experts + args->num_fused_shared_experts;
342+
337343
if (moe_tactic == -1) {
338-
moe_tactic = moe_runner->getDefaultValidConfigIndex(
339-
args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts,
340-
args->num_tokens);
344+
moe_tactic = moe_runner->getDefaultValidConfigIndex(effectiveTopK, args->hidden_size,
345+
args->intermediate_size,
346+
effectiveLocalExperts, args->num_tokens);
341347
}
342348
auto valid_cfgs =
343-
moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size,
344-
args->local_num_experts, args->num_tokens);
349+
moe_runner->getValidConfigIndices(effectiveTopK, args->hidden_size, args->intermediate_size,
350+
effectiveLocalExperts, args->num_tokens);
345351
auto valid_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), moe_tactic);
346352
FLASHINFER_CHECK(valid_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic,
347353
" for tile_N=", tile_tokens_dim, ". Number of valid tactics for this tile is ",
@@ -377,8 +383,9 @@ class FusedMoeLauncher {
377383

378384
routing_runner.run(
379385
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
380-
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
381-
args->routed_scaling_factor, static_cast<int*>(expert_indexes.data_ptr()),
386+
args->num_fused_shared_experts, args->n_group, args->topk_group, args->local_expert_offset,
387+
args->local_num_experts, args->routed_scaling_factor,
388+
static_cast<int*>(expert_indexes.data_ptr()),
382389
static_cast<int*>(expert_count_histogram.data_ptr()),
383390
static_cast<int*>(total_num_padded_tokens.data_ptr()),
384391
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
@@ -882,12 +889,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
882889
auto const routing_bias_dtype =
883890
routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16;
884891
mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
892+
int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts;
885893
// Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
886894
bool has_precomputed_weights = expert_weights.ndim() == 2 && expert_weights.size(0) > 0;
887895
if (!has_precomputed_weights) {
888896
// Allocate expert_weights buffer for routing output
889-
FusedMoeLauncher::expert_weights =
890-
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
897+
FusedMoeLauncher::expert_weights = alloc_tensor({args->num_tokens, totalExpertsPerToken},
898+
dl_bfloat16, hidden_states.device());
891899
workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr();
892900
} else {
893901
workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr());
@@ -918,12 +926,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
918926
TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8.";
919927
TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8.";
920928

929+
int64_t const totalLocalExperts = args->local_num_experts + args->num_fused_shared_experts;
921930
if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
922931
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32)
923932
<< "gemm1_weights_scale must be float.";
924933
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D.";
925-
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts)
926-
<< "gemm1_weights_scale has incorrect shape.";
934+
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), totalLocalExperts)
935+
<< "gemm1_weights_scale has incorrect dim 0.";
927936
TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0)
928937
<< "intermediate_size must be a multiple of 128.";
929938
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1),
@@ -943,8 +952,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
943952
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32)
944953
<< "gemm2_weights_scale must be float.";
945954
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D.";
946-
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts)
947-
<< "gemm2_weights_scale has incorrect shape.";
955+
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), totalLocalExperts)
956+
<< "gemm2_weights_scale has incorrect dim 0.";
948957
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128)
949958
<< "gemm2_weights_scale has incorrect shape.";
950959
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128)
@@ -1054,8 +1063,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
10541063
// routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes
10551064
routing_runner.run(
10561065
use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens,
1057-
args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset,
1058-
args->local_num_experts, args->routed_scaling_factor, workspace.routing_expert_indexes,
1066+
args->num_experts, args->top_k, args->num_fused_shared_experts, args->n_group,
1067+
args->topk_group, args->local_expert_offset, args->local_num_experts,
1068+
args->routed_scaling_factor, workspace.routing_expert_indexes,
10591069
static_cast<int*>(expert_count_histogram.data_ptr()),
10601070
static_cast<int*>(total_num_padded_tokens.data_ptr()),
10611071
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
@@ -1517,8 +1527,9 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
15171527

15181528
routing_runner.run(
15191529
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
1520-
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
1521-
args->routed_scaling_factor, static_cast<int*>(expert_indices.data_ptr()),
1530+
args->num_fused_shared_experts, args->n_group, args->topk_group, args->local_expert_offset,
1531+
args->local_num_experts, args->routed_scaling_factor,
1532+
static_cast<int*>(expert_indices.data_ptr()),
15221533
static_cast<int*>(expert_count_histogram.data_ptr()),
15231534
static_cast<int*>(total_num_padded_tokens.data_ptr()),
15241535
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
@@ -1746,10 +1757,11 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
17461757
Optional<TensorView> routing_bias, TensorView hidden_states, TensorView hidden_states_scale,
17471758
TensorView gemm1_weights, TensorView gemm1_weights_scale, TensorView gemm2_weights,
17481759
TensorView gemm2_weights_scale, TensorView output, int64_t num_experts, int64_t top_k,
1749-
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
1750-
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
1751-
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
1752-
bool enable_pdl, Array<int64_t> config_index, Fp8QuantizationType quantization_type) {
1760+
Optional<int64_t> num_fused_shared_experts, Optional<int64_t> n_group,
1761+
Optional<int64_t> topk_group, int64_t intermediate_size, int64_t local_expert_offset,
1762+
int64_t local_num_experts, Optional<double> routed_scaling_factor, int64_t routing_method_type,
1763+
bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, bool enable_pdl,
1764+
Array<int64_t> config_index, Fp8QuantizationType quantization_type) {
17531765
// Basic type validation
17541766
auto dtype = hidden_states.dtype();
17551767

@@ -1810,9 +1822,13 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
18101822
auto const num_tokens = hidden_states.size(0);
18111823
auto const hidden_size = hidden_states.size(1);
18121824

1825+
int64_t const nFusedShared = num_fused_shared_experts.value_or(0);
1826+
int64_t const totalExpertsPerToken = top_k + nFusedShared;
1827+
int64_t const totalLocalExperts = local_num_experts + nFusedShared;
1828+
18131829
auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type);
1814-
std::set<int32_t> selected_tile_nums =
1815-
computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts);
1830+
std::set<int32_t> selected_tile_nums = computeSelectedTileN(
1831+
supported_tile_nums, num_tokens, totalExpertsPerToken, totalLocalExperts);
18161832

18171833
// Create a map of launchers for each tile size
18181834
std::unordered_map<int32_t, std::unique_ptr<Fp8BlockScaleLauncher>> launchers_map;
@@ -1822,6 +1838,7 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
18221838
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
18231839
args->num_tokens = num_tokens;
18241840
args->num_experts = num_experts;
1841+
args->num_fused_shared_experts = nFusedShared;
18251842
args->hidden_size = hidden_size;
18261843
args->hidden_size_output = args->hidden_size;
18271844
args->top_k = top_k;

csrc/trtllm_fused_moe_routing_deepseek.cu

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,17 +250,28 @@ __global__ void routingMainKernel(KernelParams params) {
250250
auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm};
251251

252252
// write expert idx out already
253-
auto idxTopK = blockIdx.x * params.mTopK + laneIdx;
253+
auto idxTopK = blockIdx.x * params.mTotalExpertsPerToken + laneIdx;
254+
auto idxShared = blockIdx.x * params.mTotalExpertsPerToken + params.mTopK + laneIdx;
254255
if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) {
255256
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(finalScore),
256257
static_cast<int16_t>(expertIdx)};
257258
params.mPtrTopKPacked[idxTopK] = packedScore;
258259
}
259260

261+
if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKPacked != nullptr) {
262+
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(1.0F),
263+
static_cast<int16_t>(params.mNumExperts + laneIdx)};
264+
params.mPtrTopKPacked[idxShared] = packedScore;
265+
}
266+
260267
if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr &&
261268
params.mPtrTopKIds == nullptr) {
262269
params.mPtrTopKWeights[idxTopK] = finalScore;
263270
}
271+
272+
if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKWeights != nullptr) {
273+
params.mPtrTopKWeights[idxShared] = static_cast<OutputT>(1.0F);
274+
}
264275
}
265276
}
266277
}
@@ -557,6 +568,11 @@ void runImpl(Data& data, void* stream) {
557568
FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups,
558569
"Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups,
559570
data.mNumLimitedGroups);
571+
572+
int const numExperts = data.mNumExperts + data.mNumFusedSharedExperts;
573+
int const topK = data.mTopK + data.mNumFusedSharedExperts;
574+
int const numThreadsHist = getMaxNumExperts(numExperts);
575+
560576
// Test limits according to values passed in launch, see definition of LAUNCH_ROUTING_DEEPSEEK
561577
if (data.mNumExperts <= NumKimiK2Experts) {
562578
FLASHINFER_CHECK(
@@ -569,6 +585,9 @@ void runImpl(Data& data, void* stream) {
569585
"When NumExperts > NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d",
570586
MaxSupportedTopExperts, data.mTopK);
571587
}
588+
FLASHINFER_CHECK(topK <= MaxSupportedTopExperts,
589+
"Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts,
590+
topK);
572591
FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d",
573592
data.mTopK);
574593
FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize,
@@ -594,14 +613,19 @@ void runImpl(Data& data, void* stream) {
594613
data.mNumExperts / data.mNumExpertGroups <= WarpSize,
595614
"Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d",
596615
data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups);
616+
617+
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
618+
"Number of fused shared experts (%d) must be less than warp size.",
619+
data.mNumFusedSharedExperts);
597620
}
598621
FLASHINFER_CHECK(data.mNumExperts % 4 == 0,
599622
"Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
600623

601624
int const numBlocks = data.mNumTokens;
602-
int const numThreadsHist = getMaxNumExperts(data.mNumExperts);
603625

604-
bool const useSingleCluster = data.mNumTokens <= 1024;
626+
int numThreadsPerCluster = numThreadsHist * NumBlocksPerCluster;
627+
bool const useSingleCluster =
628+
data.mNumTokens <= 1024 && data.mNumTokens * topK <= numThreadsPerCluster;
605629
if (!useSingleCluster) {
606630
// Reset the global histograms (not used in single-cluster code path).
607631
// Cover both for the cooperative and two-kernel code paths.
@@ -625,7 +649,7 @@ void runImpl(Data& data, void* stream) {
625649
int const numBlocksCoop = 128;
626650

627651
// Maximum number of tokens supported by the kernel using a cooperative launch.
628-
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
652+
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / topK;
629653
if (data.mPtrTopKIds == nullptr) {
630654
int const numThreadsMain =
631655
max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts));
@@ -641,6 +665,12 @@ void runImpl(Data& data, void* stream) {
641665
stream, data.mNumExpertGroups > 1);
642666
}
643667

668+
if (data.mNumFusedSharedExperts > 0) {
669+
data.mNumExperts += data.mNumFusedSharedExperts;
670+
data.mTopK += data.mNumFusedSharedExperts;
671+
data.mNumLocalExperts += data.mNumFusedSharedExperts;
672+
}
673+
644674
if (data.mPtrPermutedIdxSize != nullptr) {
645675
if (useSingleCluster) {
646676
LAUNCH_ROUTING_DEEPSEEK(data,
@@ -655,7 +685,7 @@ void runImpl(Data& data, void* stream) {
655685
/*smemSize=*/0, // No dynamic smem
656686
stream, data.mNumExpertGroups > 1);
657687
} else {
658-
const int32_t expandedIdxSize = data.mNumTokens * data.mTopK;
688+
const int32_t expandedIdxSize = data.mNumTokens * topK;
659689
const int32_t histogramEltsPerBlock = 8 * numThreadsHist;
660690
const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist;
661691

0 commit comments

Comments
 (0)