Skip to content

Add SM120 (Blackwell GeForce / DGX Spark) flash attention#2268

Open
blake-snc wants to merge 6 commits intoDao-AILab:mainfrom
blake-snc:feat/sm120-support
Open

Add SM120 (Blackwell GeForce / DGX Spark) flash attention#2268
blake-snc wants to merge 6 commits intoDao-AILab:mainfrom
blake-snc:feat/sm120-support

Conversation

@blake-snc
Copy link

@blake-snc blake-snc commented Feb 20, 2026

Summary

Add flash attention support for SM120 (NVIDIA Blackwell GeForce / DGX Spark, compute capability 12.x). SM120 uses SM80-era MMA instructions (mma.sync.aligned.m16n8k16) with 99 KB shared memory.

Features:

  • Forward pass: non-causal + causal, D=64 and D=128, tile sizes tuned for SM120 SMEM (D≤64: 128×128, D>64: 128×64)
  • Backward pass: D=64 and D=128, 4-warp layout matching forward
  • Variable-length (varlen): forward + backward with cu_seqlens_q/k and seqused_q/k
  • Split-KV (FlashDecoding): multi-split inference decoding
  • Paged KV cache: Python-level page table resolution (SM80 swizzled SMEM layout is incompatible with PagedKVManager's tiled copy)

Architecture:

  • FlashAttentionForwardSm120 subclasses FlashAttentionForwardSm80 with arch=80 (CpAsync code paths) and an SM120 SMEM capacity check
  • FlashAttentionBackwardSm120 subclasses FlashAttentionBackwardSm80 similarly
  • SM80 base class enhanced with varlen support (SingleTileVarlenScheduler, SeqlenInfoQK offset helpers) and split-KV support

Depends on #2325 (SM80 API drift fixes).

Rebased on current main (resolves prior merge conflicts).

Validation on SM121a (NVIDIA GB10, DGX Spark)

Forward (non-causal + causal):

Config max_diff
B=2 S=128 H=8 D=128 non-causal 0.003906
B=2 S=128 H=8 D=128 causal 0.015625
B=1 S=256 H=4 D=64 causal 0.007812
B=4 S=512 H=16 D=128 causal 0.015625
B=1 S=1024 H=8 D=128 causal 0.015625

Backward (gradient diffs):

Config dq dk dv
B=2 S=128 H=8 D=64 non-causal 0.003906 0.007812 0.007812
B=2 S=128 H=8 D=128 causal 0.012329 0.015625 0.031250
B=1 S=256 H=4 D=64 causal 0.015625 0.015625 0.031250

Varlen forward (non-causal + causal, seqlens=[32,64,48,16]):

Batch (seqlen) non-causal max_diff causal max_diff
0 (32) 0.007812 0.015625
1 (64) 0.003906 0.015625
2 (48) 0.007812 0.007812
3 (16) 0.007812 0.015625

Split-KV (FlashDecoding):

Config max_diff
B=1 S=256 D=64 splits=2 non-causal 0.003906
B=1 S=256 D=128 splits=2 causal 0.007812
B=2 S=512 D=128 splits=4 non-causal 0.001953
B=1 S=1024 D=128 splits=4 causal 0.007812

All diffs are within BF16 precision bounds. Reference: torch.nn.functional.scaled_dot_product_attention.

Test plan

  • Forward non-causal + causal: 5 configs pass
  • Backward non-causal + causal: 3 configs pass
  • Varlen forward non-causal + causal: 8 configs pass
  • Split-KV: 4 configs pass
  • SM90/SM100 regressions: structurally verified (SM120 dispatch is fully guarded by arch // 10 == 12, new subclasses are in separate files, SM80 base class changes are behind const_expr varlen/split-KV branches) — no SM90/SM100 hardware available for runtime verification

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com

@johnnynunez
Copy link
Contributor

johnnynunez commented Feb 21, 2026

@blake-snc
Copy link
Author

@johnnynunez Yes — TMA support is already implemented in our CUTLASS PR (NVIDIA/cutlass#3030, the FlashAttentionForwardSm120Tma class).

This flash-attention PR is the CpAsync baseline adapted to Dao-AILab's CuTe DSL interface. Just rebased onto latest main to fix the merge conflict. We could add TMA here too, but wanted to get the basic forward pass landed first.

Current status: forward-only, BF16/FP16, causal/non-causal, MHA/GQA/MQA, hdim 64/96/128. Still missing: backward pass, varlen, paged KV, split-KV. Happy to coordinate on expanding it!

@johnnynunez
Copy link
Contributor

cc @drisspg @tridao

I think that it should good to close all related issues with sm12x that includes rtx50/rtx Pro and DGX Spark
#2019
#1810
#1683
#1671
#1665
#1563

@blake-snc blake-snc force-pushed the feat/sm120-support branch from 59ddcc9 to 695437f Compare March 6, 2026 03:55
@blake-snc blake-snc changed the title Add SM120 (Blackwell GeForce / DGX Spark) forward pass support Add SM120 (Blackwell GeForce / DGX Spark) forward + backward pass support Mar 6, 2026
@blake-snc
Copy link
Author

Updated with split-KV (FlashDecoding) and paged KV support. Latest push adds:

Split-KV / FlashDecoding — splits the K dimension across multiple thread blocks for long-context decode (few Q tokens, many KV tokens). Each split produces BF16/FP16 partial outputs + FP32 LSE, which are converted to FP32 and merged by the existing FlashAttentionForwardCombine kernel. Validated 12/12 configs (D=64/128, causal/non-causal, 2-4 splits).

Paged KV cache — supports paged KV for inference engines (vLLM/SGLang). The SM80 kernel's swizzled SMEM layout (composed with Swizzle(3,3,3)) is incompatible with PagedKVManager's tiled copy, which builds a plain 2D layout from sX.stride and loses the swizzle. Instead of kernel-level integration, we resolve the page table at the Python level via a single GPU gather operation. Works in combination with split-KV for the full inference pattern. Validated 14/14 configs (D=64/128, page_size 64/128, uniform + varied seqlens) plus 6/6 combined paged+split-KV configs.

All existing tests still pass (53/53 total including regressions).


Regarding TMA: This PR covers the CpAsync baseline with full feature support (fwd+bwd, varlen, split-KV, paged KV). TMA optimization (cp.async.bulk.tensor replacing per-thread CpAsync) is planned as a separate follow-up PR. The TMA kernel architecture is fundamentally different — it uses warp specialization (1 producer warp for TMA loads + 3 consumer warps for MMA compute), PipelineTmaAsync with mbarrier-based coordination, and SM90-compatible swizzled SMEM layouts for TMA descriptors.

We have a working TMA implementation already validated on SM121a in our CUTLASS PR (NVIDIA/cutlass#3030, the FlashAttentionForwardSm120Tma class). The flash-attention follow-up will port that into the CuTe DSL interface.py dispatch as a new FlashAttentionForwardSm120Tma class.

No other existing PR in flash-attention addresses TMA for SM120, and the current CUTLASS 4.4.1 release does not include SM120 FA examples (our CUTLASS PR #3030 is still open on the sm120-flash-attention-v2 branch).

@geraldstanje1
Copy link

hi @blake-snc is there a way to test this with nvidia rtx 6000 pro blackwell and well? any instructions for me to try?

blake-snc and others added 6 commits March 10, 2026 16:01
The SM80 base classes (FlashAttentionForwardSm80, FlashAttentionBackwardSm80)
had 5 latent bugs where their code fell out of sync with upstream API changes
to SeqlenInfoQK, AttentionMask, and _check_type.

Forward (flash_fwd.py):
- _check_type: pass None for 4 varlen type args (signature expanded for varlen)
- SeqlenInfoQK.create: pass batch_size as required first positional arg
- compute_one_n_block: pass seqlen= (required by score_mod path)
- AttentionMask: pass SeqlenInfoQK object instead of seqlen_q/seqlen_k
- mask.apply_mask: pass batch_idx and head_idx (required by mask_mod path)

Backward (flash_bwd.py):
- AttentionMask: same fix as forward
- mask.apply_mask: same fix as forward

These are all one-line fixes that align the SM80 classes with the current API.
None of these paths are exercised by SM90/SM100 dispatch (they have their own
__call__ implementations), so this is a no-op for existing users.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add FlashAttentionForwardSm120 subclass with SM120's 99 KB SMEM constraint
- Fix 5 latent API-drift bugs in FlashAttentionForwardSm80.__call__
- Add SM120 dispatch with optimized tile sizes (D<=64: 128x128, D>64: 128x64)
- Integrate with persistent compile cache, fake tensor mode, use_2cta_instrs

Validated on NVIDIA GB10 (DGX Spark, SM121a):
- 10/10 correctness tests pass (non-causal + causal, max_diff < 0.016 BF16)
- Peak ~49 TFLOPS causal, ~33 TFLOPS non-causal

Contributed by Second Nature Computing (https://joinsecondnature.com)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add FlashAttentionBackwardSm120 subclass with SM120's 99 KB SMEM constraint
and 128 threads (4 warps) matching the forward pass MMA layout.

SM120 backward key design decisions:
- m_block=n_block=64 with all-M atom layout (4,1,1) — same pattern as
  the working forward pass, avoids fragment dimension mismatches
- D<=64: 2 pipeline stages for Q and dO (~65 KB SMEM)
- D>64: 1 pipeline stage (~81 KB SMEM, fits in 99 KB)
- SdP/dKV/dQ swapAB all False (simplest layout, all warps in M)
- Postprocess uses 128 threads (matching backward kernel's MMA atom
  layout for correct dq_accum register-to-memory mapping)

Also fix latent API drift in FlashAttentionBackwardSm80.__call__:
- AttentionMask constructor: pass SeqlenInfoQK object (not separate ints)
- apply_mask: pass batch_idx and head_idx keyword arguments

Also add missing variable definitions (num_stages_Q, num_stages_dO,
SdP_swapAB, AtomLayoutMSdP) in the SM100 else block.

Validated on NVIDIA GB10 (DGX Spark, SM121a):
- 22/22 forward+backward tests pass (BF16, non-causal + causal)
- D=64: seqlen 128-1024, B=1-8, H=8-32, max gradient diff < 0.016
- D=128: seqlen 128-512, B=1-4, H=8, max gradient diff < 0.016
- All gradients verified against torch.nn.functional.scaled_dot_product_attention

Known limitation: CUTLASS CuTe DSL JIT compiler has a resource exhaustion
bug that causes segfault after ~8 unique kernel compilations in a single
process. Mixing different head_dim values (e.g. D=64 and D=128) in one
process may trigger this. Each head_dim works correctly in isolation.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Refactors SM80 forward __call__ and kernel to support variable-length
packed sequences via SingleTileVarlenScheduler + SeqlenInfoQK. The SM80
backward already supported varlen natively.

Forward changes:
- SM80 __call__: expanded signature to accept mCuSeqlensQ/K, mSeqUsedQ/K;
  conditional 3D/4D layout transpose; tile scheduler selection
  (varlen/causal-LPT/basic); merged SM120 compile/invoke into shared path
- SM80 kernel: tile scheduler pattern (TileScheduler.create →
  initial_work_tile_info → tile_idx) replaces direct block_idx();
  if work_tile.is_valid_tile guard for varlen padding tiles;
  offset_batch_Q/K for transparent fixed-length/varlen tensor indexing

Backward changes:
- Removed varlen/seqused blocking asserts in SM120 backward dispatch
- Wired real cu_seqlens_q/k and seqused_q/k tensors to SM120 backward
  compile and invoke (were hardcoded to None)
- Added varlen state to backward compile_key

Validation on SM121a (DGX Spark):
- Varlen forward: 13/13 pass (D=64/128, causal/non-causal, edge cases
  including seqlens=[1,1,1,1], [7,13,31,3], [333])
- Varlen forward+backward: 8/8 pass
- Non-varlen regression: 5/5 pass

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Split-KV (FlashDecoding):
- Add is_split_kv parameter to FlashAttentionForwardBase and SM80 kernel
- Extend layout transposes for O and LSE with leading split dimension
- Wire num_splits through TileSchedulerArguments and tile scheduler grid
- Use split_idx from tile scheduler in epilogue for partial O/LSE writes
- SM120 kernel writes BF16/FP16 partials (SM80-era epilogue); convert to
  FP32 before FlashAttentionForwardCombine
- Zero-init out_partial and -inf init lse_partial for empty splits (causal +
  split-KV may produce splits with no K blocks)
- Split-KV mainloop correctly iterates n_block_min..n_block_max per split
  partition via BlockInfo.get_n_block_min_max()

Paged KV:
- SM80 kernel's swizzled SMEM layout (composed with Swizzle(3,3,3)) is
  incompatible with PagedKVManager's tiled copy, which creates a plain 2D
  layout from sX.stride and loses the swizzle. Instead of kernel-level paged
  KV, resolve the page table at Python level in interface.py via GPU gather:
  k[page_table.reshape(-1)].reshape(B, max_seqlen_k, H, D)
- Requires seqused_k to communicate actual sequence lengths to the kernel

Validation on SM121a (NVIDIA GB10):
- Split-KV: 12/12 pass (D=64/128, causal/non-causal, 2-4 splits)
- Paged KV: 14/14 pass (D=64/128, page_size 64/128, causal/non-causal,
  uniform and varied sequence lengths)
- Paged KV + split-KV combined: 6/6 pass (full inference pattern)
- Varlen regression: 13/13 fwd pass, 8/8 fwd+bwd pass
- Total: 53/53 tests pass

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The varlen commit added blocksparse_tensors to the SM80 __call__ signature
but the import was removed during upstream refactoring. Add it back.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc blake-snc force-pushed the feat/sm120-support branch from 45bc87c to 8fd173b Compare March 10, 2026 23:20
@blake-snc blake-snc changed the title Add SM120 (Blackwell GeForce / DGX Spark) forward + backward pass support Add SM120 (Blackwell GeForce / DGX Spark) flash attention Mar 10, 2026
@blake-snc
Copy link
Author

Hey @geraldstanje1! The RTX 6000 Pro Blackwell should work since it's SM120, same arch family I targeted here. Here's how to test (I have not validated this myself, but this should be all you need to do):

  # Clone the PR branch
  git clone -b feat/sm120-support https://github.com/blake-snc/flash-attention.git
  cd flash-attention

  # Install (needs CUDA 12.8+ and nvidia-cutlass-dsl >= 4.4.1)
  pip install -e "flash_attn/cute[dev]"

  # Quick smoke test — forward pass
  python -c "
  import sys
  sys.modules['flash_attn_2_cuda'] = type(sys)('flash_attn_2_cuda')
  import torch
  from flash_attn.cute import flash_attn_func

  q = torch.randn(2, 256, 8, 128, dtype=torch.bfloat16, device='cuda')
  k = torch.randn(2, 256, 8, 128, dtype=torch.bfloat16, device='cuda')
  v = torch.randn(2, 256, 8, 128, dtype=torch.bfloat16, device='cuda')
  out = flash_attn_func(q, k, v, causal=True)
  print(f'Output shape: {out[0].shape}, max: {out[0].max():.4f}')
  print('Forward pass OK!')
  "

  # Full test suite (first compile pass takes a while)
  FLASH_ATTENTION_FAKE_TENSOR=1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 64 -x tests/cute/test_flash_attn.py
  FLASH_ATTENTION_FAKE_TENSOR=0 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -x tests/cute/test_flash_attn.py

I validated on DGX Spark (SM121a) but I don't have access to test on RTX 6000 Pros directly. If you run into issues, let me know as any feedback is very helpful.

@blake-snc
Copy link
Author

Per @johnnynunez's suggestion, split this into two PRs:

#2325 should be mergeable independently. This PR depends on #2325 for the base class fixes.

@geraldstanje1
Copy link

geraldstanje1 commented Mar 11, 2026

Hey @geraldstanje1! The RTX 6000 Pro Blackwell should work since it's SM120, same arch family I targeted here. Here's how to test (I have not validated this myself, but this should be all you need to do):

  # Clone the PR branch
  git clone -b feat/sm120-support https://github.com/blake-snc/flash-attention.git
  cd flash-attention

  # Install (needs CUDA 12.8+ and nvidia-cutlass-dsl >= 4.4.1)
  pip install -e "flash_attn/cute[dev]"

  # Quick smoke test — forward pass
  python -c "
  import sys
  sys.modules['flash_attn_2_cuda'] = type(sys)('flash_attn_2_cuda')
  import torch
  from flash_attn.cute import flash_attn_func

  q = torch.randn(2, 256, 8, 128, dtype=torch.bfloat16, device='cuda')
  k = torch.randn(2, 256, 8, 128, dtype=torch.bfloat16, device='cuda')
  v = torch.randn(2, 256, 8, 128, dtype=torch.bfloat16, device='cuda')
  out = flash_attn_func(q, k, v, causal=True)
  print(f'Output shape: {out[0].shape}, max: {out[0].max():.4f}')
  print('Forward pass OK!')
  "

  # Full test suite (first compile pass takes a while)
  FLASH_ATTENTION_FAKE_TENSOR=1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 64 -x tests/cute/test_flash_attn.py
  FLASH_ATTENTION_FAKE_TENSOR=0 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -x tests/cute/test_flash_attn.py

I validated on DGX Spark (SM121a) but I don't have access to test on RTX 6000 Pros directly. If you run into issues, let me know as any feedback is very helpful.

@blake-snc here the test results - looks like there are some errors - any idea?

./test.sh
===== GPU INFO =====
Wed Mar 11 01:38:29 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.126.09             Driver Version: 580.126.09     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX PRO 6000 Blac...    Off |   00000000:2F:00.0 Off |                    0 |
| N/A   22C    P8             27W /  600W |       0MiB /  97887MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
====================
Cloning into 'flash-attention'...
remote: Enumerating objects: 13429, done.        
remote: Counting objects: 100% (787/787), done.        
remote: Compressing objects: 100% (280/280), done.        
remote: Total 13429 (delta 667), reused 507 (delta 507), pack-reused 12642 (from 3)        
Receiving objects: 100% (13429/13429), 19.45 MiB | 57.57 MiB/s, done.
Resolving deltas: 100% (10306/10306), done.

Obtaining file:///teamspace/studios/this_studio/tmp/flash-attention/flash_attn/cute
  Installing build dependencies ... done
  Checking if build backend supports build_editable ... done
  Getting requirements to build editable ... done
  Preparing editable metadata (pyproject.toml) ... done
Requirement already satisfied: nvidia-cutlass-dsl>=4.4.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from flash-attn-4==0.0.1.dev1360+g8fd173bb3) (4.4.1)
Requirement already satisfied: torch in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from flash-attn-4==0.0.1.dev1360+g8fd173bb3) (2.9.1)
Requirement already satisfied: einops in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from flash-attn-4==0.0.1.dev1360+g8fd173bb3) (0.8.2)
Requirement already satisfied: typing_extensions in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from flash-attn-4==0.0.1.dev1360+g8fd173bb3) (4.15.0)
Requirement already satisfied: apache-tvm-ffi<0.2,>=0.1.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from flash-attn-4==0.0.1.dev1360+g8fd173bb3) (0.1.9)
Requirement already satisfied: torch-c-dlpack-ext in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from flash-attn-4==0.0.1.dev1360+g8fd173bb3) (0.1.5)
Requirement already satisfied: quack-kernels>=0.2.10 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from flash-attn-4==0.0.1.dev1360+g8fd173bb3) (0.2.10)
Requirement already satisfied: setuptools in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from flash-attn-4==0.0.1.dev1360+g8fd173bb3) (80.10.2)
Requirement already satisfied: pytest in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from flash-attn-4==0.0.1.dev1360+g8fd173bb3) (9.0.2)
Requirement already satisfied: ruff in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from flash-attn-4==0.0.1.dev1360+g8fd173bb3) (0.15.5)
Requirement already satisfied: nvidia-cutlass-dsl-libs-base==4.4.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from nvidia-cutlass-dsl>=4.4.1->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (4.4.1)
Requirement already satisfied: numpy in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from nvidia-cutlass-dsl-libs-base==4.4.1->nvidia-cutlass-dsl>=4.4.1->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (2.2.6)
Requirement already satisfied: cuda-python>=12.8 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from nvidia-cutlass-dsl-libs-base==4.4.1->nvidia-cutlass-dsl>=4.4.1->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (12.9.4)
Requirement already satisfied: cuda-bindings~=12.9.4 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from cuda-python>=12.8->nvidia-cutlass-dsl-libs-base==4.4.1->nvidia-cutlass-dsl>=4.4.1->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (12.9.4)
Requirement already satisfied: cuda-pathfinder~=1.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from cuda-bindings~=12.9.4->cuda-python>=12.8->nvidia-cutlass-dsl-libs-base==4.4.1->nvidia-cutlass-dsl>=4.4.1->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (1.4.1)
Requirement already satisfied: iniconfig>=1.0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pytest->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (2.3.0)
Requirement already satisfied: packaging>=22 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pytest->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (25.0)
Requirement already satisfied: pluggy<2,>=1.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pytest->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (1.6.0)
Requirement already satisfied: pygments>=2.7.2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pytest->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (2.19.2)
Requirement already satisfied: filelock in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (3.25.0)
Requirement already satisfied: sympy>=1.13.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (1.14.0)
Requirement already satisfied: networkx>=2.5.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (3.6.1)
Requirement already satisfied: jinja2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (3.1.6)
Requirement already satisfied: fsspec>=0.8.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (2026.2.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (12.8.93)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (12.8.90)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (12.8.90)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (12.8.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (11.3.3.83)
Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (10.3.9.90)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (11.7.3.90)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (12.5.8.93)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (2.27.5)
Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (3.3.20)
Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (12.8.90)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (12.8.93)
Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (1.13.1.3)
Requirement already satisfied: triton==3.5.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (3.5.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from sympy>=1.13.3->torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from jinja2->torch->flash-attn-4==0.0.1.dev1360+g8fd173bb3) (3.0.3)
Building wheels for collected packages: flash-attn-4
  Building editable for flash-attn-4 (pyproject.toml) ... done
  Created wheel for flash-attn-4: filename=flash_attn_4-0.0.1.dev1360+g8fd173bb3-0.editable-py3-none-any.whl size=5072 sha256=5cc81cc0970f26541531fd6341e0e8613f66eb62d9d70c7fe5c4d28d885d9cf6
  Stored in directory: /tmp/pip-ephem-wheel-cache-f7ja2fui/wheels/d2/93/09/524c3cc12c88cc57ec55ebc1f784928518a9f1e3e8a1ba12de
Successfully built flash-attn-4
Installing collected packages: flash-attn-4
  Attempting uninstall: flash-attn-4
    Found existing installation: flash-attn-4 0.0.1.dev1360+g8fd173bb3
    Uninstalling flash-attn-4-0.0.1.dev1360+g8fd173bb3:
      Successfully uninstalled flash-attn-4-0.0.1.dev1360+g8fd173bb3
Successfully installed flash-attn-4-0.0.1.dev1360+g8fd173bb3
/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Output shape: torch.Size([2, 256, 8, 128]), max: 3.2344
Forward pass OK!

log files:
test1.log
test2.log

test.sh:

#!/usr/bin/env bash
set -euo pipefail

echo "===== GPU INFO ====="
nvidia-smi
echo "===================="

# Clone the PR branch
rm -rf flash-attention
git clone -b feat/sm120-support https://github.com/blake-snc/flash-attention.git
cd flash-attention

# Install (requires CUDA 12.8+ and nvidia-cutlass-dsl >= 4.4.1)
pip install -e "flash_attn/cute[dev]"

# Quick smoke test — forward pass
python <<'PY'
import sys
sys.modules['flash_attn_2_cuda'] = type(sys)('flash_attn_2_cuda')

import torch
from flash_attn.cute import flash_attn_func

q = torch.randn(2, 256, 8, 128, dtype=torch.bfloat16, device='cuda')
k = torch.randn(2, 256, 8, 128, dtype=torch.bfloat16, device='cuda')
v = torch.randn(2, 256, 8, 128, dtype=torch.bfloat16, device='cuda')

out = flash_attn_func(q, k, v, causal=True)

print(f"Output shape: {out[0].shape}, max: {out[0].max():.4f}")
print("Forward pass OK!")
PY

# Run full test suite (first compile pass may take a while)
FLASH_ATTENTION_FAKE_TENSOR=1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 64 -x tests/cute/test_flash_attn.py > test1.log 2>&1 || true
FLASH_ATTENTION_FAKE_TENSOR=0 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -x tests/cute/test_flash_attn.py > test2.log 2>&1 || true

also how can i test this pr with vllm release and model gpt-oss-safeguard-20b?

@johnnynunez
Copy link
Contributor

cc @tridao @drisspg could you take a look to close the gap with blackwell geforce?

@tridao
Copy link
Member

tridao commented Mar 11, 2026

This is a big change so let's split it into multiple PRs for each of the features:
Then we can merge one by one while testing in the next couple of days.

Forward pass: non-causal + causal, D=64 and D=128, tile sizes tuned for SM120 SMEM (D≤64: 128×128, D>64: 128×64)
Backward pass: D=64 and D=128, 4-warp layout matching forward
Variable-length (varlen): forward + backward with cu_seqlens_q/k and seqused_q/k
Split-KV (FlashDecoding): multi-split inference decoding
Paged KV cache: Python-level page table resolution (SM80 swizzled SMEM layout is incompatible with PagedKVManager's tiled copy)

Btw the bwd preprocessing now is arch agnostic so you wont' need to change it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants