Skip to content

Tugbars/Flash-Attention-PTX-CUDA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

37 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

attention_heatmap

Flash Attention PTX/CUDA

Hand-written PTX flash attention kernel achieving 156 TFLOPS on RTX 5080
67% of theoretical peak · 58× faster than scalar baseline · no WGMMA, no TMA, no shortcuts

Performance · How It Works · Building · Visualization · Architecture


What is this?

A from-scratch flash attention implementation in raw CUDA/PTX targeting consumer NVIDIA GPUs (RTX 5080, Blackwell sm_120). No libraries, no CUTLASS attention wrappers, no cuDNN — just hand-written kernels optimized step by step from 2.7 TFLOPS to 156.4 TFLOPS.

The kernel uses PTX inline assembly for mma.sync.aligned.m16n8k16 tensor core operations with ldmatrix for optimal shared memory → register transfers, and performs the full softmax in registers using warp shuffle intrinsics, eliminating the largest shared memory bottleneck in standard flash attention implementations.

Consumer Blackwell (sm_120) lacks the datacenter features that make H100/B200 attention kernels fast: no WGMMA (warp group MMA), no TMA (tensor memory accelerator), no warp specialization barriers. This kernel achieves competitive utilization using only the tools available on consumer silicon.

Performance

performance

Peak: 156.4 TFLOPS at B=8, H=12, S=2048, D=64 (causal attention).

Config TFLOPS % Peak Notes
B=1, S=512 48.4 20.6% Small-tile variant (auto-dispatched)
B=1, S=2048 109.6 46.7%
B=4, S=2048 151.5 64.5%
B=8, S=2048 156.4 66.6% Sweet spot
B=1, S=4096 142.6 60.7%

Measured on RTX 5080 (84 SMs, 234.8 TFLOPS FP16 theoretical peak).

Under-saturated workloads (small batch × short sequence, e.g. B=1, S<1024) are auto-dispatched to a 32×64 / 4-warp tile variant that doubles the grid count and fills the SMs — same kernel template, smaller M-tile. See Architecture.

For context, Flash Attention 2 on the A100 (datacenter Ampere) achieves approximately 60% tensor core utilization. This consumer Blackwell kernel reaches 67% without WGMMA, TMA, or warp specialization, using only tools available on consumer silicon.

Optimization progression

Each version identified and eliminated a specific bottleneck. Every change was validated with Nsight Compute profiling.

Version TFLOPS Bottleneck Removed
v1 — Scalar FP32 2.7 Baseline, no tensor cores
v3 — WMMA fragments 26.7 Enabled tensor core MMA
v6 — Vectorized loads 38.3 uint4 coalesced global memory access
v7 — PTX MMA + ldmatrix 49.2 Known register layout, eliminated fragment opacity
v8 — In-register softmax 125.2 Eliminated 16KB smem_s round-trip
v9 — Direct rescale 135.9 exp(S−new_max) directly, fewer critical-path ops
v10 — cp.async loads 156.4 gmem→smem direct (no register staging), LSU freed for compute

Profiler metrics (B=4, S=2048)

Metric smem_s path In-register v8 Corrected v9
Tensor core utilization 17.9% 50.1% ~54%
L1/smem throughput 31.3% 32.5% 32.5%
Active warps 15.6% 29.7% ~32%

How It Works

The core idea: keep S in registers

Standard flash attention implementations write the S = Q×K^T attention scores to shared memory, synchronize, read them back for softmax, write the softmax output P to shared memory, synchronize again, then load P for the P×V multiply. This creates two full shared memory round-trips per KV tile.

Our kernel keeps S in registers after the Q×K^T MMA. Each thread holds s_acc[4][4] — 16 attention score values at known (row, column) positions determined by the m16n8k16 MMA layout. Softmax is computed directly on these register values:

Q × K^T (PTX MMA)
      ↓
  s_acc[4][4] in registers
      ↓
  shuffle reduce → partial max (32 cols, 4 threads)
      ↓
  smem exchange → global max across warp halves (1KB)
      ↓
  exp(S − new_max) → P values at correct scale
      ↓
  write P to smem_p → ldmatrix → P × V (PTX MMA)

Key optimizations

In-register softmax. The attention score matrix never touches shared memory. Cross-thread reduction uses __shfl_xor_sync across 4 threads sharing each row, then a tiny 1KB shared memory exchange between warp halves for the full 64-column max and sum.

Direct new_max rescaling (v9). Instead of computing exp(S − tile_max) then post-multiplying by exp(tile_max − new_max), we compute new_max = max(prev_max, tile_max) before the exponential and subtract new_max directly: exp(S − new_max). P values are written to shared memory already at the correct scale. This eliminates multiplications from the critical path and fixes subtle accuracy drift on long sequences.

Correct ldmatrix.x2.trans addressing. PTX's ldmatrix.sync.aligned.m8n8.x2.trans loads two 8×8 matrices using threads 0-15. Threads 0-7 provide addresses for matrix 0, threads 8-15 for matrix 1. The second group must offset by +8 columns (for K) or +8 rows (for V) to load the full 16-element k-dimension:

int k_row = lane_id % 8;
int mat   = (lane_id / 8) % 2;  // 0 for threads 0-7, 1 for threads 8-15
ldmatrix_x2_trans(b0, b1,
    smem_k + (ni*8 + k_row) * KV_STRIDE + ki*16 + mat*8);

Online softmax with cross-warp correction. Each warp pair (2 warps) handles a 16-row × 64-column output tile. The softmax running maximum and sum are maintained per-thread for two rows (row0 and row0+8, matching the MMA layout). When new_max > prev_max, the old O accumulator is rescaled by exp(prev_max − new_max).

What this GPU doesn't have

The RTX 5080 (sm_120, consumer Blackwell) lacks datacenter features that production attention kernels rely on:

Feature Datacenter Consumer (this kernel)
MMA width WGMMA: 128 threads, B from smem mma.sync: 32 threads, B from registers
Memory loads TMA: hardware DMA, zero thread cost Manual uint4 loads by all threads
Pipeline Warp specialization barriers Uniform warps, explicit __syncthreads
Register file TMEM (sm_100+): dedicated tensor RF Standard register file only

Building

Requirements

  • CUDA Toolkit 12.0+ (tested with 13.1)
  • CMake 3.20+
  • C++17 compiler
  • NVIDIA GPU with compute capability 8.0+ (Ampere, Ada, Blackwell)

Build

git clone https://github.com/yourusername/flash-attention-cuda.git
cd flash-attention-cuda
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
cmake --build . --target flash_demo

For fastest builds targeting only your GPU:

set(CMAKE_CUDA_ARCHITECTURES "120")  # RTX 5080

Run

./flash_demo                           # run kernel + dump data
python ../scripts/visualize.py .       # generate figures

Output:

============================================================
  Flash Attention Demo
  GPU: NVIDIA GeForce RTX 5080 (84 SMs, CC 12.0)
  Peak FP16: 234.8 TFLOPS
============================================================

[1] Generating structured Q, K, V (B=1, H=4, S=256, D=64)
[3] Running GPU flash attention kernel...
    Average: 0.0122 ms
    TFLOPS:  5.51
[4] Correctness check...
    Max absolute error: 0.007680
    Normalized RMSE:    0.023781
    Bad elements:       0 / 65536
    Result:             PASS ✓
[5] Writing visualization data...

Visualization

The demo generates binary dumps that visualize.py turns into publication-quality plots.

Attention heatmaps

Four synthetic heads demonstrate the kernel handles diverse attention patterns correctly:

attention_heatmap

  • Local: Diagonal band — nearby tokens attend to each other
  • Strided: Periodic stripes — tokens at matching phase positions attend
  • Global+Anchor: Bright left column — every query attends to early tokens
  • Block: Staircase — strong intra-block attention with sharp boundaries

