Skip to content

Commit 04f4c0c

Browse files
askliarAndrii Skliargemini-code-assist[bot]samuelleesaleozlx
authored
fix: MXFP4/MXFP8 failures in SM120 FAST_BUILD and expand all_tiles[] (#2994)
**Problem** MXFP4 and MXFP8 GEMM operations were failing on SM120 because: - The FAST_BUILD path returned a single hardcoded CtaShape128x128x64B tile regardless of GROUPED_GEMM, and that tile is not valid for all MXFP4/MXFP8 configurations - The full-build all_tiles[] table was missing tiles needed by those dtypes (128x128x128B, 128x128x64B, 256x128x64B), leaving the autotuner with no viable candidate in some cases **Fix** - FAST_BUILD: differentiate grouped vs. non-grouped paths with tiles known to work for MXFP4/MXFP8: - Grouped: 128x128x128B + 128x128x64B - Non-grouped: 128x128x256B + 128x128x64B - Full-build all_tiles[]: add the three missing tiles so the autotuner has a complete candidate set for MXFP4/MXFP8 workloads <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Performance & Optimizations** * More predictable kernel candidate selection and expanded tile/configuration options for SM120-class GPUs to improve tuning and performance. * Broadened handling of grouped computation patterns to enable additional configuration choices. * **Build/Compatibility** * Refined CUDA 12.9+ architecture suffixing for more accurate build targeting. * **Chores** * Added type annotations and minor signature clarifications (no runtime behavior changes). * **Bug Fixes** * MoE fusion path now forwards additional tensors/parameters to improve fused operation correctness. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: samuellees <lsam@nvidia.com> --------- Signed-off-by: Andrii Skliar <askliar@nvidia.com> Co-authored-by: Andrii Skliar <askliar@nvidia.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Sam (Kesen Li) <lsam@nvidia.com> Co-authored-by: Alex Yang <aleyang@nvidia.com>
1 parent 19055a6 commit 04f4c0c

2 files changed

Lines changed: 34 additions & 18 deletions

File tree

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -587,28 +587,41 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm110(
587587

588588
std::vector<CutlassGemmConfig> get_candidate_configs_sm120(
589589
CutlassGemmConfig::CandidateConfigTypeParam const config) {
590+
#ifdef FAST_BUILD
591+
if (config & CutlassGemmConfig::GROUPED_GEMM) {
592+
return {
593+
CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x128B, MainloopScheduleType::AUTO,
594+
EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1},
595+
CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x64B, MainloopScheduleType::AUTO,
596+
EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}};
597+
} else {
598+
return {
599+
CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x256B, MainloopScheduleType::AUTO,
600+
EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1},
601+
CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x64B, MainloopScheduleType::AUTO,
602+
EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}};
603+
}
604+
#else
590605
if ((config & CutlassGemmConfig::FP4_ONLY) == 0) {
591606
if (config & CutlassGemmConfig::GROUPED_GEMM) {
592607
TLLM_THROW("Not Implemented: SM120 group GEMM only supports nvfp4.");
593608
}
594609
TLLM_THROW("Not Implemented: SM120 GEMM only supports nvfp4.");
595610
}
596-
// Only tiles that satisfy ALL of:
597-
// 1. Present in the dispatch table (SHAPE_CASE in moe_gemm_template_dispatch_tma_ws.h)
598-
// 2. Pass are_tile_shapes_supported_sm120() constexpr check
599-
// 3. Have compiled kernel templates (generate_sm120_grouped_gemm_operations)
600-
//
601-
// 128x128x128B is the only tile meeting all three criteria. Its nominal SMEM
602-
// (2 stages × (128+128) × 256 bytes = 128 KB) exceeds SM120's 100 KB budget,
603-
// but CUTLASS StageCountAutoCarveout reduces the stage count to 1, bringing
604-
// actual SMEM to ~64 KB. can_implement() accepts it at runtime.
605-
//
606-
// K=64 tiles (128x128x64, 128x256x64, 256x128x64) are in the dispatch table
607-
// but cannot be compiled for FP4 on SM120 (TMA layout static_assert failure),
608-
// so they are intentionally excluded here.
609-
return {CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x128B,
610-
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
611-
ClusterShape::ClusterShape_1x1x1}};
611+
// All candidate tiles for SM120 FP4. Invalid tiles for a given path are skipped
612+
// gracefully by the try-catch in calcMaxWorkspaceSize.
613+
static constexpr CutlassTileConfigSM120 all_tiles[] = {
614+
CutlassTileConfigSM120::CtaShape128x128x128B, CutlassTileConfigSM120::CtaShape128x128x64B,
615+
CutlassTileConfigSM120::CtaShape256x128x64B, CutlassTileConfigSM120::CtaShape128x256x64B,
616+
CutlassTileConfigSM120::CtaShape128x128x256B, CutlassTileConfigSM120::CtaShape256x128x128B,
617+
};
618+
std::vector<CutlassGemmConfig> result;
619+
for (auto tile : all_tiles) {
620+
result.push_back(CutlassGemmConfig{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
621+
ClusterShape::ClusterShape_1x1x1});
622+
}
623+
return result;
624+
#endif
612625
}
613626

614627
std::vector<CutlassGemmConfig> get_candidate_configs(

flashinfer/compilation_context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _normalize_cuda_arch(major: int, minor: int) -> tuple[int, str]:
3636
tuple with the correct architecture suffix for nvcc.
3737
3838
SM 9.x -> 'a' suffix (e.g. compute_90a)
39-
SM 12.x -> 'f' suffix with minor version preserved (e.g. compute_120f for SM120, compute_121f for SM121).
39+
SM 12.x -> 'f' suffix with minor version preserved (e.g. compute_120f for SM120, compute_121a for SM121).
4040
Each SM 12.x variant gets its own cubin to avoid running SM120 code on SM121 (DGX Spark) which
4141
can cause cudaErrorIllegalInstruction. Requires CUDA >= 12.9.
4242
SM 10+ -> 'a' suffix (e.g. compute_100a)
@@ -48,7 +48,10 @@ def _normalize_cuda_arch(major: int, minor: int) -> tuple[int, str]:
4848
from flashinfer.jit.cpp_ext import is_cuda_version_at_least
4949

5050
if is_cuda_version_at_least("12.9"):
51-
return (major, str(minor) + "f")
51+
if minor == 0:
52+
return (major, "0f")
53+
else:
54+
return (major, str(minor) + "a")
5255
else:
5356
raise RuntimeError("SM 12.x requires CUDA >= 12.9")
5457
elif major >= 10:

0 commit comments

Comments
 (0)