Skip to content

Commit 4be9fd3

Browse files
Flash FNA bwd allclosed, added Python frontend, all tests pass
1 parent 95c7250 commit 4be9fd3

28 files changed

+3172
-581
lines changed

csrc/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ file(GLOB AUTOGEN_REFERENCE ./autogen/src/cuda/reference/*.cu)
161161
file(GLOB AUTOGEN_FNA ./autogen/src/cuda/fna/*.cu)
162162
file(GLOB AUTOGEN_FMHA ./autogen/src/cuda/fmha/*.cu)
163163
file(GLOB AUTOGEN_FLASH_FNA ./autogen/src/cuda/flash_fna/*.cu)
164+
file(GLOB AUTOGEN_FLASH_FNA_BWD ./autogen/src/cuda/flash_fna_bwd/*.cu)
164165
file(GLOB AUTOGEN_FLASH_FMHA ./autogen/src/cuda/flash_fmha/*.cu)
165166
file(GLOB AUTOGEN_FLASH_FMHA_BWD ./autogen/src/cuda/flash_fmha_bwd/*.cu)
166167
if(${NATTEN_WITH_HOPPER_FNA})
@@ -177,6 +178,7 @@ file(GLOB ALL_SOURCES
177178
${AUTOGEN_FNA}
178179
${AUTOGEN_FMHA}
179180
${AUTOGEN_FLASH_FNA}
181+
${AUTOGEN_FLASH_FNA_BWD}
180182
${AUTOGEN_FLASH_FMHA}
181183
${AUTOGEN_FLASH_FMHA_BWD}
182184
${AUTOGEN_BLACKWELL_FNA}

csrc/include/natten/cuda/flash_fna/flash_fna_backward.cuh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ constexpr Config get_config() {
8686
}
8787

8888

89-
template <int Arch, typename Element, int HeadDim, int kBlockM, int kBlockN, bool Deterministic>
89+
template <int Arch, typename Element, int HeadDim, int kBlockM, int kBlockN,
90+
class NADim, class QTileShape, class KVTileShape, class Causal,
91+
bool Deterministic>
9092
struct FlashFnaBackwardKernel {
9193

92-
void run(Flash_fna_bwd_params params, cudaStream_t stream) {
94+
void run(Flash_fna_bwd_params<NADim> params, cudaStream_t stream) {
9395

9496
static constexpr Config config = get_config<HeadDim, Arch>();
9597

@@ -99,6 +101,10 @@ struct FlashFnaBackwardKernel {
99101
/* kBlockM= */ kBlockM,
100102
/* kBlockN= */ kBlockN,
101103
/* Element= */ Element,
104+
/* NADim= */ NADim,
105+
/* QTileShape= */ QTileShape,
106+
/* KVTileShape= */ KVTileShape,
107+
/* Causal= */ Causal,
102108
/* Deterministic= */ Deterministic,
103109
/* GQA= */ false,
104110
/* Stages_dO= */ config.Stages_dO,

csrc/include/natten/cuda/flash_fna/flash_kernel/block.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,67 @@ struct NABlockMN {
5858

5959
return {kv_start, kv_diff_tiles};
6060
}
61+
62+
static
63+
CUTLASS_DEVICE
64+
cute::tuple<NADim, NADim> get_m_block_min_max(
65+
SeqlenInfo_t const& seqlen_info,
66+
int const n_block, int const bidb,
67+
cutlass::FastDivmod const& qhead_per_khead_divmod,
68+
// NA Args
69+
NADim kv_shape, NADim qkv_shape,
70+
NADim window_size, NADim window_left, NADim window_right, NADim stride
71+
) {
72+
73+
auto stride_group_offset = get_bwd_stride_offset(stride);
74+
75+
auto q_tile_shape = QTileShape{};
76+
auto kv_tile_shape = KVTileShape{};
77+
78+
auto kv_tiled = ceil_div(kv_shape, kv_tile_shape);
79+
80+
// Map KV index back to coord
81+
auto kv_tile_coord = idx2crd(n_block, kv_tiled);
82+
auto kv_coord = tuple_mul(kv_tile_coord, kv_tile_shape);
83+
84+
auto kv_tile_offset_last = idx2crd(size(kv_tile_shape) - 1, kv_tile_shape);
85+
auto kv_coord_last = tuple_add(kv_coord, kv_tile_offset_last);
86+
87+
// q start and end instead of kv like in forward pass
88+
auto q_start_actual = get_bwd_window_start<Causal>(
89+
kv_coord,
90+
stride_group_offset,
91+
window_left,
92+
window_right,
93+
window_size,
94+
stride,
95+
qkv_shape);
96+
97+
auto last_q_start_actual = get_bwd_window_start<Causal>(
98+
kv_coord_last,
99+
stride_group_offset,
100+
window_left,
101+
window_right,
102+
window_size,
103+
stride,
104+
qkv_shape);
105+
auto q_end_actual = get_bwd_window_end<Causal>(
106+
kv_coord_last,
107+
stride_group_offset,
108+
window_left,
109+
window_right,
110+
window_size,
111+
stride,
112+
qkv_shape);
113+
114+
auto q_start = floor_tuple(q_start_actual, q_tile_shape);
115+
auto q_end = ceil_tuple(q_end_actual, q_tile_shape);
116+
117+
auto q_diff = tuple_sub(q_end, q_start);
118+
auto q_diff_tiles = ceil_div(q_diff, q_tile_shape);
119+
120+
return {q_start, q_diff_tiles};
121+
}
61122
};
62123

