Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
46205b4
Using ActivationType instead of GatedActType, added compiled kernels,…
amitz-nv Jan 28, 2026
b8eac34
Add actType and eltwiseActType to 'no kernel found' message, move is_…
amitz-nv Jan 28, 2026
f771e0c
Update remaining GatedActType uses to ActivationType, remove GatedAct…
amitz-nv Jan 28, 2026
440c062
Use ActivationType in benchmarks, add missing activation_type argument
amitz-nv Jan 28, 2026
5725739
Minor fixes
amitz-nv Jan 28, 2026
c2c8531
Fix activation_type default value to Swiglu on trtllm_fp4_block_scale…
amitz-nv Jan 28, 2026
bb4e821
Minor improvement
amitz-nv Jan 28, 2026
c6ac4af
Support non-gated activation in NVFP4 block scale MoE
amitz-nv Jan 28, 2026
3bf918e
Rename useShuffledMatrixA to useShuffledMatrix (remove the 'A' suffix)
amitz-nv Jan 28, 2026
1193b02
Add FP4_NVFP4_NVFP4 parameterization to test_llama4_routing, update t…
amitz-nv Jan 28, 2026
b0e6d59
Increase supported topK and num experts in deepseek routing for nemotron
amitz-nv Jan 28, 2026
d4182ae
Commit more files for increase supported topK and num experts in deep…
amitz-nv Jan 28, 2026
8ee2193
Fix formatting
amitz-nv Jan 28, 2026
c899d16
Change TODO to comment
amitz-nv Jan 28, 2026
0f6f15c
Change default activation_type to Swiglu
amitz-nv Jan 28, 2026
cf6f76b
Restore intermediate size factor of 2 for gated activation in getWork…
amitz-nv Jan 28, 2026
e63e17d
Formatting fixes
amitz-nv Jan 28, 2026
8398e20
Treat SwigluBias as gated activation
amitz-nv Jan 28, 2026
ea67cef
Fix use of ActivationType enum in CLI
amitz-nv Jan 28, 2026
abefe22
Fix activation-type command line argument handling in benchmarks
amitz-nv Jan 29, 2026
da35764
Fix choices of activation-type command line argument handling in benc…
amitz-nv Jan 29, 2026
205989f
GEMM (non batched) still has mUseShuffledMatrixA member (with 'A' suf…
amitz-nv Jan 29, 2026
e467f1d
Update bench_trtllm_gen_fused_moe_autotuner.py to support more activa…
amitz-nv Jan 29, 2026
80d1b53
Revert activation_Type check in bench_trtllm_gen_fused_moe_autotuner.…
amitz-nv Jan 29, 2026
21e0e08
Include activation type in results in benchmarks/routings/moe.py
amitz-nv Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from flashinfer import (
RoutingMethodType,
GatedActType,
ActivationType,
fp4_quantize,
mxfp8_quantize,
)
Expand Down Expand Up @@ -39,6 +39,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
top_k: int,
warmups: int,
iterations: int,
activation_type: ActivationType,
):
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)
Expand Down Expand Up @@ -97,6 +98,10 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
)

if is_block_scale:
if activation_type != ActivationType.Swiglu:
raise ValueError(
"Only Swiglu activation is supported for FP8 block scale MoE."
)
fn = lambda: trtllm_fp8_block_scale_moe(
routing_logits,
routing_bias,
Expand Down Expand Up @@ -144,6 +149,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
RoutingMethodType.TopK.value,
enable_pdl,
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
activation_type.value,
)

