Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ std::vector<int64_t> prioritizePredefinedConfigs(
if (n /* out_dim */ == 0 && k /* in_dim */ == 0) {
auto pred = [](BatchedGemmConfig const& config) {
BatchedGemmOptions const& options = config.mOptions;
return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 256 &&
options.mTileScheduler == TileScheduler::Persistent;
return options.mNumStagesA == 4 && options.mNumStagesB == 4 && options.mNumStagesMma == 2 &&
options.mTileK == 256 && options.mTileScheduler == TileScheduler::Persistent;
};
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
}
Expand Down
30 changes: 13 additions & 17 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher {

void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args,
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
int64_t weight_layout, bool norm_topk_prob = true) {
constexpr ActivationType activation_type =
ActivationType::Swiglu; // not exposed in api for now

int64_t weight_layout, ActivationType activation_type, bool norm_topk_prob = true) {
// Do base class init and perform common checks
FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type,
use_shuffled_weight, weight_layout, activation_type,
Expand Down Expand Up @@ -1728,17 +1725,15 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
}
};

Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,
Optional<TensorView> const& routing_bias,
TensorView const& expert_indices, TensorView const& expert_weights,
TensorView const& hidden_states, TensorView const& gemm1_weights,
TensorView const& gemm2_weights, TensorView output,
int64_t num_experts, int64_t top_k, Optional<int64_t> n_group,
Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts,
Optional<double> routed_scaling_factor, int64_t routing_method_type,
bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
bool enable_pdl, Array<int64_t> moe_tactic, bool norm_topk_prob) {
Array<Tensor> trtllm_bf16_moe(
Optional<TensorView> const& routing_logits, Optional<TensorView> const& routing_bias,
TensorView const& expert_indices, TensorView const& expert_weights,
TensorView const& hidden_states, TensorView const& gemm1_weights,
TensorView const& gemm2_weights, TensorView output, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
bool enable_pdl, Array<int64_t> moe_tactic, int64_t activation_type, bool norm_topk_prob) {
// Just some basic type validation first and leave more checks to the launcher
if (routing_logits.has_value()) {
TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 ||
Expand All @@ -1754,6 +1749,7 @@ Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,

auto const num_tokens = hidden_states.size(0);
auto const hidden_size = hidden_states.size(1);
auto const activation = validateAndCastActivationType(activation_type);

// Calculate supported tile sizes
std::vector<int32_t> mSupportedTileN(Bf16MoeLauncher::mSupportedTileNums.begin(),
Expand Down Expand Up @@ -1788,7 +1784,7 @@ Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,
expert_weights, hidden_states, gemm1_weights,
gemm2_weights);
launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight,
weight_layout, norm_topk_prob);
weight_layout, activation, norm_topk_prob);

launchers_map[curr_tile_N] = std::move(launcher);
}
Expand Down Expand Up @@ -1817,7 +1813,7 @@ Array<Tensor> trtllm_fp8_per_tensor_scale_moe(
bool enable_pdl, Array<int64_t> config_index, int64_t activation_type, bool norm_topk_prob) {
// Basic type validation
auto dtype = hidden_states.dtype();
auto activation = static_cast<ActivationType>(activation_type);
auto activation = validateAndCastActivationType(activation_type);

TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16)
<< "FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16.";
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class ArtifactPath:

