Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Cuda Blackwell #1463

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8e9caea
update ck
rocking5566 Sep 25, 2024
31d1ec7
update ck
rocking5566 Sep 26, 2024
a8851eb
update ck again
rocking5566 Oct 1, 2024
91eb950
Merge pull request #84 from ROCm/ck_tile/fix-bwd-compiler
rocking5566 Oct 2, 2024
1b74510
Integrate FAv3bwd for MI300
rocking5566 Oct 4, 2024
d1f9160
Add benchmark script for fa3
rocking5566 Oct 4, 2024
e73cccc
Revert log level
rocking5566 Oct 4, 2024
4fee8ed
Refine FAv3 benchmark script
rocking5566 Oct 4, 2024
3e179ac
remove useless comment
rocking5566 Oct 4, 2024
0796a00
update FAv3 bwd
rocking5566 Oct 15, 2024
4526d54
Fix missing parameter
rocking5566 Oct 17, 2024
6be6826
Add missing soure file
rocking5566 Nov 1, 2024
c19380e
Fix drop_seed/drop_offset argument ambiguity
poyenc Dec 25, 2024
1332553
Support FAv3 bwd hd64+bf16+atomic16 cases
poyenc Dec 25, 2024
5f4c571
Skip unsupported tests
poyenc Dec 25, 2024
530e5ee
Add hd=64,dtype=bf16 test cases
poyenc Dec 26, 2024
5f2d704
Set is_v3_atomic_fp32=false for hd=64
poyenc Dec 26, 2024
98f6859
Skip some hd=64 test cases
poyenc Dec 26, 2024
12cd298
Merge pull request #112 from ROCm/ck_tile/fa3-fix
poyenc Dec 26, 2024
f691a81
If hd=64, clear dq before launching kernel
poyenc Dec 27, 2024
e10bc4d
Merge pull request #113 from ROCm/ck_tile/fa3-bugfix
poyenc Dec 27, 2024
5a93a97
enable hd64 bf16 atomic32
poyenc Dec 28, 2024
e471acc
Update fmha_bwd_traits template arguments
poyenc Dec 28, 2024
f547b58
Add back skipped test cases
poyenc Dec 28, 2024
423b9c2
Merge pull request #114 from ROCm/ck_tile/fa3-hd64-bf16-atomic32
poyenc Dec 28, 2024
14f8476
update ck
rocking5566 Jan 6, 2025
b0e8480
remove redundant test
rocking5566 Jan 6, 2025
c20dc71
Merge pull request #117 from ROCm/ck_tile/fa3_update
rocking5566 Jan 6, 2025
ec892f9
update ck to support hdim=64 for bf16
rocking5566 Jan 15, 2025
fabfcd3
update ck to support hdim 8x for (64~128)
rocking5566 Jan 24, 2025
17deba7
update change for revise of ck api
rocking5566 Jan 24, 2025
e524b68
Add more test case for fa3
rocking5566 Jan 26, 2025
6aebd95
Update ck for different layout
rocking5566 Jan 27, 2025
596f83c
change version to 3.0.0.r1
rocking5566 Feb 7, 2025
99730ad
update CK
rocking5566 Feb 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions benchmarks/benchmark_flash_attention_fa3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined

from flash_attn import flash_attn_func

try:
from triton.ops.flash_attention import attention as attention_triton
except ImportError:
attention_triton = None

try:
import xformers.ops as xops
except ImportError:
xops = None


def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)


def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0


def attention_pytorch(qkv, dropout_p=0.0, causal=True):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size, seqlen, _, nheads, d = qkv.shape
q, k, v = qkv.unbind(dim=2)
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b s h d -> (b h) d s')
softmax_scale = 1.0 / math.sqrt(d)
# Preallocate attn_weights for `baddbmm`
scores = torch.empty(batch_size * nheads, seqlen, seqlen,
dtype=qkv.dtype, device=qkv.device)
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
'(b h) t s -> b h t s', h=nheads)
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full(
(seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
return output.to(dtype=qkv.dtype)


def time_fwd_bwd(func, *args, **kwargs):
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean, time_b[1].mean


repeats = 30
device = 'cuda'
dtype = torch.bfloat16

bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [128]
nheads = 16
dropout_p = 0.0

methods = (["Flash"])

time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
config = (causal, headdim, batch_size, seqlen)
q = torch.randn(batch_size, seqlen, nheads, headdim,
device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen, nheads, headdim,
device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen, nheads, headdim,
device=device, dtype=dtype, requires_grad=True)

f, b = time_fwd_bwd(
flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False
)
time_f[config, "Flash"] = f
time_b[config, "Flash"] = b

print(
f"[b, s, h, d] = [{batch_size}, {seqlen}, {nheads}, {headdim}], causal={causal}")
for method in methods:
speed_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim,
nheads, causal, mode="bwd"),
time_b[config, method]
)
print(f"bwd: {speed_b[config, method]:.2f} TFLOPs/s")