def bench(do_autotune):
Expand Down Expand Up @@ -288,7 +294,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
RoutingMethodType.Renormalize.value,
True,
enable_pdl,
GatedActType.SwiGlu.value, # gated_act_type
ActivationType.Swiglu.value, # act_type
None,
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
)
Expand Down Expand Up @@ -348,6 +354,14 @@ def bench(do_autotune):
parser.add_argument(
"--iterations", type=int, default=100, help="Number of benchmark iterations"
)
parser.add_argument(
"--activation-type",
type=ActivationType,
choices=list(ActivationType),
required=False,
default=ActivationType.Swiglu,
help=f"Type of gated activation function: {list(ActivationType)}",
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
args = parser.parse_args()
if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]:
bench_trtllm_gen_fused_moe_autotuner_fp8(
Expand All @@ -360,6 +374,7 @@ def bench(do_autotune):
args.top_k,
args.warmups,
args.iterations,
args.activation_type,
)
else:
bench_trtllm_gen_fused_moe_autotuner_fp4(
Expand Down
25 changes: 11 additions & 14 deletions benchmarks/routines/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

import flashinfer
from flashinfer import ActivationType
from flashinfer.autotuner import autotune
from flashinfer.fused_moe import (
WeightLayout,
Expand Down Expand Up @@ -175,12 +176,12 @@ def parse_moe_args(line, parser):
help="Data type of the weights (before quantization).",
)
parser.add_argument(
"--gated_act",
type=str,
"--activation-type",
type=ActivationType,
choices=list(ActivationType),
required=False,
default="swiglu",
choices=["swiglu", "geglu"],
help="Type of gated activation function: swiglu | geglu.",
default=ActivationType.Swiglu,
help=f"Type of gated activation function: {list(ActivationType)}",
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
parser.add_argument(
"--autotune",
Expand Down Expand Up @@ -247,13 +248,6 @@ def parse_moe_args(line, parser):
}
args.routing_method_type = routing_method_name_to_type[args.routing_method]

# Normalize gated act type (map string to internal int expected by kernels)
gated_act_name_to_type = {
"swiglu": 0,
"geglu": 1,
}
args.gated_act_type = gated_act_name_to_type[args.gated_act]

if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
Expand Down Expand Up @@ -630,7 +624,7 @@ def testTrtllmFp4BlockScaleMoe(args):
use_shuffled_weight = args.use_shuffled_weight
weight_layout = args.weight_layout
is_cuda_graph_compatible = not args.no_cuda_graph
gated_act_type = args.gated_act_type
activation_type = args.activation_type
Comment thread
coderabbitai[bot] marked this conversation as resolved.
res = []

backends = ["trtllm"]
Expand Down Expand Up @@ -795,7 +789,7 @@ def run_fp4_moe(
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=routing_method_type,
gated_act_type=gated_act_type,
activation_type=activation_type.value,
do_finalize=True,
)

Expand Down Expand Up @@ -1671,6 +1665,7 @@ def run_fp8_per_tensor_moe(
output1_scales_gate_scalar,
gemm2_weights_fp8,
output2_scales_scalar,
activation_type,
):
# Note: FP8 per-tensor MOE expects int64_t for n_group/topk_group, not Optional[int64_t]
# So we convert None to 0 to indicate "no groups" mode
Expand All @@ -1693,6 +1688,7 @@ def run_fp8_per_tensor_moe(
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
routing_method_type=routing_method_type,
activation_type=activation_type.value,
)

# Benchmark timing
Expand All @@ -1713,6 +1709,7 @@ def run_fp8_per_tensor_moe(
output1_scales_gate_scalar,
gemm2_weights_fp8,
output2_scales_scalar,
args.activation_type,
),
)

Expand Down
10 changes: 8 additions & 2 deletions csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,16 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
options.mTransposeMmaOutput == mOptions.transposeMmaOutput &&
(!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct &&
options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch &&
tileSize == mOptions.tileSize &&
options.mUseShuffledMatrix == mOptions.useShuffledMatrixA &&
tileSize == mOptions.tileSize && options.mUseShuffledMatrix == mOptions.useShuffledMatrix &&
options.mLayoutA == mOptions.weightLayout) {
if (options.mFusedAct) {
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType)) {
continue;
}
}
if ((int64_t)options.mEltwiseActType != (int64_t)mOptions.eltwiseActType) {
continue;
}

if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) {
mPassingConfigIndices.push_back(i);
Expand All @@ -122,6 +124,8 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
<< ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB)
<< ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC)
<< ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8
<< ", mActType: " << (int64_t)mOptions.actType
<< ", mEltwiseActType: " << (int64_t)mOptions.eltwiseActType
<< ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput
<< ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct
<< ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize;
Expand Down Expand Up @@ -219,6 +223,8 @@ void TrtllmGenBatchedGemmRunner::run(
gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB;
gemmData.mInputBuffers.mPtrScaleC = scaleC;
gemmData.mInputBuffers.mPtrScaleGate = scaleGateC;
// For simplicity pass set scaleAct to scaleGateC
gemmData.mInputBuffers.mPtrScaleAct = scaleGateC;
Copy link
Copy Markdown
Contributor Author

@amitz-nv amitz-nv Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decide whether it's OK or fix in the future?

gemmData.mInputBuffers.mPtrPerTokenSfA =
mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA;
gemmData.mInputBuffers.mPtrPerTokenSfB =
Expand Down
Loading
Loading