@@ -765,10 +765,15 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
765765
766766 FusedMoeLauncher::check_routing_common ();
767767
768- if (args->n_group != 0 ) {
769- TVM_FFI_ICHECK (static_cast <RoutingMethodType>(routing_method_type) ==
770- RoutingMethodType::DeepSeekV3)
771- << " Routing kernel with groups implies DeepSeekV3 routing method." ;
768+ if (static_cast <RoutingMethodType>(routing_method_type) != RoutingMethodType::DeepSeekV3) {
769+ TVM_FFI_ICHECK (args->n_group <= 1 )
770+ << " Current routing kernel (no groups) only supports n_group <= 1" ;
771+ TVM_FFI_ICHECK (args->topk_group <= 1 )
772+ << " Current routing kernel (no groups) only supports topk_group <= 1" ;
773+ }
774+
775+ if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
776+ TVM_FFI_ICHECK (args->n_group != 0 ) << " n_group should not be zero for DeepSeekV3 routing" ;
772777 TVM_FFI_ICHECK (args->topk_group != 0 ) << " if n_group is given, topk_group must be given" ;
773778 TVM_FFI_ICHECK_EQ (args->num_experts % args->n_group , 0 )
774779 << " num_experts must be divisible by n_group" ;
@@ -790,6 +795,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
790795 TVM_FFI_ICHECK_EQ (args->top_k , 1 )
791796 << " Current routing kernel (no groups, Llama4) only supports top_k=1." ;
792797 }
798+
793799 TVM_FFI_ICHECK_EQ (args->num_experts % 4 , 0 )
794800 << " Routing kernel expects that num_experts must be divisible by 4" ;
795801 TVM_FFI_ICHECK_GT (args->num_experts , args->top_k ) << " num_experts must be greater than top_k" ;
@@ -2004,9 +2010,8 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs(
20042010 }
20052011
20062012 TVM_FFI_LOG_AND_THROW (NotImplementedError)
2007- << " Unsupported data type combination for getValidConfigs: "
2008- << " dtype_act=" << static_cast <int >(dtype_act)
2009- << " , dtype_weights=" << static_cast <int >(dtype_weights)
2013+ << " Unsupported data type combination for getValidConfigs: " << " dtype_act="
2014+ << static_cast <int >(dtype_act) << " , dtype_weights=" << static_cast <int >(dtype_weights)
20102015 << " , useDeepSeekFp8=" << useDeepSeekFp8;
20112016
20122017 // Unreachable code - added to suppress compiler warning
0 commit comments