Skip to content

why xattention speed is slower than flash attention #28

@DamonsJ

Description

@DamonsJ
Image

I tested xattention in A100 and L40s , and got same result.
xattention is slower than flash attention in both platforms.

here is the test script:

``python
import torch
from xattn.src.Xattention import Xattention_prefill
import time
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

bsz = 1
heads = 24
seq_len = 6656
dim = 128
q = torch.randn((bsz,heads,seq_len,dim),dtype=torch.bfloat16).to("cuda")
k = torch.randn((bsz,heads,seq_len,dim),dtype=torch.bfloat16).to("cuda")
v = torch.randn((bsz,heads,seq_len,dim),dtype=torch.bfloat16).to("cuda")

for i in range(5):
with torch.no_grad():
attention_output = Xattention_prefill(query_states=q,key_states=k,value_states=v,stride=8,block_size=128,use_triton=True,threshold=0.95)

run_times = 20
t1 = time.time()
for i in range(run_times):
with torch.no_grad():
attention_output = Xattention_prefill(query_states=q,key_states=k,value_states=v,stride=8,block_size=128,use_triton=True,threshold=0.95)
torch.cuda.synchronize()
t2 = time.time()

print(" ====> time xattn is ", (t2-t1)/run_times)
#print(attention_output)

for i in range(5):
with torch.no_grad():
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)

run_times = 20
t1 = time.time()
for i in range(run_times):
with torch.no_grad():
attention_output1 = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
torch.cuda.synchronize()
t2 = time.time()

print(" ====> time flash is ", (t2-t1)/run_times)
#print(attention_output1)
#c = torch.isclose(attention_output1, attention_output, rtol=1e-03, atol=1e-05, equal_nan=False)
#print(" ====> is close ", c)


and I installed these packages through the following instructions:

```python
git clone https://github.com/mit-han-lab/Block-Sparse-Attention.git
cd Block-Sparse-Attention
python setup.py install
cd ..

export PYTHONPATH="$PYTHONPATH:$(pwd)"

pip install flash-attn==2.6.3 --no-build-isolation

git clone https://github.com/mit-han-lab/x-attention.git

cd x-attention

python test.py

I did not install x-attention through pip

I just cloned the repo and run test.py

and I printed the approx_simple_mask result of function xattn_estimate

it seems that mask is ok

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions