Skip to content

SM120 attention kernels exist but are blocked by wiring issues (fmha_v2, backend selector, MLA) #2555

@blake-snc

Description

@blake-snc

Summary

We've done a deep investigation of FlashInfer's SM120 support on an NVIDIA DGX Spark (GB10, SM121a, CUDA 13.0) and found that SM120 attention kernels already exist in the codebase but are not wired into the runtime pipeline. Three specific wiring issues prevent SM120 from using its optimized paths:

Issue 1: FMHA_V2 SM120 kernels gated behind ENABLE_SM120

Location: flashinfer/jit/attention/fmha_v2/generator_utils.py:6772-6786

SM120-specific FMHA_V2 kernels are only generated if the ENABLE_SM120 environment variable is set:

if "ENABLE_SM120" in os.environ:
    enumerate_hmma_flash_kernels(specs, sm=120, dtype="fp16")
    enumerate_hmma_flash_kernels(specs, sm=120, dtype="bf16")
    # ...

Three SM120 kernel files are referenced in jit/attention/modules.py:1918-1920 (BF16 and FP8 variants) but never get generated in standard builds. This means gen_trtllm_fmha_v2_module() can't find them at runtime.

Expected: SM120 kernels should be generated by default when building for SM120 targets.

Issue 2: determine_attention_backend() routes SM120 to generic FA2

Location: flashinfer/utils.py:455-497

def determine_attention_backend(device, ...):
    if is_sm90a_supported(device) and is_fa3_backend_supported(...):
        return "fa3"
    else:
        return "fa2"

This only checks for SM90. SM120 falls through to generic FA2 (Ampere-level), even though SM120-specific FMHA_V2 kernels exist. The function should have an SM120 path that routes to the fmha_v2 SM120 backend.

Issue 3: determine_mla_backend() routes SM120 to FA2 instead of XQA

Location: flashinfer/utils.py:561-562

def determine_mla_backend(device):
    return "fa3" if is_sm90a_supported(device) else "fa2"

SM120 gets fa2 for MLA, but FlashInfer already has a functional XQA MLA backend specifically for SM120 (data/csrc/xqa/mla_sm120.cu, compiled with supported_major_versions=[12] in jit/xqa.py:137). This function should return "xqa" for SM120.

What Currently Works on SM120

Component Status
XQA decode (MHA) ✅ Works (xqa.py checks [9, 10, 12])
XQA decode (MLA, FP8) ✅ Code exists (mla_sm120.cu) but Issue #2166 reports failures
FMHA_V2 SM120 prefill ❌ Kernels exist but never generated or routed to
SM120 GEMM (FP8, FP4) ✅ Works via dedicated SM120 CUTLASS kernels

Proposed Fixes

  1. Remove ENABLE_SM120 gate in generator_utils.py (or enable by default for SM120 targets)
  2. Add SM120 branch in determine_attention_backend() to route to fmha_v2 SM120 kernels
  3. Add SM120 branch in determine_mla_backend() to route to XQA

Environment

  • NVIDIA DGX Spark (GB10, SM121a)
  • CUDA 13.0, aarch64
  • flashinfer-python 0.6.2+cu130
  • PyTorch 2.9.1+cu130

Related issues: #1147, #2166, #2294

From Second Nature Computing — testing on DGX Spark.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions