Skip to content

Commit 34ff609

Browse files
zianglihameynaik-hub
authored andcommitted
feat: Implement cutlass_fused_moe mxfp8 (flashinfer-ai#2581)
<!-- .github/pull_request_template.md --> ## 📌 Description @HumansAnd <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> flashinfer-ai#2505 implements mxfp8 for trtllm backend. However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses SGLang topk implementation and does not work with expert routing replay in MoE RL. We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works with MoE RL training. This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`: https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191 ## 🔍 Related Issues <!-- Link any related issues here --> miles MXFP8/NVFP4 RL roadmap: radixark/miles#615 SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Toggleable MXFPX/MXFP8 activation-scaling across MOE inference, updating workspace sizing, kernel selection, block-scaling and dispatch to enable MXFP8-aware execution and validation. * Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling flag. * **Tests** * Added unit tests and helpers for MXFP8 quantization, packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference validation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
1 parent b16b725 commit 34ff609

11 files changed

Lines changed: 713 additions & 332 deletions

File tree

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@ template class CutlassMoeFCRunner<half, cutlass::uint4b_t>;
3232
#ifdef ENABLE_FP8
3333
// template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3>;
3434
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>;
35+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half, __nv_fp8_e4m3, half, true>;
3536
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>;
3637
#ifdef ENABLE_BF16
3738
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
39+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_fp8_e4m3,
40+
__nv_bfloat16, true>;
3841
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>;
3942
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>;
4043
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>;

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 136 additions & 92 deletions
Large diffs are not rendered by default.

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu

Lines changed: 92 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class DtypeUtils {
8282

8383
class FusedMoeRunner : public tvm::ffi::ModuleObj {
8484
public:
85-
template <typename TypeAct, typename TypeWeight, bool NeedQuant = false>
85+
template <typename TypeAct, typename TypeWeight, bool NeedQuant = false, bool IsMXFPX = false>
8686
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> switch_output_type(DLDataType output_type) {
8787
switch (encode_dlpack_dtype(output_type)) {
8888
case int64_code: // INT64 == FP4
@@ -94,19 +94,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
9494
// return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type>>();
9595
case float16_code:
9696
if constexpr (NeedQuant) {
97-
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, half>>();
97+
return std::make_unique<
98+
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, half, half, IsMXFPX>>();
9899
} else {
99100
return std::make_unique<
100-
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, TypeAct>>();
101+
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, TypeAct, half, IsMXFPX>>();
101102
}
102103
#ifdef ENABLE_BF16
103104
case bfloat16_code:
104105
if constexpr (NeedQuant) {
105-
return std::make_unique<
106-
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, __nv_bfloat16>>();
106+
return std::make_unique<kernels::CutlassMoeFCRunner<
107+
TypeAct, TypeWeight, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, IsMXFPX>>();
107108
} else {
108-
return std::make_unique<
109-
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, TypeAct>>();
109+
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16,
110+
TypeAct, __nv_bfloat16, IsMXFPX>>();
110111
}
111112
#endif
112113
default:
@@ -145,7 +146,9 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
145146
#endif
146147

147148
#ifdef ENABLE_FP8
148-
if (isFp8Quant()) {
149+
if (isWMxfp8AMxfp8Quant()) {
150+
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3, false, true>(mOutputDtype);
151+
} else if (isFp8Quant()) {
149152
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3>(mOutputDtype);
150153
}
151154
#endif
@@ -397,8 +400,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
397400
static_cast<int>(experts_per_token),
398401
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
399402
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
400-
use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
401-
enable_pdl, stream);
403+
use_lora, lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode,
404+
min_latency_params, enable_pdl, stream);
402405
#else
403406
mKernelRunner->runMoe(
404407
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
@@ -414,7 +417,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
414417
static_cast<int>(experts_per_token),
415418
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
416419
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
417-
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream);
420+
mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, min_latency_params,
421+
enable_pdl, stream);
418422
#endif
419423
}
420424

@@ -490,8 +494,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
490494
<< "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.";
491495
}
492496

493-
TVM_FFI_ICHECK(!input_sf.has_value() || isWMxfp4AMxfp8Quant() || isNvfp4Quant())
497+
TVM_FFI_ICHECK(!input_sf.has_value() || isMxfp8ActScalingQuant() || isNvfp4Quant())
494498
<< "Block-scaling factors provided for non block-scaling quantization";
499+
TVM_FFI_ICHECK(!isMxfp8ActScalingQuant() || input_sf.has_value())
500+
<< "input_sf must be provided when use_mxfp8_act_scaling=True";
495501

496502
int experts_per_token = token_selected_experts.size(1);
497503
int64_t num_rows = input.size(0);
@@ -581,8 +587,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
581587
static_cast<int>(experts_per_token),
582588
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
583589
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
584-
use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
585-
enable_pdl, stream);
590+
use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling,
591+
min_latency_mode, min_latency_params, enable_pdl, stream);
586592
#else
587593
mKernelRunner->runMoe(
588594
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
@@ -598,8 +604,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
598604
static_cast<int>(experts_per_token),
599605
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
600606
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, use_lora_ml,
601-
lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl,
602-
stream);
607+
lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode,
608+
min_latency_params, enable_pdl, stream);
603609
#endif
604610
}
605611

@@ -838,8 +844,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
838844
bool min_latency_mode) {
839845
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(
840846
num_rows, hidden_size, inter_size, num_experts, experts_per_token, activation_type,
841-
parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, min_latency_mode,
842-
mUseW4GroupScaling);
847+
parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling,
848+
min_latency_mode, mUseW4GroupScaling);
843849
size_t src_to_dest_map_size = experts_per_token * num_rows * sizeof(int);
844850

