Skip to content
6 changes: 3 additions & 3 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ void FusedMoeLauncher::init_common(
int major = 0, minor = 0;
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device);
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device);
TVM_FFI_ICHECK_EQ(major, 10) << "MoE kernel requires 10.x architecture. Current device has SM "
TVM_FFI_ICHECK(major == 10 || major == 12) << "MoE kernel requires SM 10.x or SM 12.x architecture. Current device has SM "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe 11 (Thor) also works? it should very similar to 10.x

<< major << minor;
this->device_version = std::make_tuple(major, minor);

Expand Down Expand Up @@ -1342,8 +1342,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
return std::make_tuple(major, minor);
}();

TVM_FFI_ICHECK_EQ(std::get<0>(device_props), 10)
<< "This kernel requires 10.x architecture. Current device has SM "
TVM_FFI_ICHECK(std::get<0>(device_props) == 10 || std::get<0>(device_props) == 12)
<< "This kernel requires SM 10.x or SM 12.x architecture. Current device has SM "
<< std::get<0>(device_props) << std::get<1>(device_props);

// Set data types
Expand Down
43 changes: 39 additions & 4 deletions flashinfer/compilation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,55 @@ class CompilationContext:
"-DFLASHINFER_ENABLE_FP4_E2M1",
]

@staticmethod
def _normalize_cuda_arch(major: int, minor: int) -> tuple[int, str]:
"""Normalize a (major, minor) capability pair into a (major, minor_str)
tuple with the correct architecture suffix for nvcc.

SM 9.x -> 'a' suffix (e.g. compute_90a)
SM 12.x -> 'f' suffix (e.g. compute_120f) when the installed CUDA
toolchain supports it (CUDA >= 13.0), otherwise 'a'.
SM 10+ -> 'a' suffix (e.g. compute_100a)
SM < 9 -> no suffix
"""
if major == 9:
return (major, str(minor) + "a")
elif major == 12:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last nit after discussing with @nv-yunzheq ; I think we also want to compile DGX spark (sm121) to sm120f, instead of 121f.

Can we change the logic to return always major + 0f (or just 120f) on line 50?

try:
from flashinfer.jit.cpp_ext import is_cuda_version_at_least
if is_cuda_version_at_least("13.0"):
return (major, str(minor) + "f")
Comment thread
kahyunnam marked this conversation as resolved.
Outdated
except ImportError:
logger.debug(
"Could not import is_cuda_version_at_least; "
"falling back to 'a' suffix for SM %d.%d", major, minor
)
return (major, str(minor) + "a")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
elif major >= 10:
return (major, str(minor) + "a")
return (major, str(minor))

def __init__(self):
self.TARGET_CUDA_ARCHS = set()
if "FLASHINFER_CUDA_ARCH_LIST" in os.environ:
for arch in os.environ["FLASHINFER_CUDA_ARCH_LIST"].split(" "):
major, minor = arch.split(".")
major = int(major)
self.TARGET_CUDA_ARCHS.add((int(major), str(minor)))
# If the user already provided a suffix (e.g. "12.0f"),
# respect it as-is; otherwise normalise.
if minor[-1].isalpha():
self.TARGET_CUDA_ARCHS.add((major, minor))
else:
self.TARGET_CUDA_ARCHS.add(
self._normalize_cuda_arch(major, int(minor))
)
else:
try:
for device in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(device)
if major >= 9:
minor = str(minor) + "a"
self.TARGET_CUDA_ARCHS.add((int(major), str(minor)))
self.TARGET_CUDA_ARCHS.add(
self._normalize_cuda_arch(major, minor)
)
except Exception as e:
logger.warning(f"Failed to get device capability: {e}.")

Expand Down
6 changes: 3 additions & 3 deletions flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def gen_cutlass_fused_moe_sm103_module(use_fast_build: bool = False) -> JitSpec:
]

nvcc_flags += current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10]
supported_major_versions=[10, 12]
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

return gen_cutlass_fused_moe_module(nvcc_flags, "103", use_fast_build)
Expand All @@ -76,7 +76,7 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec:
]

nvcc_flags += current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10, 11]
supported_major_versions=[10, 11, 12]
)

return gen_cutlass_fused_moe_module(nvcc_flags, "100", use_fast_build)
Expand Down Expand Up @@ -248,7 +248,7 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:

# currently only support Blackwell
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10]
supported_major_versions=[10, 12]
)

return gen_jit_spec(
Expand Down