-
Notifications
You must be signed in to change notification settings - Fork 552
Fix a bug in flash attention where kv_seq_len should divide block_k_major. #8671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for this change! |
k_padded if k_pad_size > 0 else k, | ||
v_padded if k_pad_size > 0 else v, | ||
ab_padded if k_pad_size > 0 and ab is not None else ab, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can let k,v,ab
all go through _pad_to_block_size()
and use k_padded, v_padded, ab_padded
afterwards to simplify the logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember we saw same issue in stable diffusion run. This is great, thanks for the fixing!
@@ -279,32 +279,42 @@ def fa_custom_forward( | |||
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids( | |||
q_segment_ids, kv_segment_ids) | |||
|
|||
block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]) | |||
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]) | |||
k_padded, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need to do the same padding for backward pass. Let's see if the test can pass or not.
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, | ||
"This test only works on TPUv3+.") | ||
@with_jax_high_precision | ||
def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a backward pass test to make sure q,v,k,ab grad are the same with self._attention output (similar to test_flash_attention_backward_aot_autograd_traceable)? I feel the backward pass needs the same update. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very familiar with the backward, Can I just fix this forward bug first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That also works, thanks!
@zpcore Hi, I found that the tpu-test failed again. I've run the unit test locally, but still have some questions.
and added a print.
The output is:
The diff is highly above the tolerance of 1e-5.
the output diff is similar, and the test method also fails. So, I'm confused about why the diff is so large. Can you give me some advice? |
From test log:
the bug is an existing bug? not from this fix? |
I ran your code with the test_pallas, and the test passes. The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for fixing the shape.
ok, thanks for accepting the pr. |
When generating images with flash attention on TPU, a bug occurs with the following error message:
Cause:
This bug happens when the image resolution is not divisible by 512 on at least one side. Specifically, the sequence length (kv_seq_len) should be divisible by the block size (block_k_major, which is 512) for the flash attention mechanism to work correctly. In the error above, kv_seq_len=4992 is not divisible by 512, leading to this exception.
Solution:
To resolve this issue, we need to pad the k, v, and ab vectors to ensure that their lengths are divisible by the block sizes.