Skip to content

Missing kernel/function for customized_flash_attn #132

@CameronBraunstein

Description

@CameronBraunstein

I am interested in running this code for training and inference using customized_flash_attn. I have flash_attn installed, but setting customized_flash_attn=True in the constructor yields:

Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<
flash_attn_func is in [line 1140] [file .../miniconda3/envs/infinity/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py] <<<<<<
flash_attn_func.code.co_varnames=('q', 'k', 'v', 'dropout_p', 'softmax_scale', 'causal', 'window_size', 'softcap', 'alibi_slopes', 'deterministic', 'return_attn_probs') <<<<<<

Based on the code, it looks like there should be a custom version of flash_attn_func in the flash_attn library, which takes in additional Infinity-specific arguments to handle the unconventional attention mask.

However, I do not see a way to access this logic, and it seems like it is non-trivial to re-implement this kernel. Is there a way to access this kernel?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions