Skip to content

Enabling flash_attention_2 causes RuntimeError: cu_seqlens_q must have shape (batch_size + 1) #52

@hyc-hw

Description

@hyc-hw

Thank you for your excellent open source work. I have a question about using flash_attention_2

Environment:
Hardware: NVIDIA A100
Python version: 3.12
Package versions:
torch==2.8.0
torchvision==0.23.0
transformers==4.57.1
flash-attn==2.8.3
Model: Official Alpamayo model from Hugging Face
Description:
When loading the official Alpamayo model, I found that the default attention implementation is sdpa even FlashAttention-2 is used in config.json.

To enable faster inference, I manually forced FlashAttention-2 by adding the following line in alpamayo/src/alpamayo_r1/models/base_model.py:

class ReasoningVLA(PreTrainedModel, TrajectoryFusionMixin):
    """Reasoning Vision-Language-Action model."""

    config_class: type[ReasoningVLAConfig] = ReasoningVLAConfig
    base_model_prefix: str = "vlm"

    def __init__(
        self,
        config: ReasoningVLAConfig,
        pretrained_modules: dict[str, torch.nn.Module] | None = None,
        original_vocab_size: int | None = None,
        print_param_count: bool = True,
    ) -> None:
        super().__init__(config)

        if pretrained_modules is not None:
            for module in pretrained_modules.values():
                if not isinstance(module, torch.nn.Module):
                    continue
                _recursive_setattr(module, "_is_hf_initialized", True)
        else:
            pretrained_modules = {}
        config.attn_implementation = 'flash_attention_2'

However, running the inference script (alpamayo/src/alpamayo_r1/test_inference.py) results in the following error:

Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.
Traceback (most recent call last):
  File "/data_b/hyc/alpamayo/src/alpamayo_r1/test_inference_ori.py", line 52, in <module>
    pred_xyz, pred_rot, extra = model.sample_trajectories_from_data_with_vlm_rollout(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/src/alpamayo_r1/models/alpamayo_r1.py", line 291, in sample_trajectories_from_data_with_vlm_rollout
    sampled_action = self.diffusion.sample(
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/src/alpamayo_r1/diffusion/flow_matching.py", line 79, in sample
    return self._euler(
           ^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/src/alpamayo_r1/diffusion/flow_matching.py", line 131, in _euler
    v = step_fn(x=x, t=t_start)
        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/src/alpamayo_r1/models/alpamayo_r1.py", line 269, in step_fn
    expert_out_base = self.expert(
                      ^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 850, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 502, in forward
    hidden_states, _ = self.self_attn(
                       ^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 444, in forward
    attn_output, attn_weights = attention_interface(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/integrations/flash_attention.py", line 66, in flash_attention_forward
    attn_output = _flash_attention_forward(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 616, in _flash_attention_forward
    out_unpad = flash_varlen_fn(
                ^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 1443, in flash_attn_varlen_func
    return FlashAttnVarlenFunc.apply(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 925, in forward
    out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/autograd.py", line 111, in autograd_impl
    result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
    result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_ops.py", line 836, in redispatch
    return self._handle.redispatch_boxed(keyset, *args, **kwargs)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 344, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 377, in wrapped_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 165, in _flash_attn_varlen_forward
    out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cu_seqlens_q must have shape (batch_size + 1)

This suggests a mismatch between the expected input format for FlashAttention-2 (which, in varlen mode, requires packed sequences and valid cu_seqlens_q) and the actual padded input used during inference.
Current Workaround:
Keep attn_implementation as the default (sdpa) to avoid the crash.
Questions:
Does the official Alpamayo model officially support flash_attention_2?
If so, does it require special input preprocessing (e.g., sequence packing)?
Could this be a compatibility issue between Qwen3-VL (used as the VLM backbone) and FlashAttention-2 in transformers==4.57.1?
Any guidance on enabling FA2 safely would be greatly appreciated!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions