Skip to content

Fix a bug in flash attention where kv_seq_len should divide block_k_major. #8671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 10, 2025

Conversation

zhangp365
Copy link
Contributor

When generating images with flash attention on TPU, a bug occurs with the following error message:

2025-02-04 07:29:15,292 - execution.py:398 - ERROR - !!! Exception during processing !!! kv_seq_len=4992 should be divisible by block_k_major=512
2025-02-04 07:29:15,295 - execution.py:399 - ERROR - Traceback (most recent call last):
  File "/app/execution.py", line 329, in execute
    output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/execution.py", line 203, in get_output_data
    return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/execution.py", line 175, in _map_node_over_list
    process_inputs(input_dict, i)
  File "/app/execution.py", line 164, in process_inputs
    results.append(getattr(obj, func)(**inputs))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/custom_nodes/TLDiffnode/nodes_flux.py", line 293, in sample
    latents_or_images = pipeline(**params)[0]
                        ^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 907, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 545, in forward
    encoder_hidden_states, hidden_states = block(
                                           ^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 187, in forward
    attention_outputs = self.attn(
                        ^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/models/attention_processor.py", line 594, in forward
    return self.processor(
           ^^^^^^^^^^^^^^^
  File "/app/diffusers/src/diffusers/models/attention_processor.py", line 3508, in __call__
    hidden_states = flash_attention(query, key, value, causal=False)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_xla/experimental/custom_kernel.py", line 515, in flash_attention
    return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_xla/experimental/custom_kernel.py", line 307, in forward
    payload, _ = trace_pallas(
                 ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_xla/experimental/custom_kernel.py", line 139, in trace_pallas
    ir = jax.jit(
         ^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py", line 621, in _flash_attention_impl
    _verify_block("block_k_major", "kv_seq_len", block_k_major, kv_seq_len)
  File "/opt/conda/lib/python3.11/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py", line 1741, in _verify_block
    raise ValueError(
ValueError: kv_seq_len=4992 should be divisible by block_k_major=512
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Cause:

This bug happens when the image resolution is not divisible by 512 on at least one side. Specifically, the sequence length (kv_seq_len) should be divisible by the block size (block_k_major, which is 512) for the flash attention mechanism to work correctly. In the error above, kv_seq_len=4992 is not divisible by 512, leading to this exception.

Solution:

To resolve this issue, we need to pad the k, v, and ab vectors to ensure that their lengths are divisible by the block sizes.

@qihqi qihqi requested review from tengyifei and zpcore February 6, 2025 04:54
@qihqi
Copy link
Collaborator

qihqi commented Feb 6, 2025

Thanks for this change!
Please run the yapf formatter it's good to go otherwise. Thanks!

@qihqi qihqi self-requested a review February 6, 2025 04:56
Comment on lines 299 to 301
k_padded if k_pad_size > 0 else k,
v_padded if k_pad_size > 0 else v,
ab_padded if k_pad_size > 0 and ab is not None else ab,
Copy link
Member

@zpcore zpcore Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can let k,v,ab all go through _pad_to_block_size() and use k_padded, v_padded, ab_padded afterwards to simplify the logic.

Copy link
Member

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember we saw same issue in stable diffusion run. This is great, thanks for the fixing!

@@ -279,32 +279,42 @@ def fa_custom_forward(
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)

block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2])
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
k_padded, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need to do the same padding for backward pass. Let's see if the test can pass or not.

@zpcore
Copy link
Member

zpcore commented Feb 6, 2025

Oh, forgot to mention, we don't have the test for the case that needs padding. Can you also add to the unit test w/ spmd, w/o spmd, ? Thanks

@zhangp365
Copy link
Contributor Author

Oh, forgot to mention, we don't have the test for the case that needs padding. Can you also add to the unit test w/ spmd, w/o spmd, ? Thanks

ok, I will update it according to this.

@zhangp365
Copy link
Contributor Author

Oh, forgot to mention, we don't have the test for the case that needs padding. Can you also add to the unit test w/ spmd, w/o spmd, ? Thanks

ok, I will update it according to this.

Hi, I've updated the code. Please review it again, thanks.

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a backward pass test to make sure q,v,k,ab grad are the same with self._attention output (similar to test_flash_attention_backward_aot_autograd_traceable)? I feel the backward pass needs the same update. Thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with the backward, Can I just fix this forward bug first?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That also works, thanks!

@zhangp365
Copy link
Contributor Author

zhangp365 commented Feb 8, 2025

@zpcore Hi, I found that the tpu-test failed again. I've run the unit test locally, but still have some questions.
Now, I am running the latest test code on the v6e TPU:

python -m unittest test.test_pallas.PallasTest.test_flash_attention_wrapper_kv_and_ab_padding

and added a print.

    diff = o.cpu() - expected_o.cpu()
    print(f"diff: {diff}, max: {diff.max()}, min: {diff.min()}")

The output is:

diff: tensor([[[[ 6.8292e-04,  4.0711e-04,  1.4198e-03, -1.9633e-04],
          [ 3.7608e-04, -5.6365e-04,  5.7314e-04, -2.5114e-04],
          [-2.0739e-03, -2.0366e-04, -1.2157e-03, -2.2950e-03],
          ...,
          [-2.0213e-04, -2.7862e-04, -3.3129e-04,  1.1446e-03],
          [ 5.1174e-05,  5.1262e-04,  4.8657e-04,  2.7038e-04],
          [-2.8645e-03,  4.7613e-03,  2.5814e-03, -1.0975e-02]],

         [[ 1.5105e-03,  1.5672e-03,  1.8056e-04,  4.0236e-04],
          [ 1.2470e-03, -1.1021e-03, -9.5463e-04, -1.3065e-04],
          [-1.9513e-03,  1.0202e-03, -1.5303e-04, -4.0093e-04],
          ...,
          [ 1.0555e-04,  3.0100e-04,  5.1606e-04, -3.4345e-04],
          [-1.9871e-04,  6.6838e-04,  1.0839e-03, -3.0933e-04],
          [-7.1570e-04,  1.1260e-03, -2.5651e-04, -3.4667e-04]]],


        [[[-4.7386e-04,  7.3694e-05,  8.1086e-04, -3.3017e-04],
          [-9.0058e-04,  1.5155e-04, -7.8216e-05,  3.2675e-04],
          [ 2.4962e-04,  3.6017e-04,  4.7567e-04, -7.1727e-05],
          ...,
          [ 1.1529e-03,  9.1261e-04,  9.7885e-03, -6.2257e-03],
          [ 3.5307e-04,  1.8126e-04,  7.2570e-04, -7.2286e-04],
          [-4.8020e-04,  1.3306e-03,  7.6964e-06, -1.0741e-03]],

         [[-1.1892e-03, -2.0468e-03, -5.8037e-04, -4.7976e-04],
          [ 2.1626e-03,  1.5714e-03, -2.7031e-05, -6.9004e-04],
          [-2.2198e-03, -1.0973e-03,  1.9273e-03,  2.8504e-03],
          ...,
          [ 1.7182e-03,  1.0535e-03, -4.3399e-04,  7.2208e-04],
          [ 1.8564e-03,  1.3718e-04,  8.2254e-05,  5.1498e-05],
          [ 4.1562e-04,  1.4392e-03, -2.1695e-04, -1.3060e-03]]],


        [[[-3.8669e-04, -6.0107e-04,  8.8918e-04, -1.7675e-05],
          [-8.5890e-04,  3.6108e-03, -3.1845e-03, -7.1132e-04],
          [-5.1269e-04, -1.0678e-03,  3.5015e-04,  7.8895e-04],
          ...,
          [-1.5513e-03,  2.0282e-03, -9.3010e-04, -1.3495e-03],
          [ 3.2313e-03, -1.4389e-04,  1.3680e-03, -5.5626e-04],
          [-3.6691e-04, -4.9652e-04, -2.9108e-03,  1.2078e-03]],

         [[ 2.4585e-03,  7.4232e-04, -9.5740e-05,  1.2181e-03],
          [ 8.6099e-05, -1.9291e-04,  1.0666e-04,  9.1210e-04],
          [ 2.4083e-03,  3.0909e-03, -2.7944e-03, -1.1478e-04],
          ...,
          [ 9.6892e-04,  3.3641e-04, -1.1125e-03,  1.5945e-03],
          [ 1.2645e-03,  6.4671e-04, -1.2981e-03,  6.6146e-05],
          [ 1.8814e-04, -6.4045e-05, -2.6532e-04,  4.6800e-04]]]]),
          diff max: 0.015680253505706787, min: -0.010974973440170288

The diff is highly above the tolerance of 1e-5.
However, when I run the original test method:

python -m unittest test.test_pallas.PallasTest.test_flash_attention_wrapper

the output diff is similar, and the test method also fails.

So, I'm confused about why the diff is so large. Can you give me some advice?

@zhangp365
Copy link
Contributor Author

zhangp365 commented Feb 10, 2025

From test log:

 File "/home/runner/_work/xla/xla/pytorch/xla/test/scan/test_scan.py", line 410, in count_number_of_sines
    text: str = torch_xla._XLAC._get_xla_tensors_hlo(
RuntimeError: Error while lowering: [UNKNOWN_SCALAR[]] xla::device_data, xla_shape=f32[20,4,4]{0,2,1}, dynamic_dims: (), device=TPU:0
Error: ./torch_xla/csrc/runtime/pjrt_computation_client.h:194 : Check failed: HasValue() 
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::runtime::PjRtComputationClient::PjRtData::GetHandle()
	torch_xla::LoweringContext::GetParameter(std::shared_ptr<torch::lazy::BackendData> const&, std::unordered_set<unsigned int, std::hash<unsigned int>, std::equal_to<unsigned int>, std::allocator<unsigned int> > const&)
	torch_xla::DeviceData::Lower(torch_xla::LoweringContext*) const
	torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
	torch_xla::LoweringContext::GetOutputOp(torch::lazy::Output const&)
	torch_xla::LoweringContext::AddResult(torch::lazy::Output const&)
	torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
	torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)

the bug is an existing bug? not from this fix?

@zpcore
Copy link
Member

zpcore commented Feb 10, 2025

From test log:

 File "/home/runner/_work/xla/xla/pytorch/xla/test/scan/test_scan.py", line 410, in count_number_of_sines
    text: str = torch_xla._XLAC._get_xla_tensors_hlo(
RuntimeError: Error while lowering: [UNKNOWN_SCALAR[]] xla::device_data, xla_shape=f32[20,4,4]{0,2,1}, dynamic_dims: (), device=TPU:0
Error: ./torch_xla/csrc/runtime/pjrt_computation_client.h:194 : Check failed: HasValue() 
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::runtime::PjRtComputationClient::PjRtData::GetHandle()
	torch_xla::LoweringContext::GetParameter(std::shared_ptr<torch::lazy::BackendData> const&, std::unordered_set<unsigned int, std::hash<unsigned int>, std::equal_to<unsigned int>, std::allocator<unsigned int> > const&)
	torch_xla::DeviceData::Lower(torch_xla::LoweringContext*) const
	torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
	torch_xla::LoweringContext::GetOutputOp(torch::lazy::Output const&)
	torch_xla::LoweringContext::AddResult(torch::lazy::Output const&)
	torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
	torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)

the bug is an existing bug? not from this fix?

I ran your code with the test_pallas, and the test passes. The scan issue do seem random. Let me re trigger the test. By the way, can you leave a TODO item for the backward pass so that we can make a fix later? Thanks!

Copy link
Member

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for fixing the shape.

@zpcore zpcore merged commit cff9f4e into pytorch:master Feb 10, 2025
12 checks passed
@zhangp365
Copy link
Contributor Author

ok, thanks for accepting the pr.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants