Add SM120 (Blackwell GeForce / DGX Spark) flash attention#2268
Add SM120 (Blackwell GeForce / DGX Spark) flash attention#2268blake-snc wants to merge 6 commits intoDao-AILab:mainfrom
Conversation
|
@blake-snc planned to add TMA support? |
3c73a5f to
59ddcc9
Compare
|
@johnnynunez Yes — TMA support is already implemented in our CUTLASS PR (NVIDIA/cutlass#3030, the This flash-attention PR is the CpAsync baseline adapted to Dao-AILab's CuTe DSL interface. Just rebased onto latest main to fix the merge conflict. We could add TMA here too, but wanted to get the basic forward pass landed first. Current status: forward-only, BF16/FP16, causal/non-causal, MHA/GQA/MQA, hdim 64/96/128. Still missing: backward pass, varlen, paged KV, split-KV. Happy to coordinate on expanding it! |
59ddcc9 to
695437f
Compare
|
Updated with split-KV (FlashDecoding) and paged KV support. Latest push adds: Split-KV / FlashDecoding — splits the K dimension across multiple thread blocks for long-context decode (few Q tokens, many KV tokens). Each split produces BF16/FP16 partial outputs + FP32 LSE, which are converted to FP32 and merged by the existing Paged KV cache — supports paged KV for inference engines (vLLM/SGLang). The SM80 kernel's swizzled SMEM layout (composed with All existing tests still pass (53/53 total including regressions). Regarding TMA: This PR covers the CpAsync baseline with full feature support (fwd+bwd, varlen, split-KV, paged KV). TMA optimization ( We have a working TMA implementation already validated on SM121a in our CUTLASS PR (NVIDIA/cutlass#3030, the No other existing PR in flash-attention addresses TMA for SM120, and the current CUTLASS 4.4.1 release does not include SM120 FA examples (our CUTLASS PR #3030 is still open on the |
|
hi @blake-snc is there a way to test this with nvidia rtx 6000 pro blackwell and well? any instructions for me to try? |
The SM80 base classes (FlashAttentionForwardSm80, FlashAttentionBackwardSm80) had 5 latent bugs where their code fell out of sync with upstream API changes to SeqlenInfoQK, AttentionMask, and _check_type. Forward (flash_fwd.py): - _check_type: pass None for 4 varlen type args (signature expanded for varlen) - SeqlenInfoQK.create: pass batch_size as required first positional arg - compute_one_n_block: pass seqlen= (required by score_mod path) - AttentionMask: pass SeqlenInfoQK object instead of seqlen_q/seqlen_k - mask.apply_mask: pass batch_idx and head_idx (required by mask_mod path) Backward (flash_bwd.py): - AttentionMask: same fix as forward - mask.apply_mask: same fix as forward These are all one-line fixes that align the SM80 classes with the current API. None of these paths are exercised by SM90/SM100 dispatch (they have their own __call__ implementations), so this is a no-op for existing users. Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add FlashAttentionForwardSm120 subclass with SM120's 99 KB SMEM constraint - Fix 5 latent API-drift bugs in FlashAttentionForwardSm80.__call__ - Add SM120 dispatch with optimized tile sizes (D<=64: 128x128, D>64: 128x64) - Integrate with persistent compile cache, fake tensor mode, use_2cta_instrs Validated on NVIDIA GB10 (DGX Spark, SM121a): - 10/10 correctness tests pass (non-causal + causal, max_diff < 0.016 BF16) - Peak ~49 TFLOPS causal, ~33 TFLOPS non-causal Contributed by Second Nature Computing (https://joinsecondnature.com) Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add FlashAttentionBackwardSm120 subclass with SM120's 99 KB SMEM constraint and 128 threads (4 warps) matching the forward pass MMA layout. SM120 backward key design decisions: - m_block=n_block=64 with all-M atom layout (4,1,1) — same pattern as the working forward pass, avoids fragment dimension mismatches - D<=64: 2 pipeline stages for Q and dO (~65 KB SMEM) - D>64: 1 pipeline stage (~81 KB SMEM, fits in 99 KB) - SdP/dKV/dQ swapAB all False (simplest layout, all warps in M) - Postprocess uses 128 threads (matching backward kernel's MMA atom layout for correct dq_accum register-to-memory mapping) Also fix latent API drift in FlashAttentionBackwardSm80.__call__: - AttentionMask constructor: pass SeqlenInfoQK object (not separate ints) - apply_mask: pass batch_idx and head_idx keyword arguments Also add missing variable definitions (num_stages_Q, num_stages_dO, SdP_swapAB, AtomLayoutMSdP) in the SM100 else block. Validated on NVIDIA GB10 (DGX Spark, SM121a): - 22/22 forward+backward tests pass (BF16, non-causal + causal) - D=64: seqlen 128-1024, B=1-8, H=8-32, max gradient diff < 0.016 - D=128: seqlen 128-512, B=1-4, H=8, max gradient diff < 0.016 - All gradients verified against torch.nn.functional.scaled_dot_product_attention Known limitation: CUTLASS CuTe DSL JIT compiler has a resource exhaustion bug that causes segfault after ~8 unique kernel compilations in a single process. Mixing different head_dim values (e.g. D=64 and D=128) in one process may trigger this. Each head_dim works correctly in isolation. Contributed by Second Nature Computing (https://joinsecondnature.com) Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Refactors SM80 forward __call__ and kernel to support variable-length packed sequences via SingleTileVarlenScheduler + SeqlenInfoQK. The SM80 backward already supported varlen natively. Forward changes: - SM80 __call__: expanded signature to accept mCuSeqlensQ/K, mSeqUsedQ/K; conditional 3D/4D layout transpose; tile scheduler selection (varlen/causal-LPT/basic); merged SM120 compile/invoke into shared path - SM80 kernel: tile scheduler pattern (TileScheduler.create → initial_work_tile_info → tile_idx) replaces direct block_idx(); if work_tile.is_valid_tile guard for varlen padding tiles; offset_batch_Q/K for transparent fixed-length/varlen tensor indexing Backward changes: - Removed varlen/seqused blocking asserts in SM120 backward dispatch - Wired real cu_seqlens_q/k and seqused_q/k tensors to SM120 backward compile and invoke (were hardcoded to None) - Added varlen state to backward compile_key Validation on SM121a (DGX Spark): - Varlen forward: 13/13 pass (D=64/128, causal/non-causal, edge cases including seqlens=[1,1,1,1], [7,13,31,3], [333]) - Varlen forward+backward: 8/8 pass - Non-varlen regression: 5/5 pass Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Split-KV (FlashDecoding): - Add is_split_kv parameter to FlashAttentionForwardBase and SM80 kernel - Extend layout transposes for O and LSE with leading split dimension - Wire num_splits through TileSchedulerArguments and tile scheduler grid - Use split_idx from tile scheduler in epilogue for partial O/LSE writes - SM120 kernel writes BF16/FP16 partials (SM80-era epilogue); convert to FP32 before FlashAttentionForwardCombine - Zero-init out_partial and -inf init lse_partial for empty splits (causal + split-KV may produce splits with no K blocks) - Split-KV mainloop correctly iterates n_block_min..n_block_max per split partition via BlockInfo.get_n_block_min_max() Paged KV: - SM80 kernel's swizzled SMEM layout (composed with Swizzle(3,3,3)) is incompatible with PagedKVManager's tiled copy, which creates a plain 2D layout from sX.stride and loses the swizzle. Instead of kernel-level paged KV, resolve the page table at Python level in interface.py via GPU gather: k[page_table.reshape(-1)].reshape(B, max_seqlen_k, H, D) - Requires seqused_k to communicate actual sequence lengths to the kernel Validation on SM121a (NVIDIA GB10): - Split-KV: 12/12 pass (D=64/128, causal/non-causal, 2-4 splits) - Paged KV: 14/14 pass (D=64/128, page_size 64/128, causal/non-causal, uniform and varied sequence lengths) - Paged KV + split-KV combined: 6/6 pass (full inference pattern) - Varlen regression: 13/13 fwd pass, 8/8 fwd+bwd pass - Total: 53/53 tests pass Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The varlen commit added blocksparse_tensors to the SM80 __call__ signature but the import was removed during upstream refactoring. Add it back. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
45bc87c to
8fd173b
Compare
|
Hey @geraldstanje1! The RTX 6000 Pro Blackwell should work since it's SM120, same arch family I targeted here. Here's how to test (I have not validated this myself, but this should be all you need to do): I validated on DGX Spark (SM121a) but I don't have access to test on RTX 6000 Pros directly. If you run into issues, let me know as any feedback is very helpful. |
|
Per @johnnynunez's suggestion, split this into two PRs:
#2325 should be mergeable independently. This PR depends on #2325 for the base class fixes. |
@blake-snc here the test results - looks like there are some errors - any idea? log files: test.sh: also how can i test this pr with vllm release and model gpt-oss-safeguard-20b? |
|
This is a big change so let's split it into multiple PRs for each of the features:
Btw the bwd preprocessing now is arch agnostic so you wont' need to change it. |
Summary
Add flash attention support for SM120 (NVIDIA Blackwell GeForce / DGX Spark, compute capability 12.x). SM120 uses SM80-era MMA instructions (
mma.sync.aligned.m16n8k16) with 99 KB shared memory.Features:
cu_seqlens_q/kandseqused_q/kPagedKVManager's tiled copy)Architecture:
FlashAttentionForwardSm120subclassesFlashAttentionForwardSm80witharch=80(CpAsync code paths) and an SM120 SMEM capacity checkFlashAttentionBackwardSm120subclassesFlashAttentionBackwardSm80similarlySingleTileVarlenScheduler,SeqlenInfoQKoffset helpers) and split-KV supportDepends on #2325 (SM80 API drift fixes).
Rebased on current
main(resolves prior merge conflicts).Validation on SM121a (NVIDIA GB10, DGX Spark)
Forward (non-causal + causal):
Backward (gradient diffs):
Varlen forward (non-causal + causal, seqlens=[32,64,48,16]):
Split-KV (FlashDecoding):
All diffs are within BF16 precision bounds. Reference:
torch.nn.functional.scaled_dot_product_attention.Test plan
arch // 10 == 12, new subclasses are in separate files, SM80 base class changes are behindconst_exprvarlen/split-KV branches) — no SM90/SM100 hardware available for runtime verificationContributed by Second Nature Computing (https://joinsecondnature.com)
Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com