Skip to content
18 changes: 3 additions & 15 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,9 @@ void FusedMoeLauncher::init_common(
int major = 0, minor = 0;
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device);
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device);
TVM_FFI_ICHECK_EQ(major, 10) << "MoE kernel requires 10.x architecture. Current device has SM "
<< major << minor;
TVM_FFI_ICHECK(major == 10 || major == 12)
<< "MoE kernel requires SM 10.x or SM 12.x architecture. Current device has SM " << major
<< minor;
this->device_version = std::make_tuple(major, minor);

args->routing_logits = routing_logits.has_value() ? routing_logits.value().data_ptr() : nullptr;
Expand Down Expand Up @@ -1333,19 +1334,6 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
int64_t weight_layout, ActivationType activation_type, btg::Dtype dtype_act,
btg::Dtype dtype_weights) {
static const std::tuple<int, int> device_props = [this] {
int major, minor;
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor,
hidden_states.device().device_id);
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor,
hidden_states.device().device_id);
return std::make_tuple(major, minor);
}();

TVM_FFI_ICHECK_EQ(std::get<0>(device_props), 10)
<< "This kernel requires 10.x architecture. Current device has SM "
<< std::get<0>(device_props) << std::get<1>(device_props);

// Set data types
args->mDtypeElt = dtype_act;
args->mDtypeOut = btg::Dtype::Bfloat16; // Output is always BF16 for FP4
Expand Down
45 changes: 41 additions & 4 deletions flashinfer/compilation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,57 @@ class CompilationContext:
"-DFLASHINFER_ENABLE_FP4_E2M1",
]

@staticmethod
def _normalize_cuda_arch(major: int, minor: int) -> tuple[int, str]:
"""Normalize a (major, minor) capability pair into a (major, minor_str)
tuple with the correct architecture suffix for nvcc.

SM 9.x -> 'a' suffix (e.g. compute_90a)
SM 12.x -> always normalized to SM 120 with 'f' suffix (e.g. compute_120f)
when the installed CUDA toolchain supports it (CUDA >= 13.0),
otherwise 'a'. This covers both SM 12.0 and SM 12.1 (DGX Spark).
SM 10+ -> 'a' suffix (e.g. compute_100a)
SM < 9 -> no suffix
"""
if major == 9:
return (major, str(minor) + "a")
elif major == 12:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last nit after discussing with @nv-yunzheq ; I think we also want to compile DGX spark (sm121) to sm120f, instead of 121f.

Can we change the logic to return always major + 0f (or just 120f) on line 50?

try:
from flashinfer.jit.cpp_ext import is_cuda_version_at_least

if is_cuda_version_at_least("13.0"):
return (major, "0f")
except (ImportError, RuntimeError, ValueError):
logger.debug(
"Could not determine CUDA version; "
"falling back to 'a' suffix for SM %d.%d",
major,
minor,
)
return (major, "0a")
elif major >= 10:
return (major, str(minor) + "a")
return (major, str(minor))

def __init__(self):
self.TARGET_CUDA_ARCHS = set()
if "FLASHINFER_CUDA_ARCH_LIST" in os.environ:
for arch in os.environ["FLASHINFER_CUDA_ARCH_LIST"].split(" "):
major, minor = arch.split(".")
major = int(major)
self.TARGET_CUDA_ARCHS.add((int(major), str(minor)))
# If the user already provided a suffix (e.g. "12.0f"),
# respect it as-is; otherwise normalise.
if minor[-1].isalpha():
self.TARGET_CUDA_ARCHS.add((major, minor))
else:
self.TARGET_CUDA_ARCHS.add(
self._normalize_cuda_arch(major, int(minor))
)
else:
try:
for device in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(device)
if major >= 9:
minor = str(minor) + "a"
self.TARGET_CUDA_ARCHS.add((int(major), str(minor)))
self.TARGET_CUDA_ARCHS.add(self._normalize_cuda_arch(major, minor))
except Exception as e:
logger.warning(f"Failed to get device capability: {e}.")

Expand Down
6 changes: 3 additions & 3 deletions flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def gen_cutlass_fused_moe_sm103_module(use_fast_build: bool = False) -> JitSpec:
]

nvcc_flags += current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10]
supported_major_versions=[10, 12]
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

return gen_cutlass_fused_moe_module(nvcc_flags, "103", use_fast_build)
Expand All @@ -76,7 +76,7 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec:
]

nvcc_flags += current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10, 11]
supported_major_versions=[10, 11, 12]
)

return gen_cutlass_fused_moe_module(nvcc_flags, "100", use_fast_build)
Expand Down Expand Up @@ -248,7 +248,7 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:

# currently only support Blackwell
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10]
supported_major_versions=[10, 12]
)

return gen_jit_spec(
Expand Down
Loading