-
Notifications
You must be signed in to change notification settings - Fork 84
Description
I am interested in running this code for training and inference using customized_flash_attn. I have flash_attn installed, but setting customized_flash_attn=True in the constructor yields:
Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<
flash_attn_funcis in [line 1140] [file .../miniconda3/envs/infinity/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py] <<<<<<
flash_attn_func.code.co_varnames=('q', 'k', 'v', 'dropout_p', 'softmax_scale', 'causal', 'window_size', 'softcap', 'alibi_slopes', 'deterministic', 'return_attn_probs') <<<<<<
Based on the code, it looks like there should be a custom version of flash_attn_func in the flash_attn library, which takes in additional Infinity-specific arguments to handle the unconventional attention mask.
However, I do not see a way to access this logic, and it seems like it is non-trivial to re-implement this kernel. Is there a way to access this kernel?