63124
template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool PackGQA=false>
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
2+
3+
#pragma once
4+
5+
#include <cute/tensor.hpp>
6+
#include "cute/util/debug.hpp"
7+
#include "cute/util/print.hpp"
8+
9+
#include "cutlass/fast_math.h" // For cutlass::FastDivmod
10+
11+
#include "utils.h"
12+
#include "na_utils.h"
13+
// #include "natten/cuda/flash_fmha/utils.h"
14+
15+
namespace natten {
16+
namespace cuda {
17+
namespace flash_fna {
18+
19+
using namespace cute;
20+
21+
template<int kBlockM, int kBlockN, class NADim, class QTileShape, class KVTileShape, class Causal,
22+
bool PackGQA, typename TiledMma, class IterMapType, bool SwapAB=false>
23+
struct BwdNAMask {
24+
static_assert(!(PackGQA && SwapAB), "Cannot be both PackGQA and SwapAB");
25+
int const thread_idx;
26+
int const seqlen_q, seqlen_k;
27+
28+
NADim window_size;
29+
NADim window_left;
30+
NADim window_right;
31+
NADim stride;
32+
NADim qkv_shape;
33+
NADim q_shape;
34+
NADim kv_shape;
35+
NADim q_blk_offset;
36+
NADim q_diff_tiles;
37+
38+
bool is_fully_block_sparse;
39+
bool has_q_padding;
40+
41+
IterMapType iter_to_tile_map;
42+
cutlass::FastDivmod const qhead_per_khead_divmod;
43+
44+
CUTLASS_DEVICE
45+
BwdNAMask (const int thread_idx, const int seqlen_q, const int seqlen_k,
46+
cutlass::FastDivmod const &qhead_per_khead_divmod,
47+
NADim window_size, NADim window_left, NADim window_right, NADim stride,
48+
NADim qkv_shape, NADim q_shape, NADim kv_shape,
49+
NADim q_blk_offset, NADim q_diff_tiles, IterMapType iter_to_tile_map,
50+
bool is_fully_block_sparse, bool has_q_padding):
51+
thread_idx(thread_idx),
52+
seqlen_q(seqlen_q),
53+
seqlen_k(seqlen_k),
54+
qhead_per_khead_divmod(qhead_per_khead_divmod),
55+
window_size(window_size),
56+
window_left(window_left),
57+
window_right(window_right),
58+
stride(stride),
59+
qkv_shape(qkv_shape),
60+
q_shape(q_shape),
61+
kv_shape(kv_shape),
62+
q_blk_offset(q_blk_offset),
63+
q_diff_tiles(q_diff_tiles),
64+
iter_to_tile_map(iter_to_tile_map),
65+
is_fully_block_sparse(is_fully_block_sparse),
66+
has_q_padding(has_q_padding) {}
67+
68+
template <typename Engine, typename Layout>
69+
CUTLASS_DEVICE
70+
void apply_na_mask(Tensor<Engine, Layout> &tSrS, const int m_block, const int n_block) {
71+
auto thread_mma = TiledMma{}.get_thread_slice(thread_idx);
72+
73+
Tensor cS = cute::make_identity_tensor(Shape<Int<!SwapAB ? kBlockM : kBlockN>, Int<!SwapAB ? kBlockN : kBlockM>>{});
74+
Tensor tScS = thread_mma.partition_C(cS);
75+
76+
Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tSrS.layout()));
77+
Tensor tScS_rowcol = make_tensor(tScS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tScS.layout()));
78+
79+
if constexpr (!SwapAB) {
80+
tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * m_block * kBlockM + E<0>{} * size(q_blk_offset);
81+
tScS_rowcol.data() = tScS_rowcol.data() + E<1>{} * n_block * kBlockN;
82+
}
83+
else {
84+
tScS_rowcol.data() = tScS_rowcol.data() + E<1>{} * m_block * kBlockM + E<1>{} * size(q_blk_offset);
85+
tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * n_block * kBlockN;
86+
}
87+
// tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * m_block * kBlockM + E<0>{} * size(q_blk_offset);
88+
// tScS_rowcol.data() = tScS_rowcol.data() + E<1>{} * n_block * kBlockN;
89+
90+
auto q_tile_shape = QTileShape{};
91+
auto kv_tile_shape = KVTileShape{};
92+
93+
auto stride_group_offset = get_bwd_stride_offset(stride);
94+
95+
auto kv_tiled = ceil_div(kv_shape, kv_tile_shape);
96+
97+
auto [q_idx_first, kv_idx_first] = tScS_rowcol(0);
98+
if constexpr (SwapAB) {
99+
auto tmp = q_idx_first;
100+
q_idx_first = kv_idx_first;
101+
kv_idx_first = tmp;
102+
}
103+
104+
// KV coord remap
105+
int kv_tile_idx = kv_idx_first / size(kv_tile_shape);
106+
auto kv_tile_coord = idx2crd(kv_tile_idx, kv_tiled);
107+
auto kv_tile_offset = tuple_mul(kv_tile_coord, kv_tile_shape);
108+
int kv_idx_first_in_tile = kv_tile_idx * size(kv_tile_shape);
109+
auto kv_ctr = make_identity_tensor(kv_tile_shape);
110+
auto kv_ctr_offset = domain_offset(kv_tile_offset, kv_ctr);
111+
112+
// Q coord remap
113+
int q_tile_idx = m_block;
114+
auto q_tile_coord = idx2crd(q_tile_idx, q_diff_tiles);
115+
auto q_tile_offset = tuple_add(q_blk_offset, tuple_mul(q_tile_coord, q_tile_shape));
116+
int q_idx_first_in_tile = (q_tile_idx * size(q_tile_shape)) + size(q_blk_offset);
117+
118+
auto q_ctr = make_identity_tensor(q_tile_shape);
119+
auto q_ctr_offset = domain_offset(q_tile_offset, q_ctr);
120+
121+
CUTLASS_PRAGMA_UNROLL
122+
for (int i = 0; i < size(tSrS_rowcol); i++) {
123+
auto [q_idx, kv_idx] = tScS_rowcol(i);
124+
125+
if constexpr (SwapAB) {
126+
auto tmp = q_idx;
127+
q_idx = kv_idx;
128+
kv_idx = tmp;
129+
}
130+
131+
auto q_coord = q_ctr_offset(q_idx - q_idx_first_in_tile);
132+
auto kv_coord = kv_ctr_offset(kv_idx - kv_idx_first_in_tile);
133+
134+
auto q_start = get_bwd_window_start<Causal>(
135+
kv_coord,
136+
stride_group_offset,
137+
window_left,
138+
window_right,
139+
window_size,
140+
stride,
141+
qkv_shape);
142+
auto q_end = get_bwd_window_end<Causal>(
143+
kv_coord,
144+
stride_group_offset,
145+
window_left,
146+
window_right,
147+
window_size,
148+
stride,
149+
qkv_shape);
150+
151+
bool is_neigh = is_neighbor(q_coord, q_start, q_end);
152+
if (not is_neighbor(q_coord, q_start, q_end)) {
153+
tSrS_rowcol(i) = -INFINITY;
154+
}
155+
}
156+
}
157+
158+
template <typename Engine, typename Layout>
159+
CUTLASS_DEVICE
160+
void apply_padding(Tensor<Engine, Layout> &tSrS, const int m_block, const int n_block) {
161+
162+
// Q coord remap
163+
auto thread_mma = TiledMma{}.get_thread_slice(thread_idx);
164+
165+
Tensor cS = cute::make_identity_tensor(Shape<Int<!SwapAB ? kBlockM : kBlockN>, Int<!SwapAB ? kBlockN : kBlockM>>{});
166+
Tensor tScS = thread_mma.partition_C(cS);
167+
168+
Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tSrS.layout()));
169+
Tensor tScS_rowcol = make_tensor(tScS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tScS.layout()));
170+
171+
if constexpr (!SwapAB) {
172+
tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * m_block * kBlockM + E<0>{} * size(q_blk_offset);
173+
tScS_rowcol.data() = tScS_rowcol.data() + E<1>{} * n_block * kBlockN;
174+
}
175+
else {
176+
tScS_rowcol.data() = tScS_rowcol.data() + E<1>{} * m_block * kBlockM + E<1>{} * size(q_blk_offset);
177+
tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * n_block * kBlockN;
178+
}
179+
180+
auto q_tile_shape = QTileShape{};
181+
182+
auto stride_group_offset = get_bwd_stride_offset(stride);
183+
184+
auto [q_idx_first, kv_idx_first] = tScS_rowcol(0);
185+
if constexpr (SwapAB) {
186+
auto tmp = q_idx_first;
187+
q_idx_first = kv_idx_first;
188+
kv_idx_first = tmp;
189+
}
190+
191+
int q_tile_idx = m_block;
192+
auto q_tile_coord = idx2crd(q_tile_idx, q_diff_tiles);
193+
// auto q_tile_offset = idx2crd(q_tile_res, q_tile_shape);
194+
auto q_tile_offset = tuple_add(q_blk_offset, tuple_mul(q_tile_coord, q_tile_shape));
195+
int q_idx_first_in_tile = (q_tile_idx * size(q_tile_shape)) + size(q_blk_offset);
196+
197+
auto q_ctr = make_identity_tensor(q_tile_shape);
198+
auto q_ctr_offset = domain_offset(q_tile_offset, q_ctr);
199+
200+
// int q_tile_idx = q_idx_first / size(q_tile_shape);
201+
// int q_tile_res = q_idx_first % size(q_tile_shape);
202+
203+
// auto q_tile_coord = idx2crd(q_tile_idx, q_diff_tiles);
204+
// auto q_tile_offset = idx2crd(q_tile_res, q_tile_shape);
205+
// auto q_thread_offset = tuple_add(
206+
// q_tile_offset,
207+
// tuple_add(q_blk_offset, tuple_mul(q_tile_coord, q_tile_shape)));
208+
209+
// auto q_ctr = make_identity_tensor(q_tile_shape);
210+
// auto q_ctr_offset = domain_offset(q_thread_offset, q_ctr);
211+
212+
CUTLASS_PRAGMA_UNROLL
213+
for (int i = 0; i < size(tSrS_rowcol); i++) {
214+
auto [q_idx, kv_idx] = tScS_rowcol(i);
215+
216+
if constexpr (SwapAB) {
217+
auto tmp = q_idx;
218+
q_idx = kv_idx;
219+
kv_idx = tmp;
220+
}
221+
222+
auto q_coord = q_ctr_offset(q_idx - q_idx_first_in_tile);
223+
224+
if (not is_within_bounds(q_coord, qkv_shape)) {
225+
tSrS_rowcol(i) = -INFINITY;
226+
}
227+
}
228+
}
229+
230+
};
231+
232+
} // namespace flash_fna
233+
} // namespace cuda
234+
} // namespace natten

csrc/include/natten/cuda/flash_fna/flash_kernel/flash_bwd_kernel_sm80.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#pragma once
66

77
#include "cute/tensor.hpp"
8+
#include "cute/util/debug.hpp"
89

910
#include <cutlass/cutlass.h>
1011
#include <cutlass/array.h>

csrc/include/natten/cuda/flash_fna/flash_kernel/flash_bwd_launch_template.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ namespace flash_fna {
3838
using namespace cute;
3939

4040
template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
41-
class NADim,
41+
class NADim, class QTileShape, class KVTileShape, class Causal,
4242
bool Deterministic, bool GQA,
4343
int Stages_dO=2, int Stages_dS_or_QSm80=2,
4444
bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
@@ -104,7 +104,7 @@ void run_flash_bwd(Flash_fna_bwd_params<NADim> &params, cudaStream_t stream) {
104104
// SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
105105
// >;
106106
using CollectiveMainloop = flash_fna::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
107-
Deterministic,
107+
Deterministic, NADim, QTileShape, KVTileShape, Causal,
108108
SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>;
109109
using CollectiveEpilogue = std::conditional_t<
110110
!GQA,
@@ -148,7 +148,15 @@ void run_flash_bwd(Flash_fna_bwd_params<NADim> &params, cudaStream_t stream) {
148148
{_1{}, seqlen_q_rounded, params.h * params.seqlen_q_rounded}, // stride_dPsum
149149
params.scale_softmax,
150150
params.b,
151-
params.dq_semaphore
151+
params.dq_semaphore,
152+
// NA Args
153+
params.qkv_shape,
154+
params.q_shape,
155+
params.kv_shape,
156+
params.window_size,
157+
params.stride,
158+
params.dilation,
159+
params.num_heads_actual
152160
};
153161
// The case work with GQA is ugly but idk how to fix it.
154162
typename CollectiveEpilogue::Arguments epilogue_args {

0 commit comments

Comments
 (0)