Skip to content

Commit 7f903f7

Browse files
author
Tomer Natan
committed
Add BF16 trtllm-gen MoE: activation_type + routing_replay_out + new BMM cubins
Cherry-pick of flashinfer-ai#2864 (squashed) plus: - activation_type param for trtllm_bf16_moe/trtllm_bf16_routed_moe (Swiglu=3, Relu2=6) - routing_replay_out param for BF16 kernel (same pattern as FP8) - Updated batched GEMM artifacts and checksums - validateAndCastActivationType for safety - Bf16MoeLauncher::init accepts ActivationType + routing_replay_out
1 parent f8eb66b commit 7f903f7

4 files changed

Lines changed: 49 additions & 12 deletions

File tree

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ std::vector<int64_t> prioritizePredefinedConfigs(
6868
if (n /* out_dim */ == 0 && k /* in_dim */ == 0) {
6969
auto pred = [](BatchedGemmConfig const& config) {
7070
BatchedGemmOptions const& options = config.mOptions;
71-
return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 256 &&
72-
options.mTileScheduler == TileScheduler::Persistent;
71+
return options.mNumStagesA == 4 && options.mNumStagesB == 4 && options.mNumStagesMma == 2 &&
72+
options.mTileK == 256 && options.mTileScheduler == TileScheduler::Persistent;
7373
};
7474
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
7575
}

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -476,10 +476,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher {
476476

477477
void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args,
478478
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
479-
int64_t weight_layout) {
480-
constexpr ActivationType activation_type =
481-
ActivationType::Swiglu; // not exposed in api for now
482-
479+
int64_t weight_layout, ActivationType activation_type) {
483480
// Do base class init and perform common checks
484481
FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type,
485482
use_shuffled_weight, weight_layout, activation_type);
@@ -1670,7 +1667,8 @@ Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,
16701667
int64_t local_expert_offset, int64_t local_num_experts,
16711668
Optional<double> routed_scaling_factor, int64_t routing_method_type,
16721669
bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
1673-
bool enable_pdl, Array<int64_t> moe_tactic) {
1670+
bool enable_pdl, Array<int64_t> moe_tactic, int64_t activation_type,
1671+
Optional<TensorView> routing_replay_out) {
16741672
// Just some basic type validation first and leave more checks to the launcher
16751673
if (routing_logits.has_value()) {
16761674
TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 ||
@@ -1686,6 +1684,20 @@ Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,
16861684

16871685
auto const num_tokens = hidden_states.size(0);
16881686
auto const hidden_size = hidden_states.size(1);
1687+
auto const activation = validateAndCastActivationType(activation_type);
1688+
1689+
// Validate routing_replay_out if provided
1690+
if (routing_replay_out.has_value()) {
1691+
auto replay = routing_replay_out.value();
1692+
TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA)
1693+
<< "routing_replay_out must be a CUDA tensor";
1694+
TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id)
1695+
<< "routing_replay_out must be on the same device as hidden_states";
1696+
TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]";
1697+
TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k";
1698+
TVM_FFI_ICHECK(encode_dlpack_dtype(replay.dtype()) == int16_code)
1699+
<< "routing_replay_out must be int16 dtype";
1700+
}
16891701

16901702
// Calculate supported tile sizes
16911703
std::vector<int32_t> mSupportedTileN(Bf16MoeLauncher::mSupportedTileNums.begin(),
@@ -1719,7 +1731,8 @@ Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,
17191731
expert_weights, hidden_states, gemm1_weights,
17201732
gemm2_weights);
17211733
launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight,
1722-
weight_layout);
1734+
weight_layout, activation);
1735+
launcher->set_routing_replay_out(routing_replay_out);
17231736