TRTLLM_GEN_FMHA: str = "55bba55929d4093682e32d817bd11ffb0441c749/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"31e75d429ff3f710de1251afdd148185f53da44d/batched_gemm-4daf11e-c111d7c/"
"39a9d28268f43475a757d5700af135e1e58c9849/batched_gemm-5ee61af-2b9855b/"
)
TRTLLM_GEN_GEMM: str = (
"31e75d429ff3f710de1251afdd148185f53da44d/gemm-4daf11e-1fddea2/"
Expand All @@ -158,7 +158,7 @@ class CheckSumHash:
"f2c0aad1e74391c4267a2f9a20ec819358b59e04588385cffb452ed341500b99"
)
TRTLLM_GEN_BMM: str = (
"2c2361bdf1deb0a2ea0f130f2d57dd62864f4400a706ac19a625d492b03460cb"
"db06db7f36a2a9395a2041ff6ac016fe664874074413a2ed90797f91ef17e0f6"
)
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
TRTLLM_GEN_GEMM: str = (
Expand Down
11 changes: 7 additions & 4 deletions flashinfer/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
"""

from .core import (
ActivationType,
Fp8QuantizationType,
RoutingMethodType,
WeightLayout,
convert_to_block_layout,
cutlass_fused_moe,
gen_cutlass_fused_moe_sm120_module,
Expand All @@ -37,6 +33,13 @@
trtllm_mxint4_block_scale_moe,
)

from ..tllm_enums import (
ActivationType,
Fp8QuantizationType,
WeightLayout,
RoutingMethodType,
)

from .fused_routing_dsv3 import ( # noqa: F401
fused_topk_deepseek as fused_topk_deepseek,
)
Expand Down
40 changes: 34 additions & 6 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
)
from ..tllm_enums import *
from ..tllm_enums import (
ActivationType,
WeightLayout,
DtypeTrtllmGen,
Fp8QuantizationType,
deduce_trtllm_gen_tensor_dtype,
trtllm_gen_dtype_has_scale,
)


@functools.cache
Expand Down Expand Up @@ -1107,6 +1114,7 @@ def forward(
kwargs["do_finalize"],
kwargs["enable_pdl"],
[-1, -1] if tactic == -1 else tactic,
self.activation_type,
kwargs.get("norm_topk_prob", True),
)
elif (
Expand Down Expand Up @@ -1290,6 +1298,7 @@ def trtllm_bf16_moe_op(
do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
activation_type: int = ActivationType.Swiglu.value,
norm_topk_prob: bool = True,
) -> List[torch.Tensor]:
assert routing_logits is not None or topk_ids is not None, (
Expand Down Expand Up @@ -1338,7 +1347,7 @@ def trtllm_bf16_moe_op(
intermediate_size=intermediate_size,
weight_layout=weight_layout,
use_shuffled_weight=use_shuffled_weight,
activation_type=ActivationType.Swiglu, # Default for BF16
activation_type=activation_type,
)

moe_inputs = MoEInputs(
Expand Down Expand Up @@ -1375,6 +1384,7 @@ def trtllm_bf16_moe_op(
weight_layout=weight_layout,
do_finalize=do_finalize,
enable_pdl=enable_pdl,
activation_type=activation_type,
)

# Call the C++ function with the selected tactic
Expand All @@ -1401,6 +1411,7 @@ def trtllm_bf16_moe_op(
do_finalize,
enable_pdl,
[-1, -1] if tactic == -1 else tactic,
activation_type,
norm_topk_prob,
)
if do_finalize:
Expand Down Expand Up @@ -1435,6 +1446,7 @@ def _fake_trtllm_bf16_moe(
do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
activation_type: int = ActivationType.Swiglu.value,
norm_topk_prob: bool = True,
) -> List[torch.Tensor]:
seq_len = hidden_states.shape[0]
Expand Down Expand Up @@ -2272,6 +2284,7 @@ def trtllm_bf16_moe(
do_finalize: bool = True,
enable_pdl: bool = True,
tune_max_num_tokens: int = 8192,
activation_type: int = ActivationType.Swiglu.value,
norm_topk_prob: bool = True,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""BF16 MoE operation with autotuning support.
Expand All @@ -2286,7 +2299,9 @@ def trtllm_bf16_moe(
Must be bfloat16 if provided.
hidden_states: [seq_len, hidden_size] tensor of input hidden states.
Must be bfloat16.
gemm1_weights: [num_experts, 2*intermediate_size // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16.
gemm1_weights: [num_experts, M // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16.
M is 2*intermediate_size for gated activations and
intermediate_size for non-gated activations.
gemm2_weights: [num_experts, hidden_size//128, intermediate_size, 128] tensor of second layer weights. must be bfloat16.
num_experts: Total number of experts.
top_k: Number of experts to route to per token.
Expand All @@ -2310,6 +2325,9 @@ def trtllm_bf16_moe(
do_finalize: Whether to finalize the output (default: True).
enable_pdl: Whether to enable Programmatic Dependent Launch. Auto-enabled for >= sm90.
tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192).
activation_type (int): Type of activation function (default: 3 - Swiglu)
- 3: Swiglu
- 6: Relu2 (non-gated)

Returns:
when do_finalize=True, returns the final MoE output.
Expand Down Expand Up @@ -2337,6 +2355,7 @@ def trtllm_bf16_moe(
do_finalize,
enable_pdl,
tune_max_num_tokens,
activation_type,
norm_topk_prob,
)

Expand Down Expand Up @@ -2369,6 +2388,7 @@ def trtllm_bf16_routed_moe(
do_finalize: bool = True,
enable_pdl: bool = True,
tune_max_num_tokens: int = 8192,
activation_type: int = ActivationType.Swiglu.value,
) -> List[torch.Tensor]:
"""BF16 MoE operation with autotuning support.

Expand All @@ -2381,7 +2401,9 @@ def trtllm_bf16_routed_moe(
Can be created as: (topk_ids.int32 << 16) | expert_weights.bfloat16.view(int16)
hidden_states: [seq_len, hidden_size] tensor of input hidden states.
Must be bfloat16.
gemm1_weights: [num_experts, 2*intermediate_size // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16.
gemm1_weights: [num_experts, M // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16.
M is 2*intermediate_size for gated activations and
intermediate_size for non-gated activations.
gemm2_weights: [num_experts, hidden_size//128, intermediate_size, 128] tensor of second layer weights. must be bfloat16.
num_experts: Total number of experts.
top_k: Number of experts to route to per token.
Expand All @@ -2405,6 +2427,9 @@ def trtllm_bf16_routed_moe(
do_finalize: Whether to finalize the output (default: True).
enable_pdl: Whether to enable Programmatic Dependent Launch. Auto-enabled for >= sm90.
tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192).
activation_type (int): Type of activation function (default: 3 - Swiglu)
- 3: Swiglu
- 6: Relu2 (non-gated)

Returns:
when do_finalize=True, returns the final MoE output.
Expand Down Expand Up @@ -2432,6 +2457,7 @@ def trtllm_bf16_routed_moe(
do_finalize,
enable_pdl,
tune_max_num_tokens,
activation_type,
True, # norm_topk_prob: not used for pre-computed routing
)

Expand Down Expand Up @@ -2476,7 +2502,9 @@ def trtllm_fp8_per_tensor_scale_moe(
routing_logits: [seq_len, num_experts] tensor of routing logits
routing_bias: [num_experts] tensor of routing bias
hidden_states: [seq_len, hidden_size] tensor of input hidden states
gemm1_weights: [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights
gemm1_weights: [num_experts, M, hidden_size] tensor of first layer weights
M is 2*intermediate_size for gated activations and
intermediate_size for non-gated activations.
output1_scales_scalar: [local_num_experts] tensor of first layer output scales
output1_scales_gate_scalar: [local_num_experts] tensor of first layer gate scales
gemm2_weights: [num_experts, hidden_size, intermediate_size] tensor of second layer weights
Expand All @@ -2498,7 +2526,7 @@ def trtllm_fp8_per_tensor_scale_moe(
- 0: Gelu
- 3: Swiglu
- 4: Geglu
- 6: Relu2
- 6: Relu2 (non-gated)
- 7: Identity

Returns:
Expand Down
11 changes: 7 additions & 4 deletions tests/moe/test_trtllm_gen_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,7 @@ def prepare_static_weights_for_kernel(
self._cache_permute_indices,
args.gemm1_weights[i].view(torch.uint8),
epilogue_tile_m,
is_gated_act_gemm=is_gated_activation(args.activation_type),
)
tmp_weights1 = (
args.gemm1_weights[i]
Expand Down Expand Up @@ -1508,6 +1509,7 @@ def call_moe(
routed_scaling = kwargs["routed_scaling"]
routing_method_type = kwargs["routing_method_type"]
enable_autotune = kwargs.get("enable_autotune", True)
activation_type = kwargs["activation_type"]
norm_topk_prob = kwargs.get("norm_topk_prob", True)

# Use autotuner for optimal kernel selection
Expand All @@ -1530,6 +1532,7 @@ def call_moe(
weight_layout=static_data["weight_layout"],
routing_method_type=routing_method_type,
tune_max_num_tokens=TUNE_MAX_NUM_TOKENS,
activation_type=activation_type,
norm_topk_prob=norm_topk_prob,
)
return output.to(torch.float)
Expand Down Expand Up @@ -3131,7 +3134,7 @@ def test_renormalize_routing(
# Test: DeepSeekV3 routing
@pytest.mark.parametrize("num_tokens", [8, 768, 3072])
@pytest.mark.parametrize("hidden_size", [1024])
@pytest.mark.parametrize("intermediate_size", [2944, 2048, 1024, 768, 512, 384])
@pytest.mark.parametrize("intermediate_size", [2688, 2048, 1024, 768, 512, 384])
@pytest.mark.parametrize(
"moe_impl",
[
Expand Down Expand Up @@ -3164,12 +3167,12 @@ def test_renormalize_routing(
"routed_scaling": 2.5,
"has_routing_bias": True,
"routing_method_type": RoutingMethodType.DeepSeekV3,
"compatible_moe_impls": [FP8PerTensorMoe, FP4Moe],
"compatible_intermediate_size": [2944],
"compatible_moe_impls": [BF16Moe, FP8PerTensorMoe, FP4Moe],
"compatible_intermediate_size": [2688],
"compatible_activation_types": [ActivationType.Relu2],
"enable_autotune": True,
},
id="nemotron_3_dummy",
id="nemotron_3_super",
),
pytest.param(
{
Expand Down
2 changes: 2 additions & 0 deletions tests/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class QuantMode(IntEnum):
QuantMode.FP4_NVFP4_NVFP4,
QuantMode.FP8_BLOCK_SCALE_MXFP8,
QuantMode.FP8_PER_TENSOR,
QuantMode.BF16,
]


Expand Down Expand Up @@ -161,6 +162,7 @@ def skip_checks(
)

if logits_dtype == torch.float32 and moe_impl.quant_mode not in [
QuantMode.FP4_NVFP4_NVFP4,
QuantMode.FP8_PER_TENSOR,
QuantMode.FP8_BLOCK_SCALE_DEEPSEEK,
QuantMode.FP8_BLOCK_SCALE_MXFP8,
Expand Down
Loading