Skip to content

Commit ede2225

Browse files
authored
Add support for Relu2 in BF16 fused MoE (#2864)
<!-- .github/pull_request_template.md --> ## 📌 Description * Added support for Relu2 non-gated activation in BF16 Fused MoE by adding `activation_type` to external API: * `trtllm_bf16_moe` * `trtllm_bf16_routed_moe` * `Bf16MoeLauncher::init` * Updated trtllm-gen batched GEMM kernels * Updated `tests/moe/test_trtllm_gen_fused_moe.py::test_deepseekv3_routing` to include BF16 with Nemotron config, fixed nemotron config `intermediate_size` test param to match Nemotron 3 Super. * Fixed import issues found by `pre-commit run --all-files` * Required change from trtllm-gen batched GEMM update: Changed `options.mNumStages == 4` to `options.mNumStagesA == 4 && options.mNumStagesB == 4` in `prioritizePredefinedConfigs` function in `csrc/trtllm_batched_gemm_runner.cu`. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * MoE APIs now accept a validated runtime activation_type, enabling selectable activation functions for BF16 and FP8 inference. * **Tests** * Expanded DeepSeekV3 routing tests and added BF16 to non-gated activation coverage. * Updated test parameters to reflect new compatibility. * **Bug Fixes** * Adjusted kernel configuration prioritization for a specific corner-case path. * **Refactor** * Internal enum imports reorganized to a shared enums module. * **Chores** * Updated batched GEMM artifact path and checksum. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
1 parent e64ae8b commit ede2225

7 files changed

Lines changed: 67 additions & 35 deletions

File tree

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ std::vector<int64_t> prioritizePredefinedConfigs(
6868
if (n /* out_dim */ == 0 && k /* in_dim */ == 0) {
6969
auto pred = [](BatchedGemmConfig const& config) {
7070
BatchedGemmOptions const& options = config.mOptions;
71-
return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 256 &&
72-
options.mTileScheduler == TileScheduler::Persistent;
71+
return options.mNumStagesA == 4 && options.mNumStagesB == 4 && options.mNumStagesMma == 2 &&
72+
options.mTileK == 256 && options.mTileScheduler == TileScheduler::Persistent;
7373
};
7474
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
7575
}

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -531,10 +531,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher {
531531

532532
void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args,
533533
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
534-
int64_t weight_layout, bool norm_topk_prob = true) {
535-
constexpr ActivationType activation_type =
536-
ActivationType::Swiglu; // not exposed in api for now
537-
534+
int64_t weight_layout, ActivationType activation_type, bool norm_topk_prob = true) {
538535
// Do base class init and perform common checks
539536
FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type,
540537
use_shuffled_weight, weight_layout, activation_type,
@@ -1728,17 +1725,15 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
17281725
}
17291726
};
17301727

1731-
Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,
1732-
Optional<TensorView> const& routing_bias,
1733-
TensorView const& expert_indices, TensorView const& expert_weights,
1734-
TensorView const& hidden_states, TensorView const& gemm1_weights,
1735-
TensorView const& gemm2_weights, TensorView output,
1736-
int64_t num_experts, int64_t top_k, Optional<int64_t> n_group,
1737-
Optional<int64_t> topk_group, int64_t intermediate_size,
1738-
int64_t local_expert_offset, int64_t local_num_experts,
1739-
Optional<double> routed_scaling_factor, int64_t routing_method_type,
1740-
bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
1741-
bool enable_pdl, Array<int64_t> moe_tactic, bool norm_topk_prob) {
1728+
Array<Tensor> trtllm_bf16_moe(
1729+
Optional<TensorView> const& routing_logits, Optional<TensorView> const& routing_bias,
1730+
TensorView const& expert_indices, TensorView const& expert_weights,
1731+
TensorView const& hidden_states, TensorView const& gemm1_weights,
1732+
TensorView const& gemm2_weights, TensorView output, int64_t num_experts, int64_t top_k,
1733+
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
1734+
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
1735+
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
1736+
bool enable_pdl, Array<int64_t> moe_tactic, int64_t activation_type, bool norm_topk_prob) {
17421737
// Just some basic type validation first and leave more checks to the launcher
17431738
if (routing_logits.has_value()) {
17441739
TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 ||
@@ -1754,6 +1749,7 @@ Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,
17541749

17551750
auto const num_tokens = hidden_states.size(0);
17561751
auto const hidden_size = hidden_states.size(1);
1752+
auto const activation = validateAndCastActivationType(activation_type);
17571753

17581754
// Calculate supported tile sizes
17591755
std::vector<int32_t> mSupportedTileN(Bf16MoeLauncher::mSupportedTileNums.begin(),
@@ -1788,7 +1784,7 @@ Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,
17881784
expert_weights, hidden_states, gemm1_weights,
17891785
gemm2_weights);
17901786
launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight,
1791-
weight_layout, norm_topk_prob);
1787+
weight_layout, activation, norm_topk_prob);
17921788

17931789
launchers_map[curr_tile_N] = std::move(launcher);
17941790
}
@@ -1817,7 +1813,7 @@ Array<Tensor> trtllm_fp8_per_tensor_scale_moe(
18171813
bool enable_pdl, Array<int64_t> config_index, int64_t activation_type, bool norm_topk_prob) {
18181814
// Basic type validation
18191815
auto dtype = hidden_states.dtype();
1820-
auto activation = static_cast<ActivationType>(activation_type);
1816+
auto activation = validateAndCastActivationType(activation_type);
18211817

18221818
TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16)
18231819
<< "FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16.";

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class ArtifactPath:
137137

