Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in flash attention where kv_seq_len should divide block_k_major. #8671

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

zhangp365
Copy link

When generating images with flash attention on TPU, a bug occurs with the following error message:

2025-02-04 07:29:15,292 - execution.py:398 - ERROR - !!! Exception during processing !!! kv_seq_len=4992 should be divisible by block_k_major=512
2025-02-04 07:29:15,295 - execution.py:399 - ERROR - Traceback (most recent call last):
  File "/app/execution.py", line 329, in execute
    output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/execution.py", line 203, in get_output_data
    return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/execution.py", line 175, in _map_node_over_list
    process_inputs(input_dict, i)
  File "/app/execution.py", line 164, in process_inputs
    results.append(getattr(obj, func)(**inputs))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/custom_nodes/TLDiffnode/nodes_flux.py", line 293, in sample
    latents_or_images = pipeline(**params)[0]
                        ^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 907, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 545, in forward
    encoder_hidden_states, hidden_states = block(
                                           ^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 187, in forward
    attention_outputs = self.attn(
                        ^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/models/attention_processor.py", line 594, in forward
    return self.processor(
           ^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/models/attention_processor.py", line 3508, in __call__
    hidden_states = flash_attention(query, key, value, causal=False)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_xla/experimental/custom_kernel.py", line 515, in flash_attention
    return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_xla/experimental/custom_kernel.py", line 307, in forward
    payload, _ = trace_pallas(
                 ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_xla/experimental/custom_kernel.py", line 139, in trace_pallas
    ir = jax.jit(
         ^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py", line 621, in _flash_attention_impl
    _verify_block("block_k_major", "kv_seq_len", block_k_major, kv_seq_len)
  File "/opt/conda/lib/python3.11/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py", line 1741, in _verify_block
    raise ValueError(
ValueError: kv_seq_len=4992 should be divisible by block_k_major=512
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Cause:

This bug happens when the image resolution is not divisible by 512 on at least one side. Specifically, the sequence length (kv_seq_len) should be divisible by the block size (block_k_major, which is 512) for the flash attention mechanism to work correctly. In the error above, kv_seq_len=4992 is not divisible by 512, leading to this exception.

Solution:

To resolve this issue, we need to pad the k, v, and ab vectors to ensure that their lengths are divisible by the block sizes.

@qihqi qihqi requested review from tengyifei and zpcore February 6, 2025 04:54
@qihqi
Copy link
Collaborator

qihqi commented Feb 6, 2025

Thanks for this change!
Please run the yapf formatter it's good to go otherwise. Thanks!

@qihqi qihqi self-requested a review February 6, 2025 04:56
Comment on lines 299 to 301
k_padded if k_pad_size > 0 else k,
v_padded if k_pad_size > 0 else v,
ab_padded if k_pad_size > 0 and ab is not None else ab,
Copy link
Collaborator

@zpcore zpcore Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can let k,v,ab all go through _pad_to_block_size() and use k_padded, v_padded, ab_padded afterwards to simplify the logic.

Copy link
Collaborator

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember we saw same issue in stable diffusion run. This is great, thanks for the fixing!

@@ -279,32 +279,42 @@ def fa_custom_forward(
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)

block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2])
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
k_padded, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need to do the same padding for backward pass. Let's see if the test can pass or not.

@zpcore
Copy link
Collaborator

zpcore commented Feb 6, 2025

Oh, forgot to mention, we don't have the test for the case that needs padding. Can you also add to the unit test w/ spmd, w/o spmd, ? Thanks

@zhangp365
Copy link
Author

Oh, forgot to mention, we don't have the test for the case that needs padding. Can you also add to the unit test w/ spmd, w/o spmd, ? Thanks

ok, I will update it according to this.

@zhangp365
Copy link
Author

Oh, forgot to mention, we don't have the test for the case that needs padding. Can you also add to the unit test w/ spmd, w/o spmd, ? Thanks

ok, I will update it according to this.

Hi, I've updated the code. Please review it again, thanks.

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a backward pass test to make sure q,v,k,ab grad are the same with self._attention output (similar to test_flash_attention_backward_aot_autograd_traceable)? I feel the backward pass needs the same update. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants