Implement backward pass#2
Conversation
|
@leloykun hello Franz! I have some trouble with the code and flash attention. Firstly, why the attn values sanity check return False when the seq_len is lower than 32. It lead to collapse in inference which seq_len is usually 1, I guess the block size may cause this result? Then, how to choose a appropriate block size? Looking forward to your reply! |
|
Hi @hypertseng! I believe it was because we weren't exiting the loops after going past the seq length. The forward pass should be fixed in my repo here: https://github.com/leloykun/flash-hyperbolic-attention-minimal |
|
@leloykun Recently, I found the flash_attn_bwd implementation in your repo is lower than the manual implementation, this is totally because the implicitly function call of cudaDeviceSynchronize which Increases the CPU time a lot. Do you have any idea to solve this problem? |
|
@hypertseng Most likely, |


Description
This PR implements a minimal backward pass for flash attention.
I got these results on my RTX 2060
2x speedup
Tho my GPU can only handle size 16 blocks (vs. size 32 blocks for T4)