Skip to content

Add Flash FMHA/FNA backend#278

Open
AdityaKane2001 wants to merge 23 commits intoSHI-Labs:mainfrom
AdityaKane2001:aditya/flash-na
Open

Add Flash FMHA/FNA backend#278
AdityaKane2001 wants to merge 23 commits intoSHI-Labs:mainfrom
AdityaKane2001:aditya/flash-na

Conversation

@AdityaKane2001
Copy link
Contributor

@AdityaKane2001 AdityaKane2001 commented Oct 30, 2025

TL;DR: Ported over Flash Attention CUTLASS 3.x kernels to NATTEN as-is, with wrappers to fit into NATTEN. Exposing through the flash-fmha backend.

Summary of changes:

  • C++:

    • Added actual kernel files under csrc/include/natten/cuda/flash_fmha/flash_kernel.
    • Torch C++ interface at csrc/src/flash_fmha.cu, which call into csrc/.../flash_fmha/flash_fmha_{forward/bakcward}.cuh
    • Added a utility file csrc/.../flash_kernel/param_utils.h for param conversion.
  • Python

    • Exposed C++ function through flash-fmha backend.
    • Added autograd function and configs for the same.
    • Wherever possible, some arrangement is done to later implement flash-fna backend.
    • Added autogen scripts and tests for flash-fma.

Present rough edges:

  1. Python frontend might have some code style inconsistencies.
  2. Stray template parameters for flash bwd template currently housed in flash_fmha_backward.cuh, as opposed to autogen. It seems that adding those to autogen scripts will make the scripts overly complex.
  3. Although correctness is guaranteed (because of tests), no particular refactoring of the actual flash attn kernel code was done.
  4. Important: Flash FMHA does not support causal argument. Please use Flash FNA 1D for this.

@AdityaKane2001
Copy link
Contributor Author

Note to self: See comment in #282 when dealing with dilation.

@AdityaKane2001
Copy link
Contributor Author

Update: Flash FNA forward added as well.

Similar to flash-fmha, exposed through flash-fna backend. Only forward supported for now, will port over backward next.

@AdityaKane2001
Copy link
Contributor Author

Checklist for features required in Flash:

  • Different head dims for QK and V
  • Arbitrary head dims
  • GQA/MQA
  • Varlen (for FMHA)

@AdityaKane2001 AdityaKane2001 changed the title Add Flash FMHA backend Add Flash FMHA/FNA backend Jan 16, 2026
@AdityaKane2001
Copy link
Contributor Author

@alihassanijr

Rebased onto main. All tests pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants