@@ -256,33 +256,36 @@ class FusedMoeLauncher {
256256 Tensor num_non_exiting_ctas;
257257
258258 void prepare_routing_common () {
259+ int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts ;
260+ int32_t const totalNumExperts = args->num_experts + args->num_fused_shared_experts ;
261+
259262 // Allocate routing phase workspace tensors
260- num_tokens_per_expert = alloc_tensor ({args-> num_experts }, dl_int32, hidden_states.device ());
263+ num_tokens_per_expert = alloc_tensor ({totalNumExperts }, dl_int32, hidden_states.device ());
261264 int32_t max_num_padded_tokens =
262265 tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount (
263- args->num_tokens , args-> top_k , args-> num_experts , tile_tokens_dim);
266+ args->num_tokens , totalExpertsPerToken, totalNumExperts , tile_tokens_dim);
264267
265268 total_num_padded_tokens = alloc_tensor ({1 }, dl_int32, hidden_states.device ());
266269
267270 expanded_idx_to_permuted_idx =
268- alloc_tensor ({args->num_tokens * args-> top_k }, dl_int32, hidden_states.device ());
271+ alloc_tensor ({args->num_tokens * totalExpertsPerToken }, dl_int32, hidden_states.device ());
269272
270273 permuted_idx_to_token_idx =
271274 alloc_tensor ({max_num_padded_tokens}, dl_int32, hidden_states.device ());
272275
273276 expert_indexes =
274- alloc_tensor ({args->num_tokens , args-> top_k }, dl_int32, hidden_states.device ());
277+ alloc_tensor ({args->num_tokens , totalExpertsPerToken }, dl_int32, hidden_states.device ());
275278
276279 // expert_weights allocation should be done by derived class since data type could vary
277280
278- int64_t const size_of_expert_count_histogram = std::max (args-> num_experts * 2 , 256 * 2 );
281+ int64_t const size_of_expert_count_histogram = std::max (totalNumExperts * 2 , 256 * 2 );
279282 expert_count_histogram = alloc_tensor ({size_of_expert_count_histogram},
280283 dl_int32, // 256 is the max number of threads per block
281284 // and max number of experts
282285 hidden_states.device ());
283286
284287 int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim (
285- args->num_tokens , args-> top_k , args-> num_experts , tile_tokens_dim);
288+ args->num_tokens , totalExpertsPerToken, totalNumExperts , tile_tokens_dim);
286289
287290 cta_idx_xy_to_batch_idx = alloc_tensor ({max_num_ctas}, dl_int32, hidden_states.device ());
288291
@@ -334,14 +337,17 @@ class FusedMoeLauncher {
334337 this ->activation_type , this ->use_shuffled_weight , this ->weight_layout );
335338 }
336339
340+ int32_t const effectiveTopK = args->top_k + args->num_fused_shared_experts ;
341+ int32_t const effectiveLocalExperts = args->local_num_experts + args->num_fused_shared_experts ;
342+
337343 if (moe_tactic == -1 ) {
338- moe_tactic = moe_runner->getDefaultValidConfigIndex (
339- args-> top_k , args-> hidden_size , args-> intermediate_size , args->local_num_experts ,
340- args->num_tokens );
344+ moe_tactic = moe_runner->getDefaultValidConfigIndex (effectiveTopK, args-> hidden_size ,
345+ args->intermediate_size ,
346+ effectiveLocalExperts, args->num_tokens );
341347 }
342348 auto valid_cfgs =
343- moe_runner->getValidConfigIndices (args-> top_k , args->hidden_size , args->intermediate_size ,
344- args-> local_num_experts , args->num_tokens );
349+ moe_runner->getValidConfigIndices (effectiveTopK , args->hidden_size , args->intermediate_size ,
350+ effectiveLocalExperts , args->num_tokens );
345351 auto valid_it = std::find (valid_cfgs.begin (), valid_cfgs.end (), moe_tactic);
346352 FLASHINFER_CHECK (valid_it != valid_cfgs.end (), " Invalid MoE tactic " , moe_tactic,
347353 " for tile_N=" , tile_tokens_dim, " . Number of valid tactics for this tile is " ,
@@ -377,8 +383,8 @@ class FusedMoeLauncher {
377383
378384 routing_runner.run (
379385 args->routing_logits , args->routing_bias , args->num_tokens , args->num_experts , args->top_k ,
380- args->n_group , args->topk_group , args->local_expert_offset , args->local_num_experts ,
381- args->routed_scaling_factor , workspace.routing_expert_indexes ,
386+ args->num_fused_shared_experts , args->n_group , args->topk_group , args->local_expert_offset ,
387+ args->local_num_experts , args-> routed_scaling_factor , workspace.routing_expert_indexes ,
382388 static_cast <int *>(expert_count_histogram.data_ptr ()),
383389 static_cast <int *>(total_num_padded_tokens.data_ptr ()),
384390 static_cast <int *>(expanded_idx_to_permuted_idx.data_ptr ()),
@@ -910,12 +916,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
910916 auto const routing_bias_dtype =
911917 routing_bias.has_value () ? routing_bias.value ().dtype () : dl_bfloat16;
912918 mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
919+ int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts ;
913920 // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
914921 bool has_precomputed_weights = expert_weights.ndim () == 2 && expert_weights.size (0 ) > 0 ;
915922 if (!has_precomputed_weights) {
916923 // Allocate expert_weights buffer for routing output
917- FusedMoeLauncher::expert_weights =
918- alloc_tensor ({args-> num_tokens , args-> top_k }, dl_bfloat16, hidden_states.device ());
924+ FusedMoeLauncher::expert_weights = alloc_tensor ({args-> num_tokens , totalExpertsPerToken},
925+ dl_bfloat16, hidden_states.device ());
919926 workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr ();
920927 } else {
921928 workspace.expert_weights = const_cast <void *>(expert_weights.data_ptr ());
@@ -946,12 +953,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
946953 TVM_FFI_ICHECK_EQ (gemm1_weights.dtype (), dl_float8_e4m3fn) << " gemm1_weights must be fp8." ;
947954 TVM_FFI_ICHECK_EQ (gemm2_weights.dtype (), dl_float8_e4m3fn) << " gemm2_weights must be fp8." ;
948955
956+ int64_t const totalLocalExperts = args->local_num_experts + args->num_fused_shared_experts ;
949957 if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
950958 TVM_FFI_ICHECK_EQ (gemm1_weights_scale.dtype (), dl_float32)
951959 << " gemm1_weights_scale must be float." ;
952960 TVM_FFI_ICHECK_EQ (gemm1_weights_scale.ndim (), 3 ) << " gemm1_weights_scale must be 3D." ;
953- TVM_FFI_ICHECK_EQ (gemm1_weights_scale.size (0 ), args-> local_num_experts )
954- << " gemm1_weights_scale has incorrect shape ." ;
961+ TVM_FFI_ICHECK_EQ (gemm1_weights_scale.size (0 ), totalLocalExperts )
962+ << " gemm1_weights_scale has incorrect dim 0 ." ;
955963 TVM_FFI_ICHECK_EQ (args->intermediate_size % 128 , 0 )
956964 << " intermediate_size must be a multiple of 128." ;
957965 TVM_FFI_ICHECK_EQ (gemm1_weights_scale.size (1 ),
@@ -971,8 +979,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
971979 TVM_FFI_ICHECK_EQ (gemm2_weights_scale.dtype (), dl_float32)
972980 << " gemm2_weights_scale must be float." ;
973981 TVM_FFI_ICHECK_EQ (gemm2_weights_scale.ndim (), 3 ) << " gemm2_weights_scale must be 3D." ;
974- TVM_FFI_ICHECK_EQ (gemm2_weights_scale.size (0 ), args-> local_num_experts )
975- << " gemm2_weights_scale has incorrect shape ." ;
982+ TVM_FFI_ICHECK_EQ (gemm2_weights_scale.size (0 ), totalLocalExperts )
983+ << " gemm2_weights_scale has incorrect dim 0 ." ;
976984 TVM_FFI_ICHECK_EQ (gemm2_weights_scale.size (1 ), args->hidden_size / 128 )
977985 << " gemm2_weights_scale has incorrect shape." ;
978986 TVM_FFI_ICHECK_EQ (gemm2_weights_scale.size (2 ), args->intermediate_size / 128 )
@@ -1082,8 +1090,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
10821090 // routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes
10831091 routing_runner.run (
10841092 use_precomputed ? nullptr : args->routing_logits , args->routing_bias , args->num_tokens ,
1085- args->num_experts , args->top_k , args->n_group , args->topk_group , args->local_expert_offset ,
1086- args->local_num_experts , args->routed_scaling_factor , workspace.routing_expert_indexes ,
1093+ args->num_experts , args->top_k , args->num_fused_shared_experts , args->n_group ,
1094+ args->topk_group , args->local_expert_offset , args->local_num_experts ,
1095+ args->routed_scaling_factor , workspace.routing_expert_indexes ,
10871096 static_cast <int *>(expert_count_histogram.data_ptr ()),
10881097 static_cast <int *>(total_num_padded_tokens.data_ptr ()),
10891098 static_cast <int *>(expanded_idx_to_permuted_idx.data_ptr ()),
@@ -1545,8 +1554,9 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
15451554
15461555 routing_runner.run (
15471556 args->routing_logits , args->routing_bias , args->num_tokens , args->num_experts , args->top_k ,
1548- args->n_group , args->topk_group , args->local_expert_offset , args->local_num_experts ,
1549- args->routed_scaling_factor , static_cast <int *>(expert_indices.data_ptr ()),
1557+ args->num_fused_shared_experts , args->n_group , args->topk_group , args->local_expert_offset ,
1558+ args->local_num_experts , args->routed_scaling_factor ,
1559+ static_cast <int *>(expert_indices.data_ptr ()),
15501560 static_cast <int *>(expert_count_histogram.data_ptr ()),
15511561 static_cast <int *>(total_num_padded_tokens.data_ptr ()),
15521562 static_cast <int *>(expanded_idx_to_permuted_idx.data_ptr ()),
@@ -1779,10 +1789,11 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
17791789 Optional<TensorView> routing_bias, TensorView hidden_states, TensorView hidden_states_scale,
17801790 TensorView gemm1_weights, TensorView gemm1_weights_scale, TensorView gemm2_weights,
17811791 TensorView gemm2_weights_scale, TensorView output, int64_t num_experts, int64_t top_k,
1782- Optional<int64_t > n_group, Optional<int64_t > topk_group, int64_t intermediate_size,
1783- int64_t local_expert_offset, int64_t local_num_experts, Optional<double > routed_scaling_factor,
1784- int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
1785- bool enable_pdl, Array<int64_t > config_index, Fp8QuantizationType quantization_type) {
1792+ Optional<int64_t > num_fused_shared_experts, Optional<int64_t > n_group,
1793+ Optional<int64_t > topk_group, int64_t intermediate_size, int64_t local_expert_offset,
1794+ int64_t local_num_experts, Optional<double > routed_scaling_factor, int64_t routing_method_type,
1795+ bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, bool enable_pdl,
1796+ Array<int64_t > config_index, Fp8QuantizationType quantization_type) {
17861797 // Basic type validation
17871798 auto dtype = hidden_states.dtype ();
17881799
@@ -1843,9 +1854,13 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
18431854 auto const num_tokens = hidden_states.size (0 );
18441855 auto const hidden_size = hidden_states.size (1 );
18451856
1857+ int64_t const nFusedShared = num_fused_shared_experts.value_or (0 );
1858+ int64_t const totalExpertsPerToken = top_k + nFusedShared;
1859+ int64_t const totalLocalExperts = local_num_experts + nFusedShared;
1860+
18461861 auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums (quantization_type);
1847- std::set<int32_t > selected_tile_nums =
1848- computeSelectedTileN ( supported_tile_nums, num_tokens, top_k, local_num_experts );
1862+ std::set<int32_t > selected_tile_nums = computeSelectedTileN (
1863+ supported_tile_nums, num_tokens, totalExpertsPerToken, totalLocalExperts );
18491864
18501865 // Create a map of launchers for each tile size
18511866 std::unordered_map<int32_t , std::unique_ptr<Fp8BlockScaleLauncher>> launchers_map;
@@ -1855,6 +1870,7 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
18551870 auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
18561871 args->num_tokens = num_tokens;
18571872 args->num_experts = num_experts;
1873+ args->num_fused_shared_experts = nFusedShared;
18581874 args->hidden_size = hidden_size;
18591875 args->hidden_size_output = args->hidden_size ;
18601876 args->top_k = top_k;
0 commit comments