diff --git a/benchmarks/benchmark_flash_attention_fa3.py b/benchmarks/benchmark_flash_attention_fa3.py new file mode 100644 index 000000000..de846a335 --- /dev/null +++ b/benchmarks/benchmark_flash_attention_fa3.py @@ -0,0 +1,117 @@ +# Install the newest triton version with +# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" +import pickle +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + +from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward +from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined + +from flash_attn import flash_attn_func + +try: + from triton.ops.flash_attention import attention as attention_triton +except ImportError: + attention_triton = None + +try: + import xformers.ops as xops +except ImportError: + xops = None + + +def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) + + +def efficiency(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + + +def attention_pytorch(qkv, dropout_p=0.0, causal=True): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, head_dim) + dropout_p: float + Output: + output: (batch_size, seqlen, nheads, head_dim) + """ + batch_size, seqlen, _, nheads, d = qkv.shape + q, k, v = qkv.unbind(dim=2) + q = rearrange(q, 'b t h d -> (b h) t d') + k = rearrange(k, 'b s h d -> (b h) d s') + softmax_scale = 1.0 / math.sqrt(d) + # Preallocate attn_weights for `baddbmm` + scores = torch.empty(batch_size * nheads, seqlen, seqlen, + dtype=qkv.dtype, device=qkv.device) + scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), + '(b h) t s -> b h t s', h=nheads) + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full( + (seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1) + attention_drop = F.dropout(attention, dropout_p) + output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + return output.to(dtype=qkv.dtype) + + +def time_fwd_bwd(func, *args, **kwargs): + time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) + return time_f[1].mean, time_b[1].mean + + +repeats = 30 +device = 'cuda' +dtype = torch.bfloat16 + +bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +causal_vals = [False, True] +headdim_vals = [128] +nheads = 16 +dropout_p = 0.0 + +methods = (["Flash"]) + +time_f = {} +time_b = {} +time_f_b = {} +speed_f = {} +speed_b = {} +speed_f_b = {} +for causal in causal_vals: + for headdim in headdim_vals: + for batch_size, seqlen in bs_seqlen_vals: + config = (causal, headdim, batch_size, seqlen) + q = torch.randn(batch_size, seqlen, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) + + f, b = time_fwd_bwd( + flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + time_f[config, "Flash"] = f + time_b[config, "Flash"] = b + + print( + f"[b, s, h, d] = [{batch_size}, {seqlen}, {nheads}, {headdim}], causal={causal}") + for method in methods: + speed_b[config, method] = efficiency( + flops(batch_size, seqlen, headdim, + nheads, causal, mode="bwd"), + time_b[config, method] + ) + print(f"bwd: {speed_b[config, method]:.2f} TFLOPs/s") + diff --git a/csrc/composable_kernel b/csrc/composable_kernel index a9b170b54..cf70a2efb 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit a9b170b54195ab667ca814f80dd5dfbf4ad772f5 +Subproject commit cf70a2efbb334747199baadf91acb504b8ce2bf7 diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 1859137f8..f1b1b693c 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -23,7 +23,10 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, false, // has_dbias has_dropout, false, // s_randval - deterministic}; + deterministic, + true, // uses_ext_asm + true, // is_v3_atomic_fp32 + 1}; // how_v3_bf16_cvt 0:RTNE; 1:RTNA; 2:RTZ } fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, @@ -99,11 +102,11 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, ck_tile::index_t stride_dv = dv.stride(1); ck_tile::index_t nhead_stride_dv = dv.stride(2); - // dq_acc: (split, batch_size, seqlen_q, nheads, hdim) + // dq_acc: (split, batch_size, nheads, seqlen_q, hdim) ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); - ck_tile::index_t stride_dq_acc = dq_acc.stride(2); - ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2); + ck_tile::index_t stride_dq_acc = dq_acc.stride(3); float p_undrop = 1.0 - p_dropout; @@ -191,7 +194,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, static_cast(mask.type), p_dropout, p_undrop, - {drop_seed, drop_offset}}; + std::make_pair(drop_seed, drop_offset)}; } std::vector @@ -318,11 +321,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num at::Tensor dq_accum; if (!deterministic) { - dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({1, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(at::kFloat)); } else { const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64; const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); - dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; @@ -399,4 +402,4 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num } return { dq, dk, dv, softmax_d }; -} \ No newline at end of file +} diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index a6b33b4ab..37e2f4111 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -137,7 +137,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, static_cast(mask.type), p_dropout, has_dropout_randval, - {drop_seed, drop_offset}}; + std::make_pair(drop_seed, drop_offset)}; } std::vector diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index 531d735ed..4b109f857 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -23,7 +23,10 @@ fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, false, // has_dbias has_dropout, false, // s_randval - deterministic}; + deterministic, + false, // uses_ext_asm + head_size != 64, // is_v3_atomic_fp32 + 2}; // how_v3_bf16_cvt 0:RTNE; 1:RTNA; 2:RTZ } fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, @@ -197,7 +200,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, static_cast(mask.type), p_dropout, p_undrop, - {drop_seed, drop_offset}}; + std::make_pair(drop_seed, drop_offset)}; } std::vector @@ -426,4 +429,4 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads } return { dq, dk, dv, softmax_d }; -} \ No newline at end of file +} diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 6e30aa74a..255705bfa 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -140,7 +140,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, static_cast(mask.type), p_dropout, has_dropout_randval, - {drop_seed, drop_offset}}; + std::make_pair(drop_seed, drop_offset)}; } std::vector diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index e235d9274..9e9611173 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.6.3" +__version__ = "3.0.0.r1" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/setup.py b/setup.py index 3184c91dd..925305aa4 100644 --- a/setup.py +++ b/setup.py @@ -348,6 +348,8 @@ def validate_and_update_archs(archs): f"build/fmha_*wd*.cpp" ) + sources+=glob.glob(f"csrc/composable_kernel/example/ck_tile/01_fmha/hsaco/*.cpp") + rename_cpp_to_cu(sources) renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", @@ -358,6 +360,8 @@ def validate_and_update_archs(archs): "csrc/flash_attn_ck/mha_varlen_bwd.cu", "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") + renamed_sources+=glob.glob(f"csrc/composable_kernel/example/ck_tile/01_fmha/hsaco/*.cu") + cc_flag += ["-O3","-std=c++17", "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", "-fgpu-flush-denormals-to-zero", diff --git a/tests/test_flash_attn_ck_fa3.py b/tests/test_flash_attn_ck_fa3.py new file mode 100644 index 000000000..307041c34 --- /dev/null +++ b/tests/test_flash_attn_ck_fa3.py @@ -0,0 +1,270 @@ +import math + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from flash_attn import ( + flash_attn_func, + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_varlen_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_with_kvcache, +) + +from test_flash_attn import ( + attn_bias_from_alibi_slopes, + convert_flash_attn_S_to_softmax, + generate_qkv, + generate_random_padding_mask, + _generate_block_kvcache, + attention_ref, + attention_kvpacked_ref, + attention_qkvpacked_ref, +) + +from flash_attn.layers.rotary import apply_rotary_emb + +def is_bwd_hdim_supported(d): + return d <= 256 + + +def ck_randval_to_dropout_mask(randval, p): + # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout + # randval in 255 * [0, 0.7] will be kept + # If return dropout_mask >=0, value will be kept + return math.floor(255.0 * (1 - p)) - randval.to(torch.float32) + + +def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded): + """ pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded] + Arguments: + S_dmask: (nheads, total_q, max_seqlen_k) + cu_seqlens_q: (b + 1) + Output: + S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded) + """ + batch_size = cu_seqlens_q.numel() - 1 + seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q + seqlens_q = seqlens_q[0:batch_size].tolist() + S_dmask = torch.split(S_dmask, seqlens_q, dim=1) + # [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)] + masks = () + for mask in S_dmask: + # (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded) + mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1) + masks = masks + (mask, ) + S_dmask = torch.cat(masks, dim=1) + + S_dmask = S_dmask.transpose(0, 1) + return S_dmask + +@pytest.mark.parametrize("kvpacked", [False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("alibi", [False]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 72, 80, 88, 96, 104, 112, 120, 128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (256, 512), + (1024, 1024), + (2048, 2048), + ], +) +@pytest.mark.parametrize("dropout_p", [0.0]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked +): + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 9 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if kvpacked: + kv = torch.randn( + batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + else: + k = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) + else: + alibi_slopes, attn_bias = None, None + + if kvpacked: + out, lse, S_dmask = flash_attn_kvpacked_func( + q, + kv, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out, lse, S_dmask = flash_attn_func( + q, + k, + v, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if dropout_p > 0.0: + # TODO - move to c++ mha_varlen_fwd() + S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p) + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen_q, + seqlen_k, + None, + None, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_mask = S_dmask_converted >= 0 + if kvpacked: + kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) + k_rep, v_rep = kv_rep.unbind(dim=2) + else: + k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) + # CK does not return P. Hence, we don't test the attn here. + else: + dropout_mask = None + + if kvpacked: + out_ref, attn_ref = attention_kvpacked_ref( + q, + kv, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_kvpacked_ref( + q, + kv, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + else: + out_ref, attn_ref = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + g = torch.randn_like(out) + if is_bwd_hdim_supported(d): + if kvpacked: + ( + dq, + dkv, + ) = torch.autograd.grad(out, (q, kv), g) + dk, dv = dkv.unbind(2) + ( + dq_ref, + dkv_ref, + ) = torch.autograd.grad(out_ref, (q, kv), g) + dk_ref, dv_ref = dkv_ref.unbind(2) + ( + dq_pt, + dkv_pt, + ) = torch.autograd.grad(out_pt, (q, kv), g) + dk_pt, dv_pt = dkv_pt.unbind(2) + else: + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # TODO - use 10 times to check, wait for ck to fix bwd precision issue + assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()