2 changes: 1 addition & 1 deletion csrc/composable_kernel
19 changes: 11 additions & 8 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
false, // has_dbias
has_dropout,
false, // s_randval
deterministic};
deterministic,
true, // uses_ext_asm
true, // is_v3_atomic_fp32
1}; // how_v3_bf16_cvt 0:RTNE; 1:RTNA; 2:RTZ
}

fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
Expand Down Expand Up @@ -99,11 +102,11 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
ck_tile::index_t stride_dv = dv.stride(1);
ck_tile::index_t nhead_stride_dv = dv.stride(2);

// dq_acc: (split, batch_size, seqlen_q, nheads, hdim)
// dq_acc: (split, batch_size, nheads, seqlen_q, hdim)
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t stride_dq_acc = dq_acc.stride(2);
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3);
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2);
ck_tile::index_t stride_dq_acc = dq_acc.stride(3);

float p_undrop = 1.0 - p_dropout;

Expand Down Expand Up @@ -191,7 +194,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
{drop_seed, drop_offset}};
std::make_pair(drop_seed, drop_offset)};
}

std::vector<at::Tensor>
Expand Down Expand Up @@ -318,11 +321,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
at::Tensor dq_accum;

if (!deterministic) {
dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
dq_accum = torch::zeros({1, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(at::kFloat));
} else {
const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
dq_accum = torch::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(at::kFloat));
}

at::Tensor dk_expanded, dv_expanded;
Expand Down Expand Up @@ -399,4 +402,4 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
}

return { dq, dk, dv, softmax_d };
}
}
2 changes: 1 addition & 1 deletion csrc/flash_attn_ck/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
has_dropout_randval,
{drop_seed, drop_offset}};
std::make_pair(drop_seed, drop_offset)};
}

std::vector<at::Tensor>
Expand Down
9 changes: 6 additions & 3 deletions csrc/flash_attn_ck/mha_varlen_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
false, // has_dbias
has_dropout,
false, // s_randval
deterministic};
deterministic,
false, // uses_ext_asm
head_size != 64, // is_v3_atomic_fp32
2}; // how_v3_bf16_cvt 0:RTNE; 1:RTNA; 2:RTZ
}

fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
Expand Down Expand Up @@ -197,7 +200,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
{drop_seed, drop_offset}};
std::make_pair(drop_seed, drop_offset)};
}

std::vector<at::Tensor>
Expand Down Expand Up @@ -426,4 +429,4 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
}

return { dq, dk, dv, softmax_d };
}
}
2 changes: 1 addition & 1 deletion csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
has_dropout_randval,
{drop_seed, drop_offset}};
std::make_pair(drop_seed, drop_offset)};
}

std::vector<at::Tensor>
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.6.3"
__version__ = "3.0.0.r1"

from flash_attn.flash_attn_interface import (
flash_attn_func,
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def validate_and_update_archs(archs):
f"build/fmha_*wd*.cpp"
)

sources+=glob.glob(f"csrc/composable_kernel/example/ck_tile/01_fmha/hsaco/*.cpp")

rename_cpp_to_cu(sources)

renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
Expand All @@ -358,6 +360,8 @@ def validate_and_update_archs(archs):
"csrc/flash_attn_ck/mha_varlen_bwd.cu",
"csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")

renamed_sources+=glob.glob(f"csrc/composable_kernel/example/ck_tile/01_fmha/hsaco/*.cu")

cc_flag += ["-O3","-std=c++17",
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
Expand Down
Loading