17241737
launchers_map[curr_tile_N] = std::move(launcher);
17251738
}
@@ -1751,7 +1764,7 @@ Array<Tensor> trtllm_fp8_per_tensor_scale_moe(
17511764
bool enable_pdl, Array<int64_t> config_index, int64_t activation_type) {
17521765
// Basic type validation
17531766
auto dtype = hidden_states.dtype();
1754-
auto activation = static_cast<ActivationType>(activation_type);
1767+
auto activation = validateAndCastActivationType(activation_type);
17551768
if (use_routing_scales_on_input) {
17561769
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16.";
17571770
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class ArtifactPath:
137137

138138
TRTLLM_GEN_FMHA: str = "f1ed60e5666a7620683a8c34a41c850a25029b35/fmha/trtllm-gen/"
139139
TRTLLM_GEN_BMM: str = (
140-
"b55211623be7f5697c5262ffd8361fc06c147bc9/batched_gemm-b3c1646-c111d7c/"
140+
"39a9d28268f43475a757d5700af135e1e58c9849/batched_gemm-5ee61af-2b9855b/"
141141
)
142142
TRTLLM_GEN_GEMM: str = (
143143
"1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3/"
@@ -158,7 +158,7 @@ class CheckSumHash:
158158
"10a54e8c3175099481aed2739ae30fa0f782368c40f9ad1b423ed8353315d65b"
159159
)
160160
TRTLLM_GEN_BMM: str = (
161-
"0af823880730c4f0b3832d2208fab035946694b83444410b9309db5613d60195"
161+
"db06db7f36a2a9395a2041ff6ac016fe664874074413a2ed90797f91ef17e0f6"
162162
)
163163
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
164164
TRTLLM_GEN_GEMM: str = (

flashinfer/fused_moe/core.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,8 @@ def forward(
11321132
kwargs["do_finalize"],
11331133
kwargs["enable_pdl"],
11341134
[-1, -1] if tactic == -1 else tactic,
1135+
self.activation_type,
1136+
kwargs.get("routing_replay_out"),
11351137
)
11361138
elif (
11371139
self.dtype_act == DtypeTrtllmGen.E4m3
@@ -1339,6 +1341,8 @@ def trtllm_bf16_moe_op(
13391341
do_finalize: bool = True,
13401342
enable_pdl: Optional[bool] = None,
13411343
tune_max_num_tokens: int = 8192,
1344+
activation_type: int = ActivationType.Swiglu.value,
1345+
routing_replay_out: Optional[torch.Tensor] = None,
13421346
) -> List[torch.Tensor]:
13431347
assert routing_logits is not None or topk_ids is not None, (
13441348
"either routing_logits or topk_ids must be provided"
@@ -1387,7 +1391,7 @@ def trtllm_bf16_moe_op(
13871391
intermediate_size=intermediate_size,
13881392
weight_layout=weight_layout,
13891393
use_shuffled_weight=use_shuffled_weight,
1390-
activation_type=ActivationType.Swiglu, # Default for BF16
1394+
activation_type=activation_type,
13911395
)
13921396

13931397
inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states]
@@ -1411,6 +1415,8 @@ def trtllm_bf16_moe_op(
14111415
weight_layout=weight_layout,
14121416
do_finalize=do_finalize,
14131417
enable_pdl=enable_pdl,
1418+
activation_type=activation_type,
1419+
routing_replay_out=routing_replay_out,
14141420
)
14151421

14161422
# Call the C++ function with the selected tactic
@@ -1437,6 +1443,8 @@ def trtllm_bf16_moe_op(
14371443
do_finalize,
14381444
enable_pdl,
14391445
[-1, -1] if tactic == -1 else tactic,
1446+
activation_type,
1447+
routing_replay_out,
14401448
)
14411449
if do_finalize:
14421450
return [output]
@@ -1469,6 +1477,8 @@ def _fake_trtllm_bf16_moe(
14691477
do_finalize: bool = True,
14701478
enable_pdl: Optional[bool] = None,
14711479
tune_max_num_tokens: int = 8192,
1480+
activation_type: int = ActivationType.Swiglu.value,
1481+
routing_replay_out: Optional[torch.Tensor] = None,
14721482
) -> List[torch.Tensor]:
14731483
seq_len = hidden_states.shape[0]
14741484
hidden_size = hidden_states.shape[1]
@@ -2265,6 +2275,8 @@ def trtllm_bf16_moe(
22652275
do_finalize: bool = True,
22662276
enable_pdl: bool = True,
22672277
tune_max_num_tokens: int = 8192,
2278+
activation_type: int = ActivationType.Swiglu.value,
2279+
routing_replay_out: Optional[torch.Tensor] = None,
22682280
) -> Union[List[torch.Tensor], torch.Tensor]:
22692281
"""BF16 MoE operation with autotuning support.
22702282
@@ -2302,6 +2314,9 @@ def trtllm_bf16_moe(
23022314
do_finalize: Whether to finalize the output (default: True).
23032315
enable_pdl: Whether to enable Programmatic Dependent Launch. Auto-enabled for >= sm90.
23042316
tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192).
2317+
activation_type (int): Type of activation function (default: 3 - Swiglu)
2318+
- 3: Swiglu
2319+
- 6: Relu2
23052320
23062321
Returns:
23072322
when do_finalize=True, returns the final MoE output.
@@ -2329,6 +2344,8 @@ def trtllm_bf16_moe(
23292344
do_finalize,
23302345
enable_pdl,
23312346
tune_max_num_tokens,
2347+
activation_type,
2348+
routing_replay_out,
23322349
)
23332350

23342351
if do_finalize:
@@ -2360,6 +2377,8 @@ def trtllm_bf16_routed_moe(
23602377
do_finalize: bool = True,
23612378
enable_pdl: bool = True,
23622379
tune_max_num_tokens: int = 8192,
2380+
activation_type: int = ActivationType.Swiglu.value,
2381+
routing_replay_out: Optional[torch.Tensor] = None,
23632382
) -> List[torch.Tensor]:
23642383
"""BF16 MoE operation with autotuning support.
23652384
@@ -2396,6 +2415,9 @@ def trtllm_bf16_routed_moe(
23962415
do_finalize: Whether to finalize the output (default: True).
23972416
enable_pdl: Whether to enable Programmatic Dependent Launch. Auto-enabled for >= sm90.
23982417
tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192).
2418+
activation_type (int): Type of activation function (default: 3 - Swiglu)
2419+
- 3: Swiglu
2420+
- 6: Relu2
23992421
24002422
Returns:
24012423
when do_finalize=True, returns the final MoE output.
@@ -2423,6 +2445,8 @@ def trtllm_bf16_routed_moe(
24232445
do_finalize,
24242446
enable_pdl,
24252447
tune_max_num_tokens,
2448+
activation_type,
2449+
routing_replay_out,
24262450
)
24272451

24282452
if do_finalize:

0 commit comments

Comments
 (0)