FlashAttention computes exact scaled dot‑product attention while minimizing global memory (HBM) IO. Instead of materializing the
- ✅ Exact forward pass (no approximation) with online softmax
- ✅ Row‑split multi‑CTA per head for high occupancy
- ✅ Warp specialization (producer / consumer warps)
- ✅ Double buffering (2‑stage shared‑memory pipeline; optional
cp.async) - ✅ Vectorized 16‑B copies (
uint4) with safe fallbacks - ✅ BF16 / FP16 / FP32 I/O, always FP32 accumulation
- ✅ Split‑K forward (optional, two‑pass merge)
- ✅ Autotuner to pick tile sizes and block configs per ((S,D))
- 🔧 WMMA/Tensor‑Core hooks (optional, off by default)
# Configure & build
cmake -B build -S . -DCMAKE_BUILD_TYPE=Release
cmake --build build -j
# (Optional) if you added the example target in CMake:
./build/minimalArchitectures: SM70+ runs the fallback path; SM80+/SM90 get async copy when FA_USE_CP_ASYNC is enabled (see flags).
Layout: All tensors are contiguous [B, H, S, D].
Header: include/fa/api.hpp
namespace fa {
// Simple forward (row‑split multi‑CTA per head)
template<typename T> // T = float, __half, __nv_bfloat16
void flash_attention_forward(
T* q, T* k, T* v, T* o,
int B, int H, int S, int D,
bool causal,
const LaunchConfig& lcfg,
cudaStream_t stream = 0);
// Split‑K forward (two‑pass merge)
template<typename T>
void flash_attention_forward_splitk(
T* q, T* k, T* v, T* o,
int B, int H, int S, int D,
bool causal,
int k_splits, // >1 enables split‑K
void* workspace, size_t workspace_bytes,
const LaunchConfig& lcfg,
cudaStream_t stream = 0);
struct LaunchConfig {
int tile_m = 224; // queries per CTA
int tile_n = 64; // keys per tile
int block = 256; // threads per CTA
int loaders= 1; // producer warps
};
} // namespace faAutotuner (optional): include/fa/autotune.hpp
fa::FlashAttnAutoTuner tuner;
auto best = tuner.get_or_tune(B,H,S,D, causal, q,k,v,o);
// then call flash_attention_forward<T>(..., best, stream);For a single head with
A naïve implementation materializes
For a row
Subtracting
We want the softmax output :
Process keys in tiles
- running max
$(m)$ , - running denominator
$\ell = \sum_{j\in \text{seen}} e^{s_{ij}-m}$ , - running numerator vector
$\mathbf{a} = \sum_{j\in \text{seen}} e^{s_{ij}-m}, v_j$ .
Derivation for merging a new tile (T):
Let
Then
This is the online/streaming softmax update. After all tiles:
It is exact (identical to full softmax), because we merely change the reference
-
Causal: set logits for
$j>i$ to$-\infty$ . In streaming, simply skip masked keys, i.e., do not update$\ell$ or$\mathbf{a}$ for them. -
Padding: if valid key length is
$L$ , ignore$j\ge L$ .
If a full tile is masked, no updates occur; state
-
Arithmetic stays
$O(n d^2)$ (per head) up to constants (dot products & accumulations). -
HBM IO is the bottleneck at large
$n$ . FlashAttention never writes/intermediates$S$ or$P$ ; it streams$K,V$ and keeps partials in shared/registers. For a single head:- Read
$Q$ once per row, read$K,V$ once per tile, write$O$ once. - IO roughly
$O(n d + n d_v)$ rather than$O(n^2)$ intermediates.
- Read
-
Row‑split multi‑CTA increases parallelism: multiple CTAs per head re‑read
$K,V$ but L2 helps amortize; the occupancy win typically dominates.
- Inputs
$Q,K,V$ may be FP16/BF16/FP32; accumulation is FP32. - Use
$\alpha = 1/\sqrt{d}$ scaling to keep logits in a well‑behaved range. - Online softmax keeps exponents centered via the running max
$m$ , curbing overflow/underflow. - If
$\ell=0$ (fully masked row), we emit zeros.
For each batch
Initialize: m = -inf, l = 0, a[:] = 0
for key tile T = {j = k0 .. k0+N-1}:
mT = max_j_in_T ( α * dot(q[i,:], K[j,:]) ) // skip masked j
m_new = max(m, mT)
scale = exp(m - m_new)
l *= scale; a[:] *= scale
for j in T (skip masked):
s = exp( α*dot(q[i],K[j]) - m_new )
l += s
a[:] += s * V[j,:]
m = m_new
o[i,:] = a[:] / l
We map each query row to a consumer thread;
-
Tensors are contiguous
[B,H,S,D]. -
Row‑split over (S):
grid.x = ceil(S / tile_m),grid.y = H,grid.z = B. -
One CTA handles
tile_mquery rows. Within a CTA:loader_warps(default 1) are producers.- The remaining threads are consumers; each consumer thread owns one query row.
-
Producer warp(s) only prefetch the next
$K,V$ tile from HBM to shared memory. - Consumer warps compute dot‑products and online softmax updates on the current tile.
- This reduces barriers and hides memory latency versus “everyone does everything.”
We keep two shared‑memory stages for each of (K) and (V):
Iteration t: compute(K/V stage A) | prefetch next K/V -> stage B
Barrier+swap
Iteration t+1: compute(K/V stage B) | prefetch next K/V -> stage A
Barriers only at tile boundaries. On SM80+, enable FA_USE_CP_ASYNC to issue cp.async lines for genuine async copies; otherwise, we still overlap via warp scheduling.
- If pointers are 16‑B aligned and the byte count is a multiple of 16, we use
uint4(16‑B) copies cooperatively across lanes. - Else, we fall back to scalar loads.
- With
FA_USE_CP_ASYNC(SM80+), the producer warp emitscp.asyncinto shared memory for further latency hiding. The code has both paths.
SMEM footprint (row‑split kernel): We keep two stages of (K) and two of (V) in shared memory:
$$
\text{SMEM bytes} = 4 \cdot \text{tile_n} \cdot D \cdot \mathrm{sizeof}(\mathrm{Storage}),
$$
where Storage is the input type (float/half/bfloat16). Choose tile_n to fit your device SMEM.
grid.xsplits the query rows across CTAs for the same head.- Forward pass needs no cross‑CTA reductions (each row’s softmax is independent).
- This boosts occupancy especially when (B\cdot H) is small and (S) is large.
When you also want to split along keys (e.g., keep CTAs smaller or match cache limits):
-
Partial pass: launch (S_K) splits; each split processes a disjoint key range $[k_{\text{begin}},k_{\text{end}})$. It outputs per‑row partials
$(m^{(t)}, \ell^{(t)}, \mathbf{a}^{(t)})$ into a workspace (FP32). - Merge pass: per row, merge all splits via the same online‑softmax merge rule:
Workspace size for
Row‑split and split‑K can be combined if needed (row‑split for parallelism; split‑K for memory/capacity).
The default inner products are scalar FP32 for clarity and portability. A WMMA path can tile to 16×16×16 fragments (HMMA) and still use the online softmax. Requirements:
- (D) multiple of 16, data in
half/bfloat16. - Use fragments for (Q) and (K^\top) sub‑tiles, accumulate in FP32.
- Keep the two‑pass per‑tile structure (max pass then sum pass) with the same rescaling.
This repo includes compile‑time hooks (FA_USE_WMMA) so you can drop in a WMMA consumer later.
- Shared memory holds two (K) tiles + two (V) tiles.
- Each consumer thread keeps its (q) row, ((m,\ell)), and the accumulator (\mathbf{a}) in registers (FP32).
tile_mis chosen so the number of consumer threads roughly equalstile_m(one row per consumer thread).
include/fa/
config.hpp # tile sizes, warp size, KernelConfig/LaunchConfig
traits.hpp # TypeTraits<T> with float/half/bfloat16 I/O and FP32 convert
tensor.hpp # contiguous [B,H,S,D] Tensor4D view
softmax.hpp # online softmax state (m, l) + rescale
tile_loader.hpp # producer warp: 16B vectorized copies + optional cp.async
row_compute.hpp # consumer: per-row dot products + online softmax update
forward_kernels.cuh # kernel declarations (row-split, split-K partial/merge)
workspace.hpp # workspace sizing helpers (split-K)
autotune.hpp # tiny autotuner class
api.hpp # user-facing API (forward / split-K forward)
src/
forward_kernels.cu # row-split kernel (multi-CTA per head, WS + double-buffer)
forward_splitk.cu # split-K partial + merge kernels (two-pass)
autotune.cu # runtime tuner (tries small config grid, caches best)
api.cu # API implementations + kernel launch plumbing
examples/
minimal_main.cu # (optional) tiny smoke test
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_ARCHITECTURES 80 86 89 90) # tune for your GPUs
add_library(flashattn STATIC
src/forward_kernels.cu
src/forward_splitk.cu
src/autotune.cu
src/api.cu
)
target_include_directories(flashattn PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_compile_definitions(flashattn PRIVATE
# Enable warp-level cp.async on SM80+ (optional):
# FA_USE_CP_ASYNC
# Enable WMMA/Tensor Cores path (optional, experimental hook):
# FA_USE_WMMA
)Dynamic shared‑memory opt‑in is handled in the API before launches via cudaFuncSetAttribute(..., cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes) and setting carve‑out preference to 100%.
-
SMEM sizing (row‑split kernel): (\text{SMEM} = 4 \times \text{tile_n} \times D \times \mathrm{sizeof}(\text{Storage})). For FP16/BF16 Storage and
tile_n=64,D=128→ (4×64×128×2 = 65{,}536) bytes. -
tile_m ≈ number of consumer threads =
block_threads - 32*loader_warps. Default:256 - 32 = 224→tile_m=224(one row per consumer thread). -
loader_warps: start at 1; try 2 if memory‑bound.
-
Autotuner:
FlashAttnAutoTuner::get_or_tune(...)probes a small set and caches per ((S,D,\text{causal})). -
Split‑K: For very long sequences or tight SMEM, try
k_splits=2..4. Remember the workspace size. -
WMMA: When enabling
FA_USE_WMMA, ensure (D) multiple of 16 and sufficient registers; measure carefully.
Correctness (sanity)
- Compare the kernel with a naïve CPU or small GPU reference on random inputs for small (B,H,S,D). Expect max relative error on the order of FP32 rounding (or FP16/BF16 conversion error if using those types).
Stress cases
- Very long (S) (e.g., (S\ge 16\mathrm{k})) to ensure online softmax stability.
- All‑masked rows (causal first row, or padding length zero).
- Mixed precision: inputs in FP16/BF16, verify against FP32 reference.
-
undefined referenceat link Ensure you compile and link all four source files (forward_kernels.cu,forward_splitk.cu,autotune.cu,api.cu). Check that your app links againstflashattn. -
no such file ./build/minimalAdd the example target toCMakeLists.txt:add_executable(minimal examples/minimal_main.cu) target_link_libraries(minimal PRIVATE flashattn)
-
std::functionerror inautotune.cuInclude<functional>or use the templatedtime_msvariant (we provide a version that avoidsstd::function). -
too much shared memoryReducetile_norD, or switch input Storage to FP16/BF16. Also make sure the kernel opts in to large dynamic SMEM (we do this inapi.cu). -
Under‑utilization on small (B\cdot H) Increase
grid.xby decreasingtile_m(more CTAs) or enable split‑K.
- FlashAttention: Fast and Memory‑Efficient Exact Attention with IO‑Awareness, Tri Dao et al., NeurIPS 2022.
- FlashAttention‑2: Faster Attention with Improved Parallelism and Work Partitioning, Tri Dao et al., 2023.
- From Online Softmax to FlashAttention, Zihao Ye, explanatory note.
Why the online merge is exact
Given two disjoint key sets
and similarly for
and
Thus the merged state is
and the final output
HBM traffic sketch
- Naïve materialization: write/read
$S\in\mathbb{R}^{n\times n}$ and$P\in\mathbb{R}^{n\times n}$ →$O(n^2)$ IO. - FlashAttention: read
$Q$ once, stream$K,V$ in tiles (read once), write$O$ once →$O(nd) + O(nd_v)$ IO. Arithmetic is unchanged; performance scales with IO reduction.
Numerics
- With FP16/BF16 inputs, convert to FP32 for accumulation, keep the log‑sum‑exp reference
$m$ , and scale contributions by$\exp(s-m)$ . This keeps intermediate magnitudes$\mathcal{O}(1)$ and avoids overflow (e.g.,$e^{80}$ in FP32 is already huge without stabilization).