Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@
#endif
#endif

#ifndef CUTLASS_GDC_ENABLED
#if (CUDA_BARRIER_ENABLED && defined(CUTLASS_ENABLE_GDC_FOR_SM100) && defined(__CUDA_ARCH__) && \
((__CUDA_ARCH__ == 1000 && \
(defined(__CUDA_ARCH_FEAT_SM100_ALL) || CUDA_ARCH_FAMILY(1000))) || \
(__CUDA_ARCH__ == 1010 && \
(defined(__CUDA_ARCH_FEAT_SM101_ALL) || CUDA_ARCH_FAMILY(1010))) || \
(__CUDA_ARCH__ == 1100 && \
(defined(__CUDA_ARCH_FEAT_SM110_ALL) || CUDA_ARCH_FAMILY(1100))) || \
(__CUDA_ARCH__ == 1030 && \
(defined(__CUDA_ARCH_FEAT_SM103_ALL) || CUDA_ARCH_FAMILY(1030))) || \
(__CUDA_ARCH__ == 1200 && \
(defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_FAMILY(1200))) || \
(__CUDA_ARCH__ == 1210 && \
(defined(__CUDA_ARCH_FEAT_SM121_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)))))
#define CUTLASS_GDC_ENABLED
#endif
#endif

namespace cutlass {
namespace arch {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct EpilogueOpDefaultReLU {};

struct EpilogueOpDefaultFtGelu {};

struct EpilogueOpDefaultRelu2 {};

struct EpilogueOpDefault {};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator,
Expand Down Expand Up @@ -117,5 +119,12 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
DefaultScaleMode>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultRelu2> {
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
cutlass::epilogue::thread::Relu2, ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, DefaultScaleMode>;
};

} // namespace cutlass_extensions
} // namespace tensorrt_llm
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
gemm_type = CutlassGemmType::Fp8;
}

// SM121 (GB10) has ~99 KB SMEM — Ampere-style tiles where both M>=128 and N>=128
// exceed the SMEM budget. Filter them out so the autotuner doesn't waste time on
// known-bad configs.
auto filter_sm121 = [sm](std::vector<CutlassTileConfig> configs) {
if (sm != 121) return configs;
std::vector<CutlassTileConfig> filtered;
for (auto const& c : configs) {
TileShape ts = get_cta_shape_for_config(c);
if (ts.m >= 128 && ts.n >= 128) continue;
filtered.push_back(c);
}
return filtered;
};

std::vector<CutlassTileConfig> base_configs{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64};
Expand All @@ -137,42 +151,42 @@ std::vector<CutlassTileConfig> get_candidate_tiles(

switch (gemm_type) {
case CutlassGemmType::Simt:
return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
return filter_sm121({CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8});
case CutlassGemmType::WeightOnly:
if (sm >= 75) {
return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
return filter_sm121({CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64});
} else {
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64};
}
case CutlassGemmType::Int8:
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
return filter_sm121({CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64});
case CutlassGemmType::Fp8:
if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) {
if (sm == 89 || sm == 120 || sm == 121) {
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
return filter_sm121({CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128};
CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128});
} else {
// no valid ampere style fp8 configs for sm90
return {};
}
} else {
if (sm == 89 || sm >= 120) {
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
return filter_sm121({CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
Expand All @@ -183,13 +197,13 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape128x64x128_WarpShape64x32x128,
CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128};
CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128});
} else {
return {};
}
}
default:
return base_configs;
return filter_sm121(base_configs);
}
}

Expand Down Expand Up @@ -635,6 +649,37 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm120(

} // namespace kernels

std::vector<CutlassGemmConfig> get_candidate_configs_sm121(
CutlassGemmConfig::CandidateConfigTypeParam const config) {
#ifdef FAST_BUILD
return {CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x64B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x1x1}};
#else
if (config & CutlassGemmConfig::GROUPED_GEMM) {
if ((config & CutlassGemmConfig::FP4_ONLY) != 0) {
// SM121 (GB10) has ~99 KB SMEM per block (vs ~228 KB on SM120/GB200).
// FP4 is stored unpacked in SMEM (1 byte per 4-bit element), so the per-stage
// footprint is doubled compared to packed storage.
// CtaShape128x128x64B: ~32 KB/stage x 2 stages + ~9 KB epilogue = 73 KB (fits)
// CtaShape128x128x128B: ~64 KB/stage -> 1 stage violates Stages>=2 constraint
// CtaShape256x128x64B, CtaShape128x256x64B: ~48 KB/stage x 2 = 105 KB > 99 KB
return {CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x64B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x1x1}};
} else {
TLLM_THROW("Not Implemented: SM121 group GEMM only supports nvfp4.");
}
} else {
if ((config & CutlassGemmConfig::FP4_ONLY) != 0) {
return get_candidate_configs_sm120(config);
} else {
TLLM_THROW("Not Implemented: SM121 GEMM only supports nvfp4.");
}
}
#endif
} // get_candidate_configs_sm121

std::vector<CutlassGemmConfig> get_candidate_configs(
int sm, int const max_split_k,
CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) {
Expand All @@ -653,9 +698,13 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
return get_candidate_configs_sm100(config_type_param, sm);
}
if (sm >= 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
if (sm == 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
return get_candidate_configs_sm120(config_type_param);
}
if (sm == 121 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
// SM121 = GB10: same ISA as SM120 but ~99 KB SMEM; only 128x128 CTA tile fits.
return get_candidate_configs_sm121(config_type_param);
}

std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,9 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType, IsMXFPX>::moeGemmBi
case ActivationType::Geglu:
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(inputs, hopper_inputs);
break;
case ActivationType::Relu2:
runGemm<cutlass_extensions::EpilogueOpDefaultRelu2>(inputs, hopper_inputs);
break;
case ActivationType::InvalidType:
TLLM_THROW("Activation type for fpA_intB must be valid.");
break;
Expand Down
11 changes: 9 additions & 2 deletions flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,15 +772,22 @@ def choose_one(
raise
except Exception as e:
shapes = self._get_input_sizes(tensors)
logger.warning(
# Log stacktrace as debug to not spam log
logger.debug(
f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling: {e}"
)

# Log stacktrace as debug to not spam log
logger.debug(
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
)

# Clear any pending async CUDA errors (e.g.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of nitpicking but consider splitting this into a different PR

# cudaErrorIllegalInstruction from a failed
# kernel warmup run) so they don't surface
# later during CUDA graph capture.
with contextlib.suppress(Exception):
torch.cuda.synchronize()

# Record the failed profiling combinations
if (
custom_op
Expand Down
11 changes: 8 additions & 3 deletions flashinfer/compilation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def _normalize_cuda_arch(major: int, minor: int) -> tuple[int, str]:
tuple with the correct architecture suffix for nvcc.

SM 9.x -> 'a' suffix (e.g. compute_90a)
SM 12.x -> always normalized to SM 120 with 'f' suffix (e.g. compute_120f).
This covers both SM 12.0 and SM 12.1 (DGX Spark) when the installed CUDA toolchain supports it (CUDA >= 12.9).
SM 12.0 -> 'f' suffix (compute_120f), SM 12.1 -> 'f' suffix (compute_121f).
SM 12.1 gets a distinct key so its reduced SMEM budget is respected in the JIT cache.
SM 10+ -> 'a' suffix (e.g. compute_100a)
SM < 9 -> no suffix
"""
Expand All @@ -47,7 +47,12 @@ def _normalize_cuda_arch(major: int, minor: int) -> tuple[int, str]:
from flashinfer.jit.cpp_ext import is_cuda_version_at_least

if is_cuda_version_at_least("12.9"):
return (major, "0f")
if minor == 0:
return (major, "0f")
else:
# SM12.1 (GB10) — keep minor to distinguish from SM12.0 (GB200)
# in the JIT cache; GB10 has only ~99 KB SMEM vs ~228 KB on GB200.
return (major, f"{minor}f")
else:
raise RuntimeError("SM 12.x requires CUDA >= 12.9")
elif major >= 10:
Expand Down
5 changes: 5 additions & 0 deletions flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def gen_cutlass_fused_moe_sm120_module(use_fast_build: bool = False) -> JitSpec:
"-DENABLE_FP8",
"-DENABLE_FP4",
"-DUSING_OSS_CUTLASS_MOE_GEMM",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
]

nvcc_flags += current_compilation_context.get_nvcc_flags_list(
Expand All @@ -56,6 +57,7 @@ def gen_cutlass_fused_moe_sm103_module(use_fast_build: bool = False) -> JitSpec:
"-DENABLE_FP4",
"-DUSING_OSS_CUTLASS_MOE_GEMM",
"-DCOMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
]

nvcc_flags += current_compilation_context.get_nvcc_flags_list(
Expand All @@ -73,6 +75,7 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec:
"-DENABLE_FP8",
"-DENABLE_FP4",
"-DUSING_OSS_CUTLASS_MOE_GEMM",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@askliar What does this flag do?
Is there a way to add it in some common flags area or something?

Copy link
Copy Markdown
Contributor

@johnnynunez johnnynunez Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed for activate PDL #2708

]

nvcc_flags += current_compilation_context.get_nvcc_flags_list(
Expand All @@ -91,6 +94,7 @@ def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec:
"-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "",
"-DENABLE_FP4" if is_cuda_version_at_least("12.8") else "",
"-DUSING_OSS_CUTLASS_MOE_GEMM",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
]
return gen_cutlass_fused_moe_module(nvcc_flags, "90", use_fast_build)

Expand Down Expand Up @@ -304,6 +308,7 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
"-DENABLE_BF16",
"-DENABLE_FP8",
"-DENABLE_FP4",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"',
]
+ nvcc_flags,
Expand Down
1 change: 1 addition & 0 deletions flashinfer/jit/gemm/fp8_blockscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def gen_fp8_blockscale_gemm_sm90_module(use_fast_build: bool = False) -> JitSpec
"-DENABLE_BF16",
"-DENABLE_FP8",
*(("-DENABLE_FP8_BLOCK_SCALE",) if is_cuda_version_at_least("12.8") else ()),
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
]

return gen_jit_spec(
Expand Down
Loading