-
Notifications
You must be signed in to change notification settings - Fork 15
Description
I tested xattention in A100 and L40s , and got same result.
xattention is slower than flash attention in both platforms.
here is the test script:
``python
import torch
from xattn.src.Xattention import Xattention_prefill
import time
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
bsz = 1
heads = 24
seq_len = 6656
dim = 128
q = torch.randn((bsz,heads,seq_len,dim),dtype=torch.bfloat16).to("cuda")
k = torch.randn((bsz,heads,seq_len,dim),dtype=torch.bfloat16).to("cuda")
v = torch.randn((bsz,heads,seq_len,dim),dtype=torch.bfloat16).to("cuda")
for i in range(5):
with torch.no_grad():
attention_output = Xattention_prefill(query_states=q,key_states=k,value_states=v,stride=8,block_size=128,use_triton=True,threshold=0.95)
run_times = 20
t1 = time.time()
for i in range(run_times):
with torch.no_grad():
attention_output = Xattention_prefill(query_states=q,key_states=k,value_states=v,stride=8,block_size=128,use_triton=True,threshold=0.95)
torch.cuda.synchronize()
t2 = time.time()
print(" ====> time xattn is ", (t2-t1)/run_times)
#print(attention_output)
for i in range(5):
with torch.no_grad():
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
run_times = 20
t1 = time.time()
for i in range(run_times):
with torch.no_grad():
attention_output1 = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
torch.cuda.synchronize()
t2 = time.time()
print(" ====> time flash is ", (t2-t1)/run_times)
#print(attention_output1)
#c = torch.isclose(attention_output1, attention_output, rtol=1e-03, atol=1e-05, equal_nan=False)
#print(" ====> is close ", c)
and I installed these packages through the following instructions:
```python
git clone https://github.com/mit-han-lab/Block-Sparse-Attention.git
cd Block-Sparse-Attention
python setup.py install
cd ..
export PYTHONPATH="$PYTHONPATH:$(pwd)"
pip install flash-attn==2.6.3 --no-build-isolation
git clone https://github.com/mit-han-lab/x-attention.git
cd x-attention
python test.py
I did not install x-attention through pip
I just cloned the repo and run test.py
and I printed the approx_simple_mask result of function xattn_estimate
it seems that mask is ok
