Skip to content

Commit 21cf038

Browse files
committed
Allow non-DeepSeekV3 routing with one group
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
1 parent cdbb2c3 commit 21cf038

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -765,10 +765,15 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
765765

766766
FusedMoeLauncher::check_routing_common();
767767

768-
if (args->n_group != 0) {
769-
TVM_FFI_ICHECK(static_cast<RoutingMethodType>(routing_method_type) ==
770-
RoutingMethodType::DeepSeekV3)
771-
<< "Routing kernel with groups implies DeepSeekV3 routing method.";
768+
if (static_cast<RoutingMethodType>(routing_method_type) != RoutingMethodType::DeepSeekV3) {
769+
TVM_FFI_ICHECK(args->n_group <= 1)
770+
<< "Current routing kernel (no groups) only supports n_group <= 1";
771+
TVM_FFI_ICHECK(args->topk_group <= 1)
772+
<< "Current routing kernel (no groups) only supports topk_group <= 1";
773+
}
774+
775+
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
776+
TVM_FFI_ICHECK(args->n_group != 0) << "n_group should not be zero for DeepSeekV3 routing";
772777
TVM_FFI_ICHECK(args->topk_group != 0) << "if n_group is given, topk_group must be given";
773778
TVM_FFI_ICHECK_EQ(args->num_experts % args->n_group, 0)
774779
<< "num_experts must be divisible by n_group";
@@ -790,6 +795,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
790795
TVM_FFI_ICHECK_EQ(args->top_k, 1)
791796
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";
792797
}
798+
793799
TVM_FFI_ICHECK_EQ(args->num_experts % 4, 0)
794800
<< "Routing kernel expects that num_experts must be divisible by 4";
795801
TVM_FFI_ICHECK_GT(args->num_experts, args->top_k) << "num_experts must be greater than top_k";
@@ -2004,9 +2010,8 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs(
20042010
}
20052011

20062012
TVM_FFI_LOG_AND_THROW(NotImplementedError)
2007-
<< "Unsupported data type combination for getValidConfigs: "
2008-
<< "dtype_act=" << static_cast<int>(dtype_act)
2009-
<< ", dtype_weights=" << static_cast<int>(dtype_weights)
2013+
<< "Unsupported data type combination for getValidConfigs: " << "dtype_act="
2014+
<< static_cast<int>(dtype_act) << ", dtype_weights=" << static_cast<int>(dtype_weights)
20102015
<< ", useDeepSeekFp8=" << useDeepSeekFp8;
20112016

20122017
// Unreachable code - added to suppress compiler warning

0 commit comments

Comments
 (0)