From 1c253ff626521ed145a8494bf274cce42ee36b20 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Mon, 13 Apr 2026 08:21:59 +0000 Subject: [PATCH 1/7] Add support for Relu2 in BF16 fused MoE, update test_accordingly Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_fused_moe_kernel_launcher.cu | 10 ++++----- flashinfer/fused_moe/core.py | 26 ++++++++++++++++++++++-- tests/moe/test_trtllm_gen_fused_moe.py | 11 ++++++---- tests/moe/utils.py | 1 + 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index d252198e2b..edf709cf60 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -531,10 +531,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, bool norm_topk_prob = true) { - constexpr ActivationType activation_type = - ActivationType::Swiglu; // not exposed in api for now - + int64_t weight_layout, ActivationType activation_type, bool norm_topk_prob = true) { // Do base class init and perform common checks FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, use_shuffled_weight, weight_layout, activation_type, @@ -1738,7 +1735,7 @@ Array trtllm_bf16_moe(Optional const& routing_logits, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, - bool enable_pdl, Array moe_tactic, bool norm_topk_prob) { + bool enable_pdl, Array moe_tactic, int64_t activation_type, bool norm_topk_prob) { // Just some basic type validation first and leave more checks to the launcher if (routing_logits.has_value()) { TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || @@ -1754,6 +1751,7 @@ Array trtllm_bf16_moe(Optional const& routing_logits, auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); + auto const activation = static_cast(activation_type); // Calculate supported tile sizes std::vector mSupportedTileN(Bf16MoeLauncher::mSupportedTileNums.begin(), @@ -1788,7 +1786,7 @@ Array trtllm_bf16_moe(Optional const& routing_logits, expert_weights, hidden_states, gemm1_weights, gemm2_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, - weight_layout, norm_topk_prob); + weight_layout, activation, norm_topk_prob); launchers_map[curr_tile_N] = std::move(launcher); } diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index f7a1522333..d8923fe54f 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -54,7 +54,14 @@ get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, ) -from ..tllm_enums import * +from ..tllm_enums import ( + ActivationType, + WeightLayout, + DtypeTrtllmGen, + Fp8QuantizationType, + deduce_trtllm_gen_tensor_dtype, + trtllm_gen_dtype_has_scale, +) @functools.cache @@ -1107,6 +1114,7 @@ def forward( kwargs["do_finalize"], kwargs["enable_pdl"], [-1, -1] if tactic == -1 else tactic, + self.activation_type, kwargs.get("norm_topk_prob", True), ) elif ( @@ -1290,6 +1298,7 @@ def trtllm_bf16_moe_op( do_finalize: bool = True, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, + activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, ) -> List[torch.Tensor]: assert routing_logits is not None or topk_ids is not None, ( @@ -1338,7 +1347,7 @@ def trtllm_bf16_moe_op( intermediate_size=intermediate_size, weight_layout=weight_layout, use_shuffled_weight=use_shuffled_weight, - activation_type=ActivationType.Swiglu, # Default for BF16 + activation_type=activation_type, ) moe_inputs = MoEInputs( @@ -1375,6 +1384,7 @@ def trtllm_bf16_moe_op( weight_layout=weight_layout, do_finalize=do_finalize, enable_pdl=enable_pdl, + activation_type=activation_type, ) # Call the C++ function with the selected tactic @@ -1401,6 +1411,7 @@ def trtllm_bf16_moe_op( do_finalize, enable_pdl, [-1, -1] if tactic == -1 else tactic, + activation_type, norm_topk_prob, ) if do_finalize: @@ -1435,6 +1446,7 @@ def _fake_trtllm_bf16_moe( do_finalize: bool = True, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, + activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, ) -> List[torch.Tensor]: seq_len = hidden_states.shape[0] @@ -2272,6 +2284,7 @@ def trtllm_bf16_moe( do_finalize: bool = True, enable_pdl: bool = True, tune_max_num_tokens: int = 8192, + activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, ) -> Union[List[torch.Tensor], torch.Tensor]: """BF16 MoE operation with autotuning support. @@ -2310,6 +2323,9 @@ def trtllm_bf16_moe( do_finalize: Whether to finalize the output (default: True). enable_pdl: Whether to enable Programmatic Dependent Launch. Auto-enabled for >= sm90. tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192). + activation_type (int): Type of activation function (default: 3 - Swiglu) + - 3: Swiglu + - 6: Relu2 Returns: when do_finalize=True, returns the final MoE output. @@ -2337,6 +2353,7 @@ def trtllm_bf16_moe( do_finalize, enable_pdl, tune_max_num_tokens, + activation_type, norm_topk_prob, ) @@ -2369,6 +2386,7 @@ def trtllm_bf16_routed_moe( do_finalize: bool = True, enable_pdl: bool = True, tune_max_num_tokens: int = 8192, + activation_type: int = ActivationType.Swiglu.value, ) -> List[torch.Tensor]: """BF16 MoE operation with autotuning support. @@ -2405,6 +2423,9 @@ def trtllm_bf16_routed_moe( do_finalize: Whether to finalize the output (default: True). enable_pdl: Whether to enable Programmatic Dependent Launch. Auto-enabled for >= sm90. tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192). + activation_type (int): Type of activation function (default: 3 - Swiglu) + - 3: Swiglu + - 6: Relu2 Returns: when do_finalize=True, returns the final MoE output. @@ -2432,6 +2453,7 @@ def trtllm_bf16_routed_moe( do_finalize, enable_pdl, tune_max_num_tokens, + activation_type, True, # norm_topk_prob: not used for pre-computed routing ) diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 568721b18b..3812ddfbe1 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -1445,6 +1445,7 @@ def prepare_static_weights_for_kernel( self._cache_permute_indices, args.gemm1_weights[i].view(torch.uint8), epilogue_tile_m, + is_gated_act_gemm=is_gated_activation(args.activation_type), ) tmp_weights1 = ( args.gemm1_weights[i] @@ -1508,6 +1509,7 @@ def call_moe( routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] enable_autotune = kwargs.get("enable_autotune", True) + activation_type = kwargs["activation_type"] norm_topk_prob = kwargs.get("norm_topk_prob", True) # Use autotuner for optimal kernel selection @@ -1530,6 +1532,7 @@ def call_moe( weight_layout=static_data["weight_layout"], routing_method_type=routing_method_type, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, + activation_type=activation_type, norm_topk_prob=norm_topk_prob, ) return output.to(torch.float) @@ -3131,7 +3134,7 @@ def test_renormalize_routing( # Test: DeepSeekV3 routing @pytest.mark.parametrize("num_tokens", [8, 768, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) -@pytest.mark.parametrize("intermediate_size", [2944, 2048, 1024, 768, 512, 384]) +@pytest.mark.parametrize("intermediate_size", [2688, 2048, 1024, 768, 512, 384]) @pytest.mark.parametrize( "moe_impl", [ @@ -3164,12 +3167,12 @@ def test_renormalize_routing( "routed_scaling": 2.5, "has_routing_bias": True, "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [FP8PerTensorMoe, FP4Moe], - "compatible_intermediate_size": [2944], + "compatible_moe_impls": [BF16Moe, FP8PerTensorMoe, FP4Moe], + "compatible_intermediate_size": [2688], "compatible_activation_types": [ActivationType.Relu2], "enable_autotune": True, }, - id="nemotron_3_dummy", + id="nemotron_3_super", ), pytest.param( { diff --git a/tests/moe/utils.py b/tests/moe/utils.py index d4f6fdbac5..42c4c45b71 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -39,6 +39,7 @@ class QuantMode(IntEnum): QuantMode.FP4_NVFP4_NVFP4, QuantMode.FP8_BLOCK_SCALE_MXFP8, QuantMode.FP8_PER_TENSOR, + QuantMode.BF16, ] From d70a0cf9f8fde5a123621d21e63e9c6cec1bd285 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Mon, 13 Apr 2026 08:21:59 +0000 Subject: [PATCH 2/7] Fix flashinfer/fused_moe.__init__.py import issues found by 'pre-commit run --all-files' Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- flashinfer/fused_moe/__init__.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index e2b4cab3d6..95d536bdf3 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -15,10 +15,6 @@ """ from .core import ( - ActivationType, - Fp8QuantizationType, - RoutingMethodType, - WeightLayout, convert_to_block_layout, cutlass_fused_moe, gen_cutlass_fused_moe_sm120_module, @@ -37,6 +33,13 @@ trtllm_mxint4_block_scale_moe, ) +from ..tllm_enums import ( + ActivationType, + Fp8QuantizationType, + WeightLayout, + RoutingMethodType, +) + from .fused_routing_dsv3 import ( # noqa: F401 fused_topk_deepseek as fused_topk_deepseek, ) From b523d74c946166e73256f437d83fe4ae2b76242d Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Mon, 13 Apr 2026 08:21:59 +0000 Subject: [PATCH 3/7] Use validateAndCastActivationType instead of unchecked static cast Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_fused_moe_kernel_launcher.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index edf709cf60..00d88cc1e0 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1751,7 +1751,7 @@ Array trtllm_bf16_moe(Optional const& routing_logits, auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); - auto const activation = static_cast(activation_type); + auto const activation = validateAndCastActivationType(activation_type); // Calculate supported tile sizes std::vector mSupportedTileN(Bf16MoeLauncher::mSupportedTileNums.begin(), @@ -1815,7 +1815,7 @@ Array trtllm_fp8_per_tensor_scale_moe( bool enable_pdl, Array config_index, int64_t activation_type, bool norm_topk_prob) { // Basic type validation auto dtype = hidden_states.dtype(); - auto activation = static_cast(activation_type); + auto activation = validateAndCastActivationType(activation_type); TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16) << "FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16."; From e710491c85c1a2acd22c09c0d78de14a5422e86a Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Mon, 13 Apr 2026 08:37:54 +0000 Subject: [PATCH 4/7] Update trtllm-gen batched GEMM artifact path & checksum, update csrc/trtllm_batched_gemm_runner.cu access to BatchedGemmOptions.mNumStages as it was split to A and B Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_batched_gemm_runner.cu | 4 ++-- flashinfer/artifacts.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index f4cb825d36..c3f1ef7fd5 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -68,8 +68,8 @@ std::vector prioritizePredefinedConfigs( if (n /* out_dim */ == 0 && k /* in_dim */ == 0) { auto pred = [](BatchedGemmConfig const& config) { BatchedGemmOptions const& options = config.mOptions; - return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 256 && - options.mTileScheduler == TileScheduler::Persistent; + return options.mNumStagesA == 4 && options.mNumStagesB == 4 && options.mNumStagesMma == 2 && + options.mTileK == 256 && options.mTileScheduler == TileScheduler::Persistent; }; prioritizedIndices = bubbleUpConfig(sortedIndices, pred); } diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 25b568c512..6794b91b4e 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -137,7 +137,7 @@ class ArtifactPath: TRTLLM_GEN_FMHA: str = "55bba55929d4093682e32d817bd11ffb0441c749/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "31e75d429ff3f710de1251afdd148185f53da44d/batched_gemm-4daf11e-c111d7c/" + "39a9d28268f43475a757d5700af135e1e58c9849/batched_gemm-5ee61af-2b9855b/" ) TRTLLM_GEN_GEMM: str = ( "31e75d429ff3f710de1251afdd148185f53da44d/gemm-4daf11e-1fddea2/" @@ -158,7 +158,7 @@ class CheckSumHash: "f2c0aad1e74391c4267a2f9a20ec819358b59e04588385cffb452ed341500b99" ) TRTLLM_GEN_BMM: str = ( - "2c2361bdf1deb0a2ea0f130f2d57dd62864f4400a706ac19a625d492b03460cb" + "db06db7f36a2a9395a2041ff6ac016fe664874074413a2ed90797f91ef17e0f6" ) DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" TRTLLM_GEN_GEMM: str = ( From bb55a6118cb11b7016e1558fae0ba43e1b674f01 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Mon, 13 Apr 2026 08:37:54 +0000 Subject: [PATCH 5/7] MoE test - Add FP4_NVFP4_NVFP4 to allowed quant_mode with logits_dtype of fp32 Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- tests/moe/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/moe/utils.py b/tests/moe/utils.py index 42c4c45b71..39abf18717 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -162,6 +162,7 @@ def skip_checks( ) if logits_dtype == torch.float32 and moe_impl.quant_mode not in [ + QuantMode.FP4_NVFP4_NVFP4, QuantMode.FP8_PER_TENSOR, QuantMode.FP8_BLOCK_SCALE_DEEPSEEK, QuantMode.FP8_BLOCK_SCALE_MXFP8, From a84a9c9ac96f3638ab94bca1327c8b0bb5f311f8 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Mon, 13 Apr 2026 08:37:55 +0000 Subject: [PATCH 6/7] Update docstrings with gemm1_weights to include non-gated activation shape Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- flashinfer/fused_moe/core.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index d8923fe54f..b164ce0ae7 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -2299,7 +2299,9 @@ def trtllm_bf16_moe( Must be bfloat16 if provided. hidden_states: [seq_len, hidden_size] tensor of input hidden states. Must be bfloat16. - gemm1_weights: [num_experts, 2*intermediate_size // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16. + gemm1_weights: [num_experts, M // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16. + M is 2*intermediate_size for gated activations and + intermediate_size for non-gated activations. gemm2_weights: [num_experts, hidden_size//128, intermediate_size, 128] tensor of second layer weights. must be bfloat16. num_experts: Total number of experts. top_k: Number of experts to route to per token. @@ -2325,7 +2327,7 @@ def trtllm_bf16_moe( tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192). activation_type (int): Type of activation function (default: 3 - Swiglu) - 3: Swiglu - - 6: Relu2 + - 6: Relu2 (non-gated) Returns: when do_finalize=True, returns the final MoE output. @@ -2399,7 +2401,9 @@ def trtllm_bf16_routed_moe( Can be created as: (topk_ids.int32 << 16) | expert_weights.bfloat16.view(int16) hidden_states: [seq_len, hidden_size] tensor of input hidden states. Must be bfloat16. - gemm1_weights: [num_experts, 2*intermediate_size // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16. + gemm1_weights: [num_experts, M // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16. + M is 2*intermediate_size for gated activations and + intermediate_size for non-gated activations. gemm2_weights: [num_experts, hidden_size//128, intermediate_size, 128] tensor of second layer weights. must be bfloat16. num_experts: Total number of experts. top_k: Number of experts to route to per token. @@ -2425,7 +2429,7 @@ def trtllm_bf16_routed_moe( tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192). activation_type (int): Type of activation function (default: 3 - Swiglu) - 3: Swiglu - - 6: Relu2 + - 6: Relu2 (non-gated) Returns: when do_finalize=True, returns the final MoE output. @@ -2498,7 +2502,9 @@ def trtllm_fp8_per_tensor_scale_moe( routing_logits: [seq_len, num_experts] tensor of routing logits routing_bias: [num_experts] tensor of routing bias hidden_states: [seq_len, hidden_size] tensor of input hidden states - gemm1_weights: [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights + gemm1_weights: [num_experts, M, hidden_size] tensor of first layer weights + M is 2*intermediate_size for gated activations and + intermediate_size for non-gated activations. output1_scales_scalar: [local_num_experts] tensor of first layer output scales output1_scales_gate_scalar: [local_num_experts] tensor of first layer gate scales gemm2_weights: [num_experts, hidden_size, intermediate_size] tensor of second layer weights @@ -2520,7 +2526,7 @@ def trtllm_fp8_per_tensor_scale_moe( - 0: Gelu - 3: Swiglu - 4: Geglu - - 6: Relu2 + - 6: Relu2 (non-gated) - 7: Identity Returns: From d0552585ad9fc3770d523911f13d9c262a759aae Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Mon, 13 Apr 2026 08:37:55 +0000 Subject: [PATCH 7/7] clang format fix Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_fused_moe_kernel_launcher.cu | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 00d88cc1e0..c7b46ead24 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1725,17 +1725,15 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { } }; -Array trtllm_bf16_moe(Optional const& routing_logits, - Optional const& routing_bias, - TensorView const& expert_indices, TensorView const& expert_weights, - TensorView const& hidden_states, TensorView const& gemm1_weights, - TensorView const& gemm2_weights, TensorView output, - int64_t num_experts, int64_t top_k, Optional n_group, - Optional topk_group, int64_t intermediate_size, - int64_t local_expert_offset, int64_t local_num_experts, - Optional routed_scaling_factor, int64_t routing_method_type, - bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, - bool enable_pdl, Array moe_tactic, int64_t activation_type, bool norm_topk_prob) { +Array trtllm_bf16_moe( + Optional const& routing_logits, Optional const& routing_bias, + TensorView const& expert_indices, TensorView const& expert_weights, + TensorView const& hidden_states, TensorView const& gemm1_weights, + TensorView const& gemm2_weights, TensorView output, int64_t num_experts, int64_t top_k, + Optional n_group, Optional topk_group, int64_t intermediate_size, + int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, + int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, + bool enable_pdl, Array moe_tactic, int64_t activation_type, bool norm_topk_prob) { // Just some basic type validation first and leave more checks to the launcher if (routing_logits.has_value()) { TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 ||