Error analysis

error_analysis

Median absolute error: 0.00025 (FP16 precision). Error is highest at early sequence positions where softmax has fewer tokens to average over, then drops to near-zero. No systematic patterns — pure FP16 quantization noise.

Per-head behavior

head_comparison

Architecture

flash-attention-cuda/
├── CMakeLists.txt
├── include/
│   └── flash_attention.h           # FlashAttentionParams struct + launch declaration
├── kernels/
│   └── flash_attention.cu          # v10 kernel + tile-size dispatcher — the main event
├── src/
│   └── demo.cu                     # standalone demo + correctness check
├── scripts/
│   └── visualize.py                # generates all figures from binary dumps
├── figures/                        # pre-generated for README
│   ├── attention_heatmap.png
│   ├── performance.png
│   ├── error_analysis.png
│   └── head_comparison.png
└── docs/
    └── flash_attention_story.html  # interactive deep-dive with animations

Kernel parameters

Parameter Big tile Small tile
BLOCK_M 64 32
BLOCK_N 64 64
D_HEAD 64 64
NUM_WARPS 8 (4 warp pairs) 4 (2 warp pairs)
Shared memory ~37 KB ~28 KB
MMA instruction mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 (same)
Precision FP16 compute, FP32 accumulation (same)

Tile-size dispatcher

Both variants are the same kernel template — only BLOCK_M and NUM_WARPS differ. The kernel's warp partitioning (warp_pair = warp_id/2 = mi, warp_half = warp_id%2) generalizes to either NUM_WARPS=4 (2 warp pairs, 2 m-tiles) or NUM_WARPS=8 (4 warp pairs, 4 m-tiles).

The host launch picks at runtime based on whether the workload saturates the GPU:

num_blocks_big = B * H * ceil(S / 64);
if (num_blocks_big < 2 * sm_count)  → small tile (32×64, 4 warps)
else                                 → big tile (64×64, 8 warps)

The small variant wins on under-saturated workloads (B=1, S<1024 on the RTX 5080) by halving BLOCK_M and doubling the grid count, which fills the SMs that the big-tile grid would leave idle. On B=1, S=512 this is +32% over the big tile (36.7 → 48.4 TFLOPS). On saturated workloads the big tile wins by amortizing per-block overhead, so it stays the production path for everything except the smallest configs.

Shared memory layout

smem_q:            64 × 72 × 2B  =  9.0 KB   Q tile
smem_k:            64 × 72 × 2B  =  9.0 KB   K tile
smem_v:            64 × 72 × 2B  =  9.0 KB   V tile
smem_p:            64 × 72 × 2B  =  9.0 KB   P = softmax(S)
smem_partial_max:  2 × 64 × 4B   =  0.5 KB   cross-warp max exchange
smem_partial_sum:  2 × 64 × 4B   =  0.5 KB   cross-warp sum exchange
                                    --------
Total:                              ~37 KB

Correctness

Verified against a CPU reference implementation (naive O(S²) attention with FP32 arithmetic):

Sequence Length Max Error NRMSE Bad Elements
64 0.0077 2.37% 0
128 0.0077 2.38% 0
256 0.0077 2.38% 0
2048 0

All errors are within FP16 precision bounds. The v9 direct rescaling fix ensures no accuracy drift on longer sequences.

Interactive deep-dive

Open docs/flash_attention_story.html in a browser for a scroll-animated walkthrough of the full optimization journey with architecture diagrams, profiler data, and code comparisons.

License

MIT

Acknowledgments

Built for the RTX 5080 — proving that consumer GPUs can run serious attention kernels when you're willing to write the PTX yourself.

About

Hand-written PTX flash attention kernel achieving 67% tensor core utilization on RTX 5080, matching A100's Flash Attention 2 without WGMMA, TMA, or datacenter hardware. 157 TFLOPS FP16.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors