Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f5ee386
Move devsetup to shi-01
AdityaKane2001 Oct 13, 2025
68a1faa
Flash FMHA fwd bwd works
AdityaKane2001 Oct 27, 2025
df9f5f4
Rebased
AdityaKane2001 Oct 27, 2025
f788ca4
Added flash FMHA (not FNA) frontend and tests, minor LSE bugfix
AdityaKane2001 Oct 29, 2025
1948ddf
Cleaning up diff
AdityaKane2001 Oct 29, 2025
577e61f
Cleaning up diff again
AdityaKane2001 Oct 29, 2025
7567faf
Cleaning up diff one last time
AdityaKane2001 Oct 29, 2025
cfe267a
One teeny tiny cleaning up of diff
AdityaKane2001 Oct 29, 2025
fe6f4ad
Forgot one directory
AdityaKane2001 Oct 29, 2025
eda3130
Unnecessary diff delete
AdityaKane2001 Oct 30, 2025
9201fe8
Leftover from debugging
AdityaKane2001 Nov 2, 2025
09fd2d0
Flash FNA fwd works
AdityaKane2001 Nov 22, 2025
57670c7
Flash FNA fwd C++ scaffolding done
AdityaKane2001 Nov 24, 2025
af2c6d8
Flash FNA fwd Python backend done, allcloses
AdityaKane2001 Nov 29, 2025
d49efbf
Flash FNA fwd Python backend done, allcloses and runs fast
AdityaKane2001 Dec 5, 2025
2ccc313
oopsie daisy
AdityaKane2001 Dec 5, 2025
0fc9560
Flash FNA bwd allclosed, added Python frontend, all tests pass
AdityaKane2001 Jan 15, 2026
e5ca6af
Rebased and changed calls as needed
AdityaKane2001 Jan 16, 2026
a43e714
Nits
AdityaKane2001 Jan 16, 2026
b78e692
Deduplicated in Flash FNA
AdityaKane2001 Jan 21, 2026
c777bbc
Added exception for Flash
AdityaKane2001 Feb 7, 2026
12cfab6
Added Ali Hassani copyrights
AdityaKane2001 Feb 7, 2026
a831b38
Removed unnecessary TODO
AdityaKane2001 Feb 7, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ file(GLOB TORCH_APIS ./src/*.cpp ./src/*.cu)
file(GLOB AUTOGEN_REFERENCE ./autogen/src/cuda/reference/*.cu)
file(GLOB AUTOGEN_FNA ./autogen/src/cuda/fna/*.cu)
file(GLOB AUTOGEN_FMHA ./autogen/src/cuda/fmha/*.cu)
file(GLOB AUTOGEN_FLASH_FNA ./autogen/src/cuda/flash_fna/*.cu)
file(GLOB AUTOGEN_FLASH_FNA_BWD ./autogen/src/cuda/flash_fna_bwd/*.cu)
file(GLOB AUTOGEN_FLASH_FMHA ./autogen/src/cuda/flash_fmha/*.cu)
file(GLOB AUTOGEN_FLASH_FMHA_BWD ./autogen/src/cuda/flash_fmha_bwd/*.cu)
if(${NATTEN_WITH_HOPPER_FNA})
file(GLOB AUTOGEN_HOPPER_FNA ./autogen/src/cuda/hopper_fna/*.cu ./autogen/src/cuda/hopper_fna_bwd/*.cu)
file(GLOB AUTOGEN_HOPPER_FMHA ./autogen/src/cuda/hopper_fmha/*.cu ./autogen/src/cuda/hopper_fmha_bwd/*.cu)
Expand All @@ -173,6 +177,10 @@ file(GLOB ALL_SOURCES
${AUTOGEN_REFERENCE}
${AUTOGEN_FNA}
${AUTOGEN_FMHA}
${AUTOGEN_FLASH_FNA}
${AUTOGEN_FLASH_FNA_BWD}
${AUTOGEN_FLASH_FMHA}
${AUTOGEN_FLASH_FMHA_BWD}
${AUTOGEN_BLACKWELL_FNA}
${AUTOGEN_BLACKWELL_FMHA}
${AUTOGEN_HOPPER_FNA}
Expand Down
122 changes: 122 additions & 0 deletions csrc/include/natten/cuda/flash_fmha/flash_fmha_backward.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/***************************************************************************************************
* Copyright (c) 2022-2025 Ali Hassani.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
*all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
**************************************************************************************************/

#pragma once

#include "natten/cuda/flash_fmha/flash_kernel/flash_bwd_launch_template.h"

namespace natten {
namespace cuda {
namespace flash {

struct Config {
int Stages_dO;
int Stages_dS_or_QSm80;
bool SdP_swapAB;
bool dKV_swapAB;
bool dQ_swapAB;
int NumMmaWarpGroups;
int AtomLayoutMSdP;
int AtomLayoutNdKV;
int AtomLayoutMdQ;
bool V_in_regs;
};

template <int HeadDim, int Arch>
constexpr Config get_config() {
if constexpr (HeadDim == 32) {
if constexpr (Arch == 86 || Arch == 89) {
return Config{2, 2, false, false, false, 2, 2, 4, 2, true};
} else if constexpr (Arch == 80) {
return Config{2, 2, false, false, false, 2, 4, 4, 4, false};
}
} else if constexpr (HeadDim == 64) {
if constexpr (Arch == 86 || Arch == 89) {
return Config{2, 2, false, false, false, 2, 2, 4, 2, true};
} else if constexpr (Arch == 80) {
return Config{2, 2, false, false, false, 2, 4, 4, 4, false};
}
} else if constexpr (HeadDim == 96) {
if constexpr (Arch == 86 || Arch == 89) {
return Config{1, 2, false, false, false, 2, 2, 4, 2, true};
} else if constexpr (Arch == 80) {
return Config{2, 2, false, false, false, 2, 2, 4, 2, false};
}
} else if constexpr (HeadDim == 128) {
if constexpr (Arch == 86 || Arch == 89) {
return Config{1, 2, false, false, false, 2, 2, 2, 2, true};
} else if constexpr (Arch == 80) {
return Config{2, 2, false, false, false, 2, 2, 2, 2, false};
}
} else if constexpr (HeadDim == 192) {
if constexpr (Arch == 86 || Arch == 89) {
return Config{1, 1, false, false, false, 2, 2, 2, 2, true};
} else if constexpr (Arch == 80) {
return Config{1, 2, false, true, false, 2, 4, 2, 2, false};
}
} else if constexpr (HeadDim == 256) {
if constexpr (Arch == 86 || Arch == 89) {
return Config{1, 1, false, false, false, 2, 2, 2, 1, true};
} else if constexpr (Arch == 80) {
return Config{1, 1, false, false, false, 2, 4, 2, 2, false};
}
} else {
static_assert(HeadDim == -1, "Unsupported HeadDim/Arch combination");
}
}


template <int Arch, typename Element, int HeadDim, int kBlockM, int kBlockN, bool Deterministic>
struct FlashBackwardKernel {

void run(Flash_bwd_params params, cudaStream_t stream) {

static constexpr Config config = get_config<HeadDim, Arch>();

auto flash_bwd = run_flash_bwd<
/* Arch= */ Arch,
/* kHeadDim= */ HeadDim,
/* kBlockM= */ kBlockM,
/* kBlockN= */ kBlockN,
/* Element= */ Element,
/* Deterministic= */ Deterministic,
/* GQA= */ false,
/* Stages_dO= */ config.Stages_dO,
/* Stages_dS_or_QSm80= */ config.Stages_dS_or_QSm80,
/* SdP_swapAB= */ config.SdP_swapAB,
/* dKV_swapAB= */ config.dKV_swapAB,
/* dQ_swapAB= */ config.dQ_swapAB,
/* NumMmaWarpGroups= */ config.NumMmaWarpGroups,
/* AtomLayoutMSdP= */ config.AtomLayoutMSdP,
/* AtomLayoutNdKV= */ config.AtomLayoutNdKV,
/* AtomLayoutMdQ= */ config.AtomLayoutMdQ,
/* V_in_regs= */ config.V_in_regs
>;

flash_bwd(params, stream);
}
};

} // namespace flash
} // namespace cuda
} // namespace natten
58 changes: 58 additions & 0 deletions csrc/include/natten/cuda/flash_fmha/flash_fmha_forward.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/***************************************************************************************************
* Copyright (c) 2022-2025 Ali Hassani.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
*all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
**************************************************************************************************/

#pragma once

#include "cutlass/cutlass.h"

#include "natten/cuda/flash_fmha/flash_kernel/flash.h"
#include "natten/cuda/flash_fmha/flash_kernel/flash_fwd_launch_template.h"

namespace natten {
namespace cuda {
namespace flash {

template <int Arch, typename Element, int HeadDim, int kBlockM, int kBlockN>
struct FlashForwardKernel {

void run(Flash_fwd_params params, cudaStream_t stream){

auto flash_fwd = run_flash_fwd<
/* Arch= */ Arch,
/* kHeadDim= */ HeadDim,
/* kHeadDimV= */ HeadDim,
/* kBlockM= */ kBlockM,
/* kBlockN= */ kBlockN,
/* Element= */ Element,
/* ElementOut= */ Element,
/* PackGQA= */ false,
/* V_colmajor= */ false
>;

flash_fwd(params, stream);
}
};

} // namespace flash
} // namespace cuda
} // namespace natten
107 changes: 107 additions & 0 deletions csrc/include/natten/cuda/flash_fmha/flash_kernel/block.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/***************************************************************************************************
* Copyright (c) 2022-2025 Ali Hassani.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
*all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/

#pragma once

namespace natten {
namespace cuda {
namespace flash {

template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool PackGQA=false>
struct BlockMN {

static
CUTLASS_DEVICE
cute::tuple<int, int> get_n_block_min_max(
SeqlenInfo_t const& seqlen_info,
int const m_block, int const bidb, int const split_idx, int const num_splits,
cutlass::FastDivmod const& qhead_per_khead_divmod) {

int const seqlen_k = seqlen_info.seqlen_k;
int const seqlen_q = seqlen_info.seqlen_q;
int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
int n_block_min = 0;
return {n_block_min, n_block_max};
}

static
CUTLASS_DEVICE
cute::tuple<int, int> get_n_block_k_new_min_max(
SeqlenInfo_t const& seqlen_info,
int const m_block, int const bidb, int const split_idx, int const num_splits,
cutlass::FastDivmod const& qhead_per_khead_divmod) {

auto [n_block_min, n_block_max] = get_n_block_min_max(
seqlen_info, m_block, bidb, split_idx, num_splits, qhead_per_khead_divmod);
int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
int const n_block_new_min = idx_k_new_min / kBlockN;
int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
return {n_block_new_min, n_block_new_max};
}

static
CUTLASS_DEVICE
cute::tuple<int, int> get_m_block_min_max(
SeqlenInfo_t const& seqlen_info,
int const n_block, int const bidb) {
// int const window_size_left, int const window_size_right, int const sink_token_length) {
int const seqlen_q = seqlen_info.seqlen_q;
int const seqlen_k = seqlen_info.seqlen_k;
int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
int m_block_min = 0;
return {m_block_min, m_block_max};
}

// If we have separate iterations with causal or local masking at the start, where do we stop
static
CUTLASS_DEVICE
int get_n_block_min_causal_local_mask(
SeqlenInfo_t const& seqlen_info,
int const m_block, int const n_block_min,
cutlass::FastDivmod const& qhead_per_khead_divmod) {
int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM);
int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
int n_idx_right = n_idx;
return std::max(n_block_min, n_idx_right / kBlockN);
}

// If we have separate iterations with local masking at the end, where do we stop the non-masked iterations
static
CUTLASS_DEVICE
int get_n_block_min_before_local_mask(
SeqlenInfo_t const& seqlen_info,
int const m_block, int const n_block_min,
cutlass::FastDivmod const& qhead_per_khead_divmod) {
int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
// (unused) // int n_idx_left = n_idx;
return n_block_min;
}

};

} // namespace flash
} // namespace cuda
} // namespace natten
46 changes: 46 additions & 0 deletions csrc/include/natten/cuda/flash_fmha/flash_kernel/cuda_check.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/***************************************************************************************************
* Copyright (c) 2022-2025 Ali Hassani.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
*all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#pragma once

#include <assert.h>
#include <stdlib.h>
namespace natten {
namespace cuda {
namespace flash {

#define FLASH_CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while(0)

#define FLASH_CHECK_CUDA_KERNEL_LAUNCH() FLASH_CHECK_CUDA(cudaGetLastError())

} // namespace flash
} // namespace cuda
} // namespace natten
Loading