845851
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
@@ -862,7 +868,66 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
862868
int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size,
863869
Optional<Array<Tensor>> quant_scales,
864870
ActivationType base_activation_type = ActivationType::Swiglu) const {
865-
if (isFp8Quant()) {
871+
if (isWMxfp8AMxfp8Quant()) {
872+
#ifdef USING_OSS_CUTLASS_MOE_GEMM
873+
TVM_FFI_ICHECK(quant_scales.has_value())
874+
<< "Expecting quant scales for MXFP8xMXFP8 quantization";
875+
TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4)
876+
<< "Expecting 4 quant scales for MXFP8xMXFP8 quantization";
877+
878+
TensorView fc1_weight_block = quant_scales.value()[0];
879+
TensorView fc1_global = quant_scales.value()[1];
880+
TensorView fc2_weight_block = quant_scales.value()[2];
881+
TensorView fc2_global = quant_scales.value()[3];
882+
883+
// The input for scale fc1_weight_block / fc2_weight_block is packed into INT32
884+
constexpr int FP8_PER_INT32 = 4;
885+
CHECK_INPUT_TYPE(fc1_weight_block, dl_int32);
886+
CHECK_INPUT_TYPE(fc1_global, dl_float32);
887+
CHECK_INPUT_TYPE(fc2_weight_block, dl_int32);
888+
CHECK_INPUT_TYPE(fc2_global, dl_float32);
889+
CHECK_DIM(3, fc1_weight_block);
890+
CHECK_DIM(1, fc1_global);
891+
CHECK_DIM(3, fc2_weight_block);
892+
CHECK_DIM(1, fc2_global);
893+
TVM_FFI_ICHECK(
894+
fc1_weight_block.size(0) == num_experts_on_rank &&
895+
fc1_weight_block.size(1) ==
896+
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
897+
inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) *
898+
2 &&
899+
fc1_weight_block.size(2) * FP8_PER_INT32 *
900+
TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize ==
901+
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
902+
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX))
903+
<< "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 "
904+
"// block_scale_vector_size)";
905+
TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank)
906+
<< "fc1 global size must be (num_experts_on_rank,)";
907+
TVM_FFI_ICHECK(
908+
fc2_weight_block.size(0) == num_experts_on_rank &&
909+
fc2_weight_block.size(1) ==
910+
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
911+
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) &&
912+
fc2_weight_block.size(2) * FP8_PER_INT32 *
913+
TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize ==
914+
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
915+
inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX))
916+
<< "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // "
917+
"block_scale_vector_size)";
918+
TVM_FFI_ICHECK_EQ(fc2_global.size(0), num_experts_on_rank)
919+
<< "fc2 global size must be (num_experts_on_rank,)";
920+
921+
return kernels::QuantParams::MXFP8MXFP8(
922+
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc1_weight_block.data_ptr()),
923+
static_cast<float const*>(fc1_global.data_ptr()),
924+
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc2_weight_block.data_ptr()),
925+
static_cast<float const*>(fc2_global.data_ptr()));
926+
#else
927+
TVM_FFI_ICHECK(false)
928+
<< "MXFP8 x MXFP8 quantization is not supported in OSS Cutlass Moe Gemm";
929+
#endif
930+
} else if (isFp8Quant()) {
866931
TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for fp8 quantization";
867932
TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4)
868933
<< "Expecting 4 quant scales for fp8 quantization";
@@ -1168,9 +1233,16 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
11681233

11691234
bool isFp8Quant() const {
11701235
return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn &&
1171-
mWeightDtype == dl_float8_e4m3fn;
1236+
mWeightDtype == dl_float8_e4m3fn && !mUseMxfp8ActScaling;
11721237
}
11731238

1239+
bool isWMxfp8AMxfp8Quant() const {
1240+
return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn &&
1241+
mWeightDtype == dl_float8_e4m3fn && mUseMxfp8ActScaling;
1242+
}
1243+
1244+
bool isMxfp8ActScalingQuant() const { return isWMxfp8AMxfp8Quant() || isWMxfp4AMxfp8Quant(); }
1245+
11741246
bool isNvfp4Quant() const {
11751247
return mWeightDtype == dl_int64 &&
11761248
mActivationDtype != dl_float8_e4m3fn; // FP8 activation does not use FP4

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,11 @@ constexpr bool isGatedActivation(ActivationType activation_type) {
239239
activation_type == ActivationType::SwigluBias;
240240
}
241241

242-
template <typename T, /*The type used for activations/scales/compute*/
243-
typename WeightType, /* The type for the MoE weights */
244-
typename OutputType, /* The output type for the GEMM */
245-
typename ScaleBiasType = OutputType /* The type for the scales/bias */
246-
>
242+
template <typename T, /*The type used for activations/scales/compute*/
243+
typename WeightType, /* The type for the MoE weights */
244+
typename OutputType, /* The output type for the GEMM */
245+
typename ScaleBiasType = OutputType, /* The type for the scales/bias */
246+
bool IsMXFPX = false>
247247
class MoeGemmRunner {
248248
public:
249249
MoeGemmRunner();
@@ -273,6 +273,8 @@ class MoeGemmRunner {
273273
static constexpr bool use_fp8 = false;
274274
static constexpr bool use_w4afp8 = false;
275275
#endif
276+
static constexpr bool use_mxfp8 = use_fp8 && IsMXFPX;
277+
276278
static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
277279

278280
#if defined(ENABLE_FP4)

0 commit comments

Comments
 (0)