Skip to content

Incorrect version check for flash_attn leads to API incompatibility in v2.6.3 #158

@SkyR0ver

Description

@SkyR0ver

Description

The following error occurs when flash_attn == 2.6.3:

[rank0]:   File "/home/xxx/miniconda3/envs/xxx/lib/python3.10/site-packages/yunchang/kernels/attention.py", line 132, in flash_attn_forward
[rank0]:     block_out, block_lse, _, _ = _flash_attn_forward(
[rank0]: TypeError: _flash_attn_forward() got an unexpected keyword argument 'window_size_left'

Refering to the source code, this error originates from the inaccurate branch condition.

if flash_attn.__version__ < '2.6.3':  # <-- WRONG!
    block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
        # ...
        window_size=window_size,
        # ...
    )
else:
    block_out, block_lse, _, _ = _flash_attn_forward(
        # ...
        window_size_left=window_size[0],
        window_size_right=window_size[1],
        # ...
    )

To be more clear, the parameter window_size_left was first introduced in flash_attn 2.7.0, and in 2.6.3 the signature of _flash_attn_forward is still as follows:

def _flash_attn_forward(
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
)

Solution

Correct the branch condition to flash_attn.__version__ <= '2.6.3' or other equivalent solutions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions