Skip to content

[Bug] [sgl-kernel] nvidia DGX (sm121a) cannot use FA4 in sgl-kernel: Unsupported compute capability. Supported: 9.x, 10.x, 11.x #20363

@gbdjxgp

Description

@gbdjxgp

Checklist

  • I searched related issues but found no solution.
  • The bug persists in the latest version.
  • Issues without environment info and a minimal reproducible demo are hard to resolve and may receive no feedback.
  • If this is not a bug report but a general question, please start a discussion at https://github.com/sgl-project/sglang/discussions. Otherwise, it will be closed.
  • Please use English. Otherwise, it will be closed.

Describe the bug

On NVIDIA GB10, FA4 in sgl-kernel fails immediately with:

AssertionError: Unsupported compute capability. Supported: 9.x, 10.x, 11.x

On this machine, torch.cuda.get_device_capability() returns (12, 1) and nvidia-smi reports compute capability 12.1. I infer the architecture target is sm_121a from the local CUDA JIT cache (-gencode=arch=compute_121a,code=sm_121a) and the open SGLang GB10 support tracking issue.

The failure comes from sgl_kernel/_fa4_interface.py, where FA4 only accepts compute capability major versions [9, 10, 11]. This makes the FA4 path unusable on GB10 / CC 12.1.

I first hit this through mini-sglang during graph capture, but the issue reproduces directly with a minimal sgl_kernel.flash_attn_with_kvcache(ver=4) script, so it does not require any model code to trigger.

Reproduction

Minimal reproduction:

  • No model required
import torch
from sgl_kernel.flash_attn import flash_attn_with_kvcache

print("device:", torch.cuda.get_device_name(0))
print("capability:", torch.cuda.get_device_capability(0))

q = torch.randn(1, 8, 64, device="cuda", dtype=torch.bfloat16)
k_cache = torch.randn(1, 8, 8, 64, device="cuda", dtype=torch.bfloat16)
v_cache = torch.randn(1, 8, 8, 64, device="cuda", dtype=torch.bfloat16)
cache_seqlens = torch.tensor([8], device="cuda", dtype=torch.int32)
cu_seqlens_q = torch.tensor([0, 1], device="cuda", dtype=torch.int32)

flash_attn_with_kvcache(
    q=q,
    k_cache=k_cache,
    v_cache=v_cache,
    cache_seqlens=cache_seqlens,
    cu_seqlens_q=cu_seqlens_q,
    max_seqlen_q=1,
    causal=True,
    ver=4,
)

Actual result:

AssertionError: Unsupported compute capability. Supported: 9.x, 10.x, 11.x

Environment

 python check_env.py 
.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:283: UserWarning: 
    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
  warnings.warn(
Python: 3.12.3 (main, Nov  6 2025, 13:44:16) [GCC 13.3.0]
CUDA available: True
GPU 0: NVIDIA GB10
GPU 0 Compute Capability: 12.1
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 13.0, V13.0.88
CUDA Driver Version: 580.95.05
PyTorch: 2.9.1+cu130
sglang: Module Not Found
sgl_kernel: 0.3.21
flashinfer_python: 0.6.6
flashinfer_cubin: Module Not Found
flashinfer_jit_cache: Module Not Found
triton: 3.5.1
transformers: 4.57.3
torchao: Module Not Found
numpy: 2.3.5
aiohttp: Module Not Found
fastapi: 0.135.1
hf_transfer: Module Not Found
huggingface_hub: 0.36.2
interegular: Module Not Found
modelscope: 1.34.0
orjson: Module Not Found
outlines: Module Not Found
packaging: 26.0
psutil: 7.2.2
pydantic: 2.12.5
python-multipart: Module Not Found
pyzmq: 27.1.0
uvicorn: 0.41.0
uvloop: Module Not Found
vllm: Module Not Found
xgrammar: Module Not Found
openai: 2.26.0
tiktoken: Module Not Found
anthropic: Module Not Found
litellm: Module Not Found
torchcodec: Module Not Found
NVIDIA Topology: 
        GPU0    NIC0    NIC1    NIC2    NIC3    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NODE    NODE    NODE    NODE    0-19    0               N/A
NIC0    NODE     X      PIX     NODE    NODE
NIC1    NODE    PIX      X      NODE    NODE
NIC2    NODE    NODE    NODE     X      PIX
NIC3    NODE    NODE    NODE    PIX      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: rocep1s0f0
  NIC1: rocep1s0f1
  NIC2: roceP2p1s0f0
  NIC3: roceP2p1s0f1


ulimit soft: 500000
Package                      Version     Editable project location
---------------------------- ----------- ----------------------------------
accelerate                   1.13.0
annotated-doc                0.0.4
annotated-types              0.7.0
anyio                        4.12.1
apache-tvm-ffi               0.1.9
black                        26.3.0
certifi                      2026.2.25
cfgv                         3.5.0
charset-normalizer           3.4.5
click                        8.3.1
contourpy                    1.3.3
coverage                     7.13.4
cuda-bindings                13.1.1
cuda-pathfinder              1.2.2
cuda-python                  13.1.1
cycler                       0.12.1
distlib                      0.4.0
distro                       1.9.0
einops                       0.8.2
fastapi                      0.135.1
filelock                     3.25.1
flake8                       7.3.0
flashinfer-python            0.6.6
fonttools                    4.62.0
fsspec                       2025.12.0
h11                          0.16.0
hf-xet                       1.3.2
httpcore                     1.0.9
httpx                        0.28.1
huggingface-hub              0.36.2
identify                     2.6.17
idna                         3.11
iniconfig                    2.3.0
jinja2                       3.1.6
jiter                        0.13.0
kiwisolver                   1.5.0
librt                        0.8.1
markupsafe                   3.0.2
matplotlib                   3.10.8
mccabe                       0.7.0
minisgl                      0.1.0       /home/xxx/xxx/mini-sglang
modelscope                   1.34.0
mpmath                       1.3.0
msgpack                      1.1.2
mypy                         1.19.1
mypy-extensions              1.1.0
networkx                     3.6.1
ninja                        1.13.0
nodeenv                      1.10.0
numpy                        2.3.5
nvidia-cublas                13.0.0.19
nvidia-cuda-cupti            13.0.48
nvidia-cuda-nvrtc            13.0.48
nvidia-cuda-runtime          13.0.48
nvidia-cudnn-cu13            9.13.0.50
nvidia-cudnn-frontend        1.19.1
nvidia-cufft                 12.0.0.15
nvidia-cufile                1.15.0.42
nvidia-curand                10.4.0.35
nvidia-cusolver              12.0.3.29
nvidia-cusparse              12.6.2.49
nvidia-cusparselt-cu13       0.8.0
nvidia-cutlass-dsl           4.4.1
nvidia-cutlass-dsl-libs-base 4.4.1
nvidia-ml-py                 13.590.48
nvidia-nccl-cu13             2.27.7
nvidia-nvjitlink             13.0.39
nvidia-nvshmem-cu13          3.3.24
nvidia-nvtx                  13.0.39
openai                       2.26.0
packaging                    26.0
pathspec                     1.0.4
pillow                       12.0.0
platformdirs                 4.9.4
pluggy                       1.6.0
pre-commit                   4.5.1
prompt-toolkit               3.0.52
psutil                       7.2.2
pyarrow                      23.0.1
pycodestyle                  2.14.0
pydantic                     2.12.5
pydantic-core                2.41.5
pyflakes                     3.4.0
pygments                     2.19.2
pyparsing                    3.3.2
pytest                       9.0.2
pytest-cov                   7.0.0
python-dateutil              2.9.0.post0
python-discovery             1.1.3
pytokens                     0.4.1
pyyaml                       6.0.3
pyzmq                        27.1.0
quack-kernels                0.3.2
regex                        2026.2.28
requests                     2.32.5
ruff                         0.15.5
safetensors                  0.7.0
setuptools                   70.2.0
sgl-kernel                   0.3.21
six                          1.17.0
sniffio                      1.3.1
starlette                    0.52.1
sympy                        1.14.0
tabulate                     0.10.0
tokenizers                   0.22.2
torch                        2.9.1+cu130
torch-c-dlpack-ext           0.1.5
torchaudio                   2.9.1
torchvision                  0.24.1
tqdm                         4.67.3
transformers                 4.57.3
triton                       3.5.1
typing-extensions            4.15.0
typing-inspection            0.4.2
urllib3                      2.6.3
uvicorn                      0.41.0
virtualenv                   21.2.0
wcwidth                      0.6.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions