@@ -256,33 +256,36 @@ class FusedMoeLauncher {
256256 Tensor num_non_exiting_ctas;
257257
258258 void prepare_routing_common () {
259+ int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts ;
260+ int32_t const totalNumExperts = args->num_experts + args->num_fused_shared_experts ;
261+
259262 // Allocate routing phase workspace tensors
260- num_tokens_per_expert = alloc_tensor ({args-> num_experts }, dl_int32, hidden_states.device ());
263+ num_tokens_per_expert = alloc_tensor ({totalNumExperts }, dl_int32, hidden_states.device ());
261264 int32_t max_num_padded_tokens =
262265 tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount (
263- args->num_tokens , args-> top_k , args-> num_experts , tile_tokens_dim);
266+ args->num_tokens , totalExpertsPerToken, totalNumExperts , tile_tokens_dim);
264267
265268 total_num_padded_tokens = alloc_tensor ({1 }, dl_int32, hidden_states.device ());
266269
267270 expanded_idx_to_permuted_idx =
268- alloc_tensor ({args->num_tokens * args-> top_k }, dl_int32, hidden_states.device ());
271+ alloc_tensor ({args->num_tokens * totalExpertsPerToken }, dl_int32, hidden_states.device ());
269272
270273 permuted_idx_to_token_idx =
271274 alloc_tensor ({max_num_padded_tokens}, dl_int32, hidden_states.device ());
272275
273276 expert_indexes =
274- alloc_tensor ({args->num_tokens , args-> top_k }, dl_int32, hidden_states.device ());
277+ alloc_tensor ({args->num_tokens , totalExpertsPerToken }, dl_int32, hidden_states.device ());
275278
276279 // expert_weights allocation should be done by derived class since data type could vary
277280
278- int64_t const size_of_expert_count_histogram = std::max (args-> num_experts * 2 , 256 * 2 );
281+ int64_t const size_of_expert_count_histogram = std::max (totalNumExperts * 2 , 256 * 2 );
279282 expert_count_histogram = alloc_tensor ({size_of_expert_count_histogram},
280283 dl_int32, // 256 is the max number of threads per block
281284 // and max number of experts
282285 hidden_states.device ());
283286
284287 int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim (
285- args->num_tokens , args-> top_k , args-> num_experts , tile_tokens_dim);
288+ args->num_tokens , totalExpertsPerToken, totalNumExperts , tile_tokens_dim);
286289
287290 cta_idx_xy_to_batch_idx = alloc_tensor ({max_num_ctas}, dl_int32, hidden_states.device ());
288291
@@ -334,14 +337,17 @@ class FusedMoeLauncher {
334337 this ->activation_type , this ->use_shuffled_weight , this ->weight_layout );
335338 }
336339
340+ int32_t const effectiveTopK = args->top_k + args->num_fused_shared_experts ;
341+ int32_t const effectiveLocalExperts = args->local_num_experts + args->num_fused_shared_experts ;
342+
337343 if (moe_tactic == -1 ) {
338- moe_tactic = moe_runner->getDefaultValidConfigIndex (
339- args-> top_k , args-> hidden_size , args-> intermediate_size , args->local_num_experts ,
340- args->num_tokens );
344+ moe_tactic = moe_runner->getDefaultValidConfigIndex (effectiveTopK, args-> hidden_size ,
345+ args->intermediate_size ,
346+ effectiveLocalExperts, args->num_tokens );
341347 }
342348 auto valid_cfgs =
343- moe_runner->getValidConfigIndices (args-> top_k , args->hidden_size , args->intermediate_size ,
344- args-> local_num_experts , args->num_tokens );
349+ moe_runner->getValidConfigIndices (effectiveTopK , args->hidden_size , args->intermediate_size ,
350+ effectiveLocalExperts , args->num_tokens );
345351 auto valid_it = std::find (valid_cfgs.begin (), valid_cfgs.end (), moe_tactic);
346352 FLASHINFER_CHECK (valid_it != valid_cfgs.end (), " Invalid MoE tactic " , moe_tactic,
347353 " for tile_N=" , tile_tokens_dim, " . Number of valid tactics for this tile is " ,
@@ -377,8 +383,9 @@ class FusedMoeLauncher {
377383
378384 routing_runner.run (
379385 args->routing_logits , args->routing_bias , args->num_tokens , args->num_experts , args->top_k ,
380- args->n_group , args->topk_group , args->local_expert_offset , args->local_num_experts ,
381- args->routed_scaling_factor , static_cast <int *>(expert_indexes.data_ptr ()),
386+ args->num_fused_shared_experts , args->n_group , args->topk_group , args->local_expert_offset ,
387+ args->local_num_experts , args->routed_scaling_factor ,
388+ static_cast <int *>(expert_indexes.data_ptr ()),
382389 static_cast <int *>(expert_count_histogram.data_ptr ()),
383390 static_cast <int *>(total_num_padded_tokens.data_ptr ()),
384391 static_cast <int *>(expanded_idx_to_permuted_idx.data_ptr ()),
@@ -882,12 +889,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
882889 auto const routing_bias_dtype =
883890 routing_bias.has_value () ? routing_bias.value ().dtype () : dl_bfloat16;
884891 mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
892+ int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts ;
885893 // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
886894 bool has_precomputed_weights = expert_weights.ndim () == 2 && expert_weights.size (0 ) > 0 ;
887895 if (!has_precomputed_weights) {
888896 // Allocate expert_weights buffer for routing output
889- FusedMoeLauncher::expert_weights =
890- alloc_tensor ({args-> num_tokens , args-> top_k }, dl_bfloat16, hidden_states.device ());
897+ FusedMoeLauncher::expert_weights = alloc_tensor ({args-> num_tokens , totalExpertsPerToken},
898+ dl_bfloat16, hidden_states.device ());
891899 workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr ();
892900 } else {
893901 workspace.expert_weights = const_cast <void *>(expert_weights.data_ptr ());
@@ -918,12 +926,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
918926 TVM_FFI_ICHECK_EQ (gemm1_weights.dtype (), dl_float8_e4m3fn) << " gemm1_weights must be fp8." ;
919927 TVM_FFI_ICHECK_EQ (gemm2_weights.dtype (), dl_float8_e4m3fn) << " gemm2_weights must be fp8." ;
920928
929+ int64_t const totalLocalExperts = args->local_num_experts + args->num_fused_shared_experts ;
921930 if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
922931 TVM_FFI_ICHECK_EQ (gemm1_weights_scale.dtype (), dl_float32)
923932 << " gemm1_weights_scale must be float." ;
924933 TVM_FFI_ICHECK_EQ (gemm1_weights_scale.ndim (), 3 ) << " gemm1_weights_scale must be 3D." ;
925- TVM_FFI_ICHECK_EQ (gemm1_weights_scale.size (0 ), args-> local_num_experts )
926- << " gemm1_weights_scale has incorrect shape ." ;
934+ TVM_FFI_ICHECK_EQ (gemm1_weights_scale.size (0 ), totalLocalExperts )
935+ << " gemm1_weights_scale has incorrect dim 0 ." ;
927936 TVM_FFI_ICHECK_EQ (args->intermediate_size % 128 , 0 )
928937 << " intermediate_size must be a multiple of 128." ;
929938 TVM_FFI_ICHECK_EQ (gemm1_weights_scale.size (1 ),
@@ -943,8 +952,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
943952 TVM_FFI_ICHECK_EQ (gemm2_weights_scale.dtype (), dl_float32)
944953 << " gemm2_weights_scale must be float." ;
945954 TVM_FFI_ICHECK_EQ (gemm2_weights_scale.ndim (), 3 ) << " gemm2_weights_scale must be 3D." ;
946- TVM_FFI_ICHECK_EQ (gemm2_weights_scale.size (0 ), args-> local_num_experts )
947- << " gemm2_weights_scale has incorrect shape ." ;
955+ TVM_FFI_ICHECK_EQ (gemm2_weights_scale.size (0 ), totalLocalExperts )
956+ << " gemm2_weights_scale has incorrect dim 0 ." ;
948957 TVM_FFI_ICHECK_EQ (gemm2_weights_scale.size (1 ), args->hidden_size / 128 )
949958 << " gemm2_weights_scale has incorrect shape." ;
950959 TVM_FFI_ICHECK_EQ (gemm2_weights_scale.size (2 ), args->intermediate_size / 128 )
@@ -1054,8 +1063,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
10541063 // routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes
10551064 routing_runner.run (
10561065 use_precomputed ? nullptr : args->routing_logits , args->routing_bias , args->num_tokens ,
1057- args->num_experts , args->top_k , args->n_group , args->topk_group , args->local_expert_offset ,
1058- args->local_num_experts , args->routed_scaling_factor , workspace.routing_expert_indexes ,
1066+ args->num_experts , args->top_k , args->num_fused_shared_experts , args->n_group ,
1067+ args->topk_group , args->local_expert_offset , args->local_num_experts ,
1068+ args->routed_scaling_factor , workspace.routing_expert_indexes ,
10591069 static_cast <int *>(expert_count_histogram.data_ptr ()),
10601070 static_cast <int *>(total_num_padded_tokens.data_ptr ()),
10611071 static_cast <int *>(expanded_idx_to_permuted_idx.data_ptr ()),
@@ -1517,8 +1527,9 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
15171527
15181528 routing_runner.run (
15191529 args->routing_logits , args->routing_bias , args->num_tokens , args->num_experts , args->top_k ,
1520- args->n_group , args->topk_group , args->local_expert_offset , args->local_num_experts ,
1521- args->routed_scaling_factor , static_cast <int *>(expert_indices.data_ptr ()),
1530+ args->num_fused_shared_experts , args->n_group , args->topk_group , args->local_expert_offset ,
1531+ args->local_num_experts , args->routed_scaling_factor ,
1532+ static_cast <int *>(expert_indices.data_ptr ()),
15221533 static_cast <int *>(expert_count_histogram.data_ptr ()),
15231534 static_cast <int *>(total_num_padded_tokens.data_ptr ()),
15241535 static_cast <int *>(expanded_idx_to_permuted_idx.data_ptr ()),
@@ -1746,10 +1757,11 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
17461757 Optional<TensorView> routing_bias, TensorView hidden_states, TensorView hidden_states_scale,
17471758 TensorView gemm1_weights, TensorView gemm1_weights_scale, TensorView gemm2_weights,
17481759 TensorView gemm2_weights_scale, TensorView output, int64_t num_experts, int64_t top_k,
1749- Optional<int64_t > n_group, Optional<int64_t > topk_group, int64_t intermediate_size,
1750- int64_t local_expert_offset, int64_t local_num_experts, Optional<double > routed_scaling_factor,
1751- int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
1752- bool enable_pdl, Array<int64_t > config_index, Fp8QuantizationType quantization_type) {
1760+ Optional<int64_t > num_fused_shared_experts, Optional<int64_t > n_group,
1761+ Optional<int64_t > topk_group, int64_t intermediate_size, int64_t local_expert_offset,
1762+ int64_t local_num_experts, Optional<double > routed_scaling_factor, int64_t routing_method_type,
1763+ bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, bool enable_pdl,
1764+ Array<int64_t > config_index, Fp8QuantizationType quantization_type) {
17531765 // Basic type validation
17541766 auto dtype = hidden_states.dtype ();
17551767
@@ -1810,9 +1822,13 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
18101822 auto const num_tokens = hidden_states.size (0 );
18111823 auto const hidden_size = hidden_states.size (1 );
18121824
1825+ int64_t const nFusedShared = num_fused_shared_experts.value_or (0 );
1826+ int64_t const totalExpertsPerToken = top_k + nFusedShared;
1827+ int64_t const totalLocalExperts = local_num_experts + nFusedShared;
1828+
18131829 auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums (quantization_type);
1814- std::set<int32_t > selected_tile_nums =
1815- computeSelectedTileN ( supported_tile_nums, num_tokens, top_k, local_num_experts );
1830+ std::set<int32_t > selected_tile_nums = computeSelectedTileN (
1831+ supported_tile_nums, num_tokens, totalExpertsPerToken, totalLocalExperts );
18161832
18171833 // Create a map of launchers for each tile size
18181834 std::unordered_map<int32_t , std::unique_ptr<Fp8BlockScaleLauncher>> launchers_map;
@@ -1822,6 +1838,7 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
18221838 auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
18231839 args->num_tokens = num_tokens;
18241840 args->num_experts = num_experts;
1841+ args->num_fused_shared_experts = nFusedShared;
18251842 args->hidden_size = hidden_size;
18261843 args->hidden_size_output = args->hidden_size ;
18271844 args->top_k = top_k;
0 commit comments