Skip to content

Commit 5f2751e

Browse files
committed
support Sigmoid (sigmoid+topk) routing function
Signed-off-by: EdalatiAli <aliedalati@cohere.com>
1 parent 97b6f0e commit 5f2751e

6 files changed

Lines changed: 147 additions & 5 deletions

File tree

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
887887
static_cast<RoutingMethodType>(routing_method_type) ==
888888
RoutingMethodType::RenormalizeNaive ||
889889
static_cast<RoutingMethodType>(routing_method_type) ==
890-
RoutingMethodType::SigmoidRenorm) {
890+
RoutingMethodType::SigmoidRenorm ||
891+
static_cast<RoutingMethodType>(routing_method_type) ==
892+
RoutingMethodType::Sigmoid) {
891893
TVM_FFI_ICHECK(args->top_k <= 10 && args->top_k > 0)
892894
<< "Current routing kernel (no groups) only supports top_k<=10 && top_k>0.";
893895
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {

csrc/trtllm_fused_moe_runner.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
138138
|| routingMethodType == RoutingMethodType::Renormalize /* TopK -> Softmax */
139139
|| routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK -> Renormalize */
140140
|| routingMethodType == RoutingMethodType::TopK /* TopK only (no softmax) */
141-
|| routingMethodType == RoutingMethodType::SigmoidRenorm /* Sigmoid -> TopK -> Renormalize */) {
141+
|| routingMethodType == RoutingMethodType::SigmoidRenorm /* Sigmoid -> TopK -> Renormalize */
142+
|| routingMethodType == RoutingMethodType::Sigmoid /* Sigmoid -> TopK */) {
142143
using namespace moe::dev::routing;
143144
routingCustom::Data routingData;
144145

@@ -164,6 +165,11 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
164165
routingData.mPreprocessType = RoutingPreprocessType::Sigmoid;
165166
routingData.mPostprocessType = RoutingPostprocessType::SumNormalize;
166167
routingData.mNormTopkProb = normTopkProb;
168+
} else if (routingMethodType == RoutingMethodType::Sigmoid) {
169+
// Sigmoid -> TopK (no renormalization)
170+
routingData.mPreprocessType = RoutingPreprocessType::Sigmoid;
171+
routingData.mPostprocessType = RoutingPostprocessType::SumNormalize;
172+
routingData.mNormTopkProb = false;
167173
} else if (routingMethodType == RoutingMethodType::Renormalize ||
168174
routingMethodType == RoutingMethodType::RenormalizeNaive) {
169175
// TopK -> Softmax (also used for RenormalizeNaive, see comment above)

flashinfer/fused_moe/core.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ class RoutingMethodType(IntEnum):
7373
TopK = (5,)
7474
# SigmoidRenorm: Sigmoid -> TopK -> Renormalize (divide by sum of top-K weights)
7575
SigmoidRenorm = (6,)
76+
# Sigmoid: Sigmoid -> TopK (no renormalization)
77+
Sigmoid = (7,)
7678
# Unspecified
77-
Unspecified = (7,)
79+
Unspecified = (8,)
7880

7981

8082
# Copied from csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h
@@ -2302,6 +2304,8 @@ def trtllm_bf16_moe(
23022304
- 2: DeepSeekV3 (Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts)
23032305
- 3: Llama4 (Top1 -> Sigmoid)
23042306
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
2307+
- 6: SigmoidRenorm (Sigmoid -> TopK -> Renormalize)
2308+
- 7: Sigmoid (Sigmoid -> TopK)
23052309
use_shuffled_weight: Whether to use shuffled weight layout for optimization (default: True).
23062310
weight_layout: Weight layout format (default: WeightLayout.BlockMajorK).
23072311
- 0: MajorK - K-major layout [Mn, K]
@@ -2397,6 +2401,8 @@ def trtllm_bf16_routed_moe(
23972401
- 2: DeepSeekV3 (Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts)
23982402
- 3: Llama4 (Top1 -> Sigmoid)
23992403
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
2404+
- 6: SigmoidRenorm (Sigmoid -> TopK -> Renormalize)
2405+
- 7: Sigmoid (Sigmoid -> TopK)
24002406
use_shuffled_weight: Whether to use shuffled weight layout for optimization (default: True).
24012407
weight_layout: Weight layout format (default: WeightLayout.BlockMajorK).
24022408
- 0: MajorK - K-major layout [Mn, K]
@@ -2832,6 +2838,8 @@ def trtllm_fp4_block_scale_moe(
28322838
- 2: DeepSeekV3 (Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts)
28332839
- 3: Llama4 (Top1 -> Sigmoid)
28342840
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
2841+
- 6: SigmoidRenorm (Sigmoid -> TopK -> Renormalize)
2842+
- 7: Sigmoid (Sigmoid -> TopK)
28352843
do_finalize (bool): Whether to finalize the output (default: False)
28362844
enable_pdl (Optional[bool]): Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
28372845
activation_type (int): Type of activation function (default: 3 - Swiglu)
@@ -2967,6 +2975,8 @@ def trtllm_fp4_block_scale_routed_moe(
29672975
- 2: DeepSeekV3 (Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts)
29682976
- 3: Llama4 (Top1 -> Sigmoid)
29692977
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
2978+
- 6: SigmoidRenorm (Sigmoid -> TopK -> Renormalize)
2979+
- 7: Sigmoid (Sigmoid -> TopK)
29702980
do_finalize (bool): Whether to finalize the output (default: False)
29712981
activation_type (int): Type of activation function (default: 3 - Swiglu)
29722982
- 3: Swiglu
@@ -3082,6 +3092,8 @@ def trtllm_mxint4_block_scale_moe(
30823092
- 2: DeepSeekV3 (Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts)
30833093
- 3: Llama4 (Top1 -> Sigmoid)
30843094
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
3095+
- 6: SigmoidRenorm (Sigmoid -> TopK -> Renormalize)
3096+
- 7: Sigmoid (Sigmoid -> TopK)
30853097
do_finalize (bool): Whether to finalize the output (default: False)
30863098
enable_pdl (Optional[bool]): Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
30873099
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192)

include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,8 @@ struct PolicyTraits<NoOpPreprocess, SoftmaxPostprocess> {
453453
>;
454454
};
455455

456-
/// Sigmoid + SumNormalize (SigmoidRenorm: Sigmoid -> TopK -> Renormalize).
456+
/// Sigmoid + SumNormalize (SigmoidRenorm: Sigmoid -> TopK -> Renormalize,
457+
/// Sigmoid: Sigmoid -> TopK with normTopkProb=false).
457458
template <>
458459
struct PolicyTraits<SigmoidPreprocess, SumNormalizePostprocess> {
459460
using Pairs = TierList<Tier<128, 8>, // Small expert counts (≤128 experts)

include/flashinfer/trtllm/fused_moe/runner.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ enum class RoutingMethodType : int64_t {
5050
TopK = 5,
5151
// SigmoidRenorm: Sigmoid -> TopK -> Renormalize (divide by sum of top-K weights)
5252
SigmoidRenorm = 6,
53+
// Sigmoid: Sigmoid -> TopK (no renormalization)
54+
Sigmoid = 7,
5355
// Unspecified
54-
Unspecified = 7,
56+
Unspecified = 8,
5557
};
5658

5759
inline int32_t maybeGetMinTokenCount(int32_t numPaddedTokens, int32_t hiddenSize,
@@ -77,6 +79,8 @@ inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethod
7779
return "TopK";
7880
case RoutingMethodType::SigmoidRenorm:
7981
return "SigmoidRenorm";
82+
case RoutingMethodType::Sigmoid:
83+
return "Sigmoid";
8084
default:
8185
return "InvalidRountingMethod"; // TODO throw error
8286
};

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2748,6 +2748,10 @@ def run_moe_test(
27482748
permute_info, scores = routing_reference_sigmoid_renorm(
27492749
expert_logits, top_k, num_experts, padding, norm_topk_prob=norm_topk_prob
27502750
)
2751+
elif routing_method_type == RoutingMethodType.Sigmoid:
2752+
permute_info, scores = routing_reference_sigmoid_renorm(
2753+
expert_logits, top_k, num_experts, padding, norm_topk_prob=False
2754+
)
27512755
elif routing_method_type == RoutingMethodType.Llama4:
27522756
permute_info, scores = routing_reference_no_aux(
27532757
expert_logits,
@@ -3242,6 +3246,119 @@ def test_sigmoid_renorm_routing(
32423246
)
32433247

32443248

3249+
# Test: Sigmoid routing (Sigmoid -> TopK, no renormalization)
3250+
@pytest.mark.parametrize("num_tokens", [8, 768, 3072])
3251+
@pytest.mark.parametrize("hidden_size", [1024])
3252+
@pytest.mark.parametrize("intermediate_size", [1024, 768, 512, 384])
3253+
@pytest.mark.parametrize(
3254+
"moe_impl",
3255+
[
3256+
pytest.param(BF16Moe(), id="BF16xBF16"),
3257+
pytest.param(
3258+
FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_DEEPSEEK),
3259+
id="FP8_Block_DeepSeek",
3260+
),
3261+
pytest.param(
3262+
FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_MXFP8),
3263+
id="FP8_Block_MxFp8",
3264+
),
3265+
pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"),
3266+
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"),
3267+
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"),
3268+
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"),
3269+
pytest.param(MxInt4BlockScaleMoe(), id="MxInt4xBf16"),
3270+
],
3271+
)
3272+
@pytest.mark.parametrize(
3273+
"routing_config",
3274+
[
3275+
pytest.param(
3276+
{
3277+
"num_experts": 128,
3278+
"top_k": 8,
3279+
"padding": 8,
3280+
"n_groups": None,
3281+
"top_k_groups": None,
3282+
"routed_scaling": None,
3283+
"has_routing_bias": False,
3284+
"routing_method_type": RoutingMethodType.Sigmoid,
3285+
"compatible_moe_impls": [
3286+
FP8PerTensorMoe,
3287+
FP8BlockScaleMoe,
3288+
FP4Moe,
3289+
BF16Moe,
3290+
MxInt4BlockScaleMoe,
3291+
],
3292+
"compatible_intermediate_size": [384, 768, 1024],
3293+
"enable_autotune": True,
3294+
},
3295+
id="Sigmoid_128e_top8",
3296+
),
3297+
],
3298+
)
3299+
@pytest.mark.parametrize(
3300+
"weight_processing",
3301+
[
3302+
pytest.param(
3303+
{
3304+
"use_shuffled_weight": False,
3305+
"layout": WeightLayout.MajorK,
3306+
"compatible_moe_impls": [FP8BlockScaleMoe],
3307+
},
3308+
id="NoShuffle_MajorK",
3309+
),
3310+
pytest.param(
3311+
{
3312+
"use_shuffled_weight": True,
3313+
"layout": WeightLayout.MajorK,
3314+
"compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe],
3315+
},
3316+
id="Shuffled_MajorK",
3317+
),
3318+
pytest.param(
3319+
{
3320+
"use_shuffled_weight": True,
3321+
"layout": WeightLayout.BlockMajorK,
3322+
"compatible_moe_impls": [
3323+
FP8BlockScaleMoe,
3324+
BF16Moe,
3325+
MxInt4BlockScaleMoe,
3326+
],
3327+
},
3328+
id="Shuffled_BlockMajorK",
3329+
),
3330+
],
3331+
)
3332+
@pytest.mark.parametrize(
3333+
"activation_type",
3334+
[
3335+
pytest.param(ActivationType.Swiglu.value, id="Swiglu"),
3336+
pytest.param(ActivationType.Geglu.value, id="Geglu"),
3337+
],
3338+
)
3339+
def test_sigmoid_routing(
3340+
num_tokens,
3341+
hidden_size,
3342+
intermediate_size,
3343+
moe_impl,
3344+
routing_config,
3345+
weight_processing,
3346+
activation_type,
3347+
cache_permute_indices,
3348+
):
3349+
"""Test Sigmoid routing configurations (Sigmoid -> TopK, no renormalization)."""
3350+
run_moe_test(
3351+
num_tokens,
3352+
hidden_size,
3353+
intermediate_size,
3354+
moe_impl,
3355+
routing_config,
3356+
weight_processing,
3357+
activation_type,
3358+
cache_permute_indices,
3359+
)
3360+
3361+
32453362
# Test: DeepSeekV3 routing
32463363
@pytest.mark.parametrize("num_tokens", [8, 768, 3072])
32473364
@pytest.mark.parametrize("hidden_size", [1024])

0 commit comments

Comments
 (0)