138138
TRTLLM_GEN_FMHA: str = "55bba55929d4093682e32d817bd11ffb0441c749/fmha/trtllm-gen/"
139139
TRTLLM_GEN_BMM: str = (
140-
"31e75d429ff3f710de1251afdd148185f53da44d/batched_gemm-4daf11e-c111d7c/"
140+
"39a9d28268f43475a757d5700af135e1e58c9849/batched_gemm-5ee61af-2b9855b/"
141141
)
142142
TRTLLM_GEN_GEMM: str = (
143143
"31e75d429ff3f710de1251afdd148185f53da44d/gemm-4daf11e-1fddea2/"
@@ -158,7 +158,7 @@ class CheckSumHash:
158158
"f2c0aad1e74391c4267a2f9a20ec819358b59e04588385cffb452ed341500b99"
159159
)
160160
TRTLLM_GEN_BMM: str = (
161-
"2c2361bdf1deb0a2ea0f130f2d57dd62864f4400a706ac19a625d492b03460cb"
161+
"db06db7f36a2a9395a2041ff6ac016fe664874074413a2ed90797f91ef17e0f6"
162162
)
163163
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
164164
TRTLLM_GEN_GEMM: str = (

flashinfer/fused_moe/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
"""
1616

1717
from .core import (
18-
ActivationType,
19-
Fp8QuantizationType,
20-
RoutingMethodType,
21-
WeightLayout,
2218
convert_to_block_layout,
2319
cutlass_fused_moe,
2420
gen_cutlass_fused_moe_sm120_module,
@@ -37,6 +33,13 @@
3733
trtllm_mxint4_block_scale_moe,
3834
)
3935

36+
from ..tllm_enums import (
37+
ActivationType,
38+
Fp8QuantizationType,
39+
WeightLayout,
40+
RoutingMethodType,
41+
)
42+
4043
from .fused_routing_dsv3 import ( # noqa: F401
4144
fused_topk_deepseek as fused_topk_deepseek,
4245
)

flashinfer/fused_moe/core.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,14 @@
5454
get_last_power_of_2_num_tokens_buckets,
5555
last_positive_power_of_2,
5656
)
57-
from ..tllm_enums import *
57+
from ..tllm_enums import (
58+
ActivationType,
59+
WeightLayout,
60+
DtypeTrtllmGen,
61+
Fp8QuantizationType,
62+
deduce_trtllm_gen_tensor_dtype,
63+
trtllm_gen_dtype_has_scale,
64+
)
5865

5966

6067
@functools.cache
@@ -1107,6 +1114,7 @@ def forward(
11071114
kwargs["do_finalize"],
11081115
kwargs["enable_pdl"],
11091116
[-1, -1] if tactic == -1 else tactic,
1117+
self.activation_type,
11101118
kwargs.get("norm_topk_prob", True),
11111119
)
11121120
elif (
@@ -1290,6 +1298,7 @@ def trtllm_bf16_moe_op(
12901298
do_finalize: bool = True,
12911299
enable_pdl: Optional[bool] = None,
12921300
tune_max_num_tokens: int = 8192,
1301+
activation_type: int = ActivationType.Swiglu.value,
12931302
norm_topk_prob: bool = True,
12941303
) -> List[torch.Tensor]:
12951304
assert routing_logits is not None or topk_ids is not None, (
@@ -1338,7 +1347,7 @@ def trtllm_bf16_moe_op(
13381347
intermediate_size=intermediate_size,
13391348
weight_layout=weight_layout,
13401349
use_shuffled_weight=use_shuffled_weight,
1341-
activation_type=ActivationType.Swiglu, # Default for BF16
1350+
activation_type=activation_type,
13421351
)
13431352

13441353
moe_inputs = MoEInputs(
@@ -1375,6 +1384,7 @@ def trtllm_bf16_moe_op(
13751384
weight_layout=weight_layout,
13761385
do_finalize=do_finalize,
13771386
enable_pdl=enable_pdl,
1387+
activation_type=activation_type,
13781388
)
13791389

13801390
# Call the C++ function with the selected tactic
@@ -1401,6 +1411,7 @@ def trtllm_bf16_moe_op(
14011411
do_finalize,
14021412
enable_pdl,
14031413
[-1, -1] if tactic == -1 else tactic,
1414+
activation_type,
14041415
norm_topk_prob,
14051416
)
14061417
if do_finalize:
@@ -1435,6 +1446,7 @@ def _fake_trtllm_bf16_moe(
14351446
do_finalize: bool = True,
14361447
enable_pdl: Optional[bool] = None,
14371448
tune_max_num_tokens: int = 8192,
1449+
activation_type: int = ActivationType.Swiglu.value,
14381450
norm_topk_prob: bool = True,
14391451
) -> List[torch.Tensor]:
14401452
seq_len = hidden_states.shape[0]
@@ -2272,6 +2284,7 @@ def trtllm_bf16_moe(
22722284
do_finalize: bool = True,
22732285
enable_pdl: bool = True,
22742286
tune_max_num_tokens: int = 8192,
2287+
activation_type: int = ActivationType.Swiglu.value,
22752288
norm_topk_prob: bool = True,
22762289
) -> Union[List[torch.Tensor], torch.Tensor]:
22772290
"""BF16 MoE operation with autotuning support.
@@ -2286,7 +2299,9 @@ def trtllm_bf16_moe(
22862299
Must be bfloat16 if provided.
22872300
hidden_states: [seq_len, hidden_size] tensor of input hidden states.
22882301
Must be bfloat16.
2289-
gemm1_weights: [num_experts, 2*intermediate_size // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16.
2302+
gemm1_weights: [num_experts, M // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16.
2303+
M is 2*intermediate_size for gated activations and
2304+
intermediate_size for non-gated activations.
22902305
gemm2_weights: [num_experts, hidden_size//128, intermediate_size, 128] tensor of second layer weights. must be bfloat16.
22912306
num_experts: Total number of experts.
22922307
top_k: Number of experts to route to per token.
@@ -2310,6 +2325,9 @@ def trtllm_bf16_moe(
23102325
do_finalize: Whether to finalize the output (default: True).
23112326
enable_pdl: Whether to enable Programmatic Dependent Launch. Auto-enabled for >= sm90.
23122327
tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192).
2328+
activation_type (int): Type of activation function (default: 3 - Swiglu)
2329+
- 3: Swiglu
2330+
- 6: Relu2 (non-gated)
23132331
23142332
Returns:
23152333
when do_finalize=True, returns the final MoE output.
@@ -2337,6 +2355,7 @@ def trtllm_bf16_moe(
23372355
do_finalize,
23382356
enable_pdl,
23392357
tune_max_num_tokens,
2358+
activation_type,
23402359
norm_topk_prob,
23412360
)
23422361

@@ -2369,6 +2388,7 @@ def trtllm_bf16_routed_moe(
23692388
do_finalize: bool = True,
23702389
enable_pdl: bool = True,
23712390
tune_max_num_tokens: int = 8192,
2391+
activation_type: int = ActivationType.Swiglu.value,
23722392
) -> List[torch.Tensor]:
23732393
"""BF16 MoE operation with autotuning support.
23742394
@@ -2381,7 +2401,9 @@ def trtllm_bf16_routed_moe(
23812401
Can be created as: (topk_ids.int32 << 16) | expert_weights.bfloat16.view(int16)
23822402
hidden_states: [seq_len, hidden_size] tensor of input hidden states.
23832403
Must be bfloat16.
2384-
gemm1_weights: [num_experts, 2*intermediate_size // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16.
2404+
gemm1_weights: [num_experts, M // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16.
2405+
M is 2*intermediate_size for gated activations and
2406+
intermediate_size for non-gated activations.
23852407
gemm2_weights: [num_experts, hidden_size//128, intermediate_size, 128] tensor of second layer weights. must be bfloat16.
23862408
num_experts: Total number of experts.
23872409
top_k: Number of experts to route to per token.
@@ -2405,6 +2427,9 @@ def trtllm_bf16_routed_moe(
24052427
do_finalize: Whether to finalize the output (default: True).
24062428
enable_pdl: Whether to enable Programmatic Dependent Launch. Auto-enabled for >= sm90.
24072429
tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192).
2430+
activation_type (int): Type of activation function (default: 3 - Swiglu)
2431+
- 3: Swiglu
2432+
- 6: Relu2 (non-gated)
24082433
24092434
Returns:
24102435
when do_finalize=True, returns the final MoE output.
@@ -2432,6 +2457,7 @@ def trtllm_bf16_routed_moe(
24322457
do_finalize,
24332458
enable_pdl,
24342459
tune_max_num_tokens,
2460+
activation_type,
24352461
True, # norm_topk_prob: not used for pre-computed routing
24362462
)
24372463

@@ -2476,7 +2502,9 @@ def trtllm_fp8_per_tensor_scale_moe(
24762502
routing_logits: [seq_len, num_experts] tensor of routing logits
24772503
routing_bias: [num_experts] tensor of routing bias
24782504
hidden_states: [seq_len, hidden_size] tensor of input hidden states
2479-
gemm1_weights: [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights
2505+
gemm1_weights: [num_experts, M, hidden_size] tensor of first layer weights
2506+
M is 2*intermediate_size for gated activations and
2507+
intermediate_size for non-gated activations.
24802508
output1_scales_scalar: [local_num_experts] tensor of first layer output scales
24812509
output1_scales_gate_scalar: [local_num_experts] tensor of first layer gate scales
24822510
gemm2_weights: [num_experts, hidden_size, intermediate_size] tensor of second layer weights
@@ -2498,7 +2526,7 @@ def trtllm_fp8_per_tensor_scale_moe(
24982526
- 0: Gelu
24992527
- 3: Swiglu
25002528
- 4: Geglu
2501-
- 6: Relu2
2529+
- 6: Relu2 (non-gated)
25022530
- 7: Identity
25032531
25042532
Returns:

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,6 +1445,7 @@ def prepare_static_weights_for_kernel(
14451445
self._cache_permute_indices,
14461446
args.gemm1_weights[i].view(torch.uint8),
14471447
epilogue_tile_m,
1448+
is_gated_act_gemm=is_gated_activation(args.activation_type),
14481449
)
14491450
tmp_weights1 = (
14501451
args.gemm1_weights[i]
@@ -1508,6 +1509,7 @@ def call_moe(
15081509
routed_scaling = kwargs["routed_scaling"]
15091510
routing_method_type = kwargs["routing_method_type"]
15101511
enable_autotune = kwargs.get("enable_autotune", True)
1512+
activation_type = kwargs["activation_type"]
15111513
norm_topk_prob = kwargs.get("norm_topk_prob", True)
15121514

15131515
# Use autotuner for optimal kernel selection
@@ -1530,6 +1532,7 @@ def call_moe(
15301532
weight_layout=static_data["weight_layout"],
15311533
routing_method_type=routing_method_type,
15321534
tune_max_num_tokens=TUNE_MAX_NUM_TOKENS,
1535+
activation_type=activation_type,
15331536
norm_topk_prob=norm_topk_prob,
15341537
)
15351538
return output.to(torch.float)
@@ -3131,7 +3134,7 @@ def test_renormalize_routing(
31313134
# Test: DeepSeekV3 routing
31323135
@pytest.mark.parametrize("num_tokens", [8, 768, 3072])
31333136
@pytest.mark.parametrize("hidden_size", [1024])
3134-
@pytest.mark.parametrize("intermediate_size", [2944, 2048, 1024, 768, 512, 384])
3137+
@pytest.mark.parametrize("intermediate_size", [2688, 2048, 1024, 768, 512, 384])
31353138
@pytest.mark.parametrize(
31363139
"moe_impl",
31373140
[
@@ -3164,12 +3167,12 @@ def test_renormalize_routing(
31643167
"routed_scaling": 2.5,
31653168
"has_routing_bias": True,
31663169
"routing_method_type": RoutingMethodType.DeepSeekV3,
3167-
"compatible_moe_impls": [FP8PerTensorMoe, FP4Moe],
3168-
"compatible_intermediate_size": [2944],
3170+
"compatible_moe_impls": [BF16Moe, FP8PerTensorMoe, FP4Moe],
3171+
"compatible_intermediate_size": [2688],
31693172
"compatible_activation_types": [ActivationType.Relu2],
31703173
"enable_autotune": True,
31713174
},
3172-
id="nemotron_3_dummy",
3175+
id="nemotron_3_super",
31733176
),
31743177
pytest.param(
31753178
{

tests/moe/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class QuantMode(IntEnum):
3939
QuantMode.FP4_NVFP4_NVFP4,
4040
QuantMode.FP8_BLOCK_SCALE_MXFP8,
4141
QuantMode.FP8_PER_TENSOR,
42+
QuantMode.BF16,
4243
]
4344

4445

@@ -161,6 +162,7 @@ def skip_checks(
161162
)
162163

163164
if logits_dtype == torch.float32 and moe_impl.quant_mode not in [
165+
QuantMode.FP4_NVFP4_NVFP4,
164166
QuantMode.FP8_PER_TENSOR,
165167
QuantMode.FP8_BLOCK_SCALE_DEEPSEEK,
166168
QuantMode.FP8_BLOCK_SCALE_MXFP8,

0 commit comments

Comments
 (0)