-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
Which component requires the feature?
CuTe DSL
Feature Request
Is your feature request related to a problem? Please describe.
FlashAttention-4 (CuTeDSL) cannot run on the NVIDIA GB10 GPU — the Blackwell-architecture chip in NVIDIA DIGITS / DGX Spark — because nvidia-cutlass-dsl==4.4.1 and
nvidia-cutlass-dsl-libs-base==4.4.1 do not recognize or include kernel images for its compute capability (12.1 / Arch.sm_121a).
Two specific failures occur:
- cutlass.cute.nvgpu.tcgen05.mma.MmaF16BF16Op rejects Arch.sm_121a — the allowed list is [sm_100a, sm_100f, sm_101a, sm_101f, sm_103a, sm_103f, sm_110a, sm_110f].
- Even when overriding with CUTE_DSL_ARCH=sm_100a, the pre-compiled CUDA library in nvidia-cutlass-dsl-libs-base does not contain SM121-compatible SASS, resulting in
cudaErrorNoKernelImageForDevice during cuda_dialect_init_library_once.
This blocks all CuTeDSL-based workloads (including FlashAttention-4) on the GB10, despite it being Blackwell-class silicon with SM100-compatible tensor cores.
Describe the solution you'd like
- Extend the tcgen05/mma.py arch allowlist to include SM12x variants (sm_120a, sm_121a, etc.), mapping them to the existing SM100 kernel templates (since the GB10's tensor
core ISA is Blackwell-compatible). - Ship SM121-targeted SASS (or SM100 PTX for forward-compatible JIT) in nvidia-cutlass-dsl-libs-base, so the CUDA runtime library loaded by cuda_dialect_init_library_once
has a kernel image available for SM121 devices.
Describe alternatives you've considered
- CUTE_DSL_ARCH=sm_100a env var override — bypasses the MMA arch check in Python, but fails at the CUDA runtime level because the pre-compiled .so in
nvidia-cutlass-dsl-libs-base has no SM121 kernel image. - Patching _get_device_arch() in flash-attention's interface.py to map SM12x → SM100 — helps at the flash-attention level but doesn't affect CUTLASS DSL internals.
- Patching the assertion in flash_fwd_sm100.py (self.arch >= Arch.sm_100 and self.arch <= Arch.sm_110f) to clamp Arch.sm_121a → Arch.sm_100 — same issue, doesn't reach the
pre-compiled library. - Falling back to FlashAttention-2 — works (FA2 compiles from source for the actual GPU), but loses the performance benefits of FA4/CuTeDSL on Blackwell hardware.
Additional context
Environment:
| Component | Version |
| --- | --- |
| GPU | NVIDIA GB10 (NVIDIA DIGITS / DGX Spark) |
| Compute capability | `(12, 1)` — `Arch.sm_121a` |
| NVIDIA driver | 580.126.09 |
| CUDA runtime | 12.8 |
| Container | `nvcr.io/nvidia/pytorch:25.01-py3` |
| PyTorch | 2.6.0a0+ecf3bae40a.nv25.01 |
| nvidia-cutlass-dsl | 4.4.1 |
| nvidia-cutlass-dsl-libs-base | 4.4.1 |
| flash-attn-4 | 4.0.0b4 |
Minimal reproducer:
import torch
from flash_attn.cute import flash_attn_func
q = torch.randn(1, 128, 32, 64, device="cuda", dtype=torch.bfloat16)
k = torch.randn(1, 128, 32, 64, device="cuda", dtype=torch.bfloat16)
v = torch.randn(1, 128, 32, 64, device="cuda", dtype=torch.bfloat16)
out, lse = flash_attn_func(q, k, v)
Error without CUTE_DSL_ARCH override:
cutlass.cute.nvgpu.common.OpError: OpError: expects arch to be one of
[Arch.sm_100a, Arch.sm_100f, Arch.sm_101a, Arch.sm_101f,
Arch.sm_103a, Arch.sm_103f, Arch.sm_110a, Arch.sm_110f],
but got Arch.sm_121a
Error with CUTE_DSL_ARCH=sm_100a:
RuntimeError: Got Cuda Runtime Error: cudaErrorNoKernelImageForDevice
no kernel image is available for execution on the device