-
Notifications
You must be signed in to change notification settings - Fork 496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix a bug in flash attention where kv_seq_len should divide block_k_major. #8671
base: master
Are you sure you want to change the base?
Conversation
Thanks for this change! |
k_padded if k_pad_size > 0 else k, | ||
v_padded if k_pad_size > 0 else v, | ||
ab_padded if k_pad_size > 0 and ab is not None else ab, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can let k,v,ab
all go through _pad_to_block_size()
and use k_padded, v_padded, ab_padded
afterwards to simplify the logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember we saw same issue in stable diffusion run. This is great, thanks for the fixing!
@@ -279,32 +279,42 @@ def fa_custom_forward( | |||
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids( | |||
q_segment_ids, kv_segment_ids) | |||
|
|||
block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]) | |||
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]) | |||
k_padded, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need to do the same padding for backward pass. Let's see if the test can pass or not.
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, | ||
"This test only works on TPUv3+.") | ||
@with_jax_high_precision | ||
def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a backward pass test to make sure q,v,k,ab grad are the same with self._attention output (similar to test_flash_attention_backward_aot_autograd_traceable)? I feel the backward pass needs the same update. Thanks
When generating images with flash attention on TPU, a bug occurs with the following error message:
Cause:
This bug happens when the image resolution is not divisible by 512 on at least one side. Specifically, the sequence length (kv_seq_len) should be divisible by the block size (block_k_major, which is 512) for the flash attention mechanism to work correctly. In the error above, kv_seq_len=4992 is not divisible by 512, leading to this exception.
Solution:
To resolve this issue, we need to pad the k, v, and ab vectors to ensure that their lengths are divisible by the block sizes.