Skip to content

Commit 290eeab

Browse files
Rebased and changed calls as needed
1 parent e07e564 commit 290eeab

File tree

14 files changed

+172
-150
lines changed

14 files changed

+172
-150
lines changed

csrc/include/natten/cuda/flash_fna/flash_kernel/flash.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct NA_params {
2222
NADim stride;
2323
NADim dilation;
2424

25-
int num_heads_actual;
25+
int batch_size_actual;
2626

2727
};
2828

csrc/include/natten/cuda/flash_fna/flash_kernel/flash_bwd_launch_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ void run_flash_bwd(Flash_fna_bwd_params<NADim> &params, cudaStream_t stream) {
156156
params.window_size,
157157
params.stride,
158158
params.dilation,
159-
params.num_heads_actual
159+
params.batch_size_actual
160160
};
161161
// The case work with GQA is ugly but idk how to fix it.
162162
typename CollectiveEpilogue::Arguments epilogue_args {

csrc/include/natten/cuda/flash_fna/flash_kernel/flash_fwd_launch_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ void run_flash_fwd(Flash_fna_fwd_params<NADim> &params, cudaStream_t stream) {
9393
{params.v_descale_batch_stride, params.v_descale_head_stride},
9494
params.num_splits,
9595
// NA Args
96-
params.qkv_shape, params.q_shape, params.kv_shape, params.window_size, params.stride, params.dilation, params.num_heads_actual
96+
params.qkv_shape, params.q_shape, params.kv_shape, params.window_size, params.stride, params.dilation, params.batch_size_actual
9797
};
9898
typename CollectiveEpilogue::Arguments epilogue_args {
9999
static_cast<ElementOut*>(params.o_ptr),

csrc/include/natten/cuda/flash_fna/flash_kernel/mainloop_bwd_sm80.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ struct CollectiveMainloopBwdSm80 {
323323
NADim window_size;
324324
NADim stride;
325325
NADim dilation;
326-
int num_heads_actual;
326+
int batch_size_actual;
327327
};
328328

329329
// Device side kernel params
@@ -365,7 +365,7 @@ struct CollectiveMainloopBwdSm80 {
365365
NADim window_right;
366366
NADim stride;
367367
NADim dilation;
368-
int num_heads_actual;
368+
int batch_size_actual;
369369
bool is_fully_block_sparse;
370370
bool has_q_padding;
371371
bool requires_qkv_fixup;
@@ -420,7 +420,7 @@ struct CollectiveMainloopBwdSm80 {
420420
// !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,
421421
args.num_batch, args.dq_semaphore,
422422
args.qkv_shape, args.q_shape, args.kv_shape, args.window_size, window_left, window_right,
423-
args.stride, args.dilation, args.num_heads_actual,
423+
args.stride, args.dilation, args.batch_size_actual,
424424
is_fully_block_sparse_, has_q_padding_, requires_qkv_fixup_, is_dilated_
425425
// args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k
426426
};
@@ -445,13 +445,12 @@ struct CollectiveMainloopBwdSm80 {
445445
// params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k
446446
};
447447

448-
int head_idx = bidh;
449448
auto qkv_shape = params.qkv_shape;
450449
bool is_fully_block_sparse = params.is_fully_block_sparse;
451450
bool has_q_padding = params.has_q_padding;
452451

453452
if (params.requires_qkv_fixup) {
454-
qkv_shape = correct_qkv_shape(params.qkv_shape, head_idx, params.dilation, params.num_heads_actual);
453+
qkv_shape = correct_qkv_shape(params.qkv_shape, bidb, params.dilation, params.batch_size_actual);
455454
is_fully_block_sparse = fully_block_sparse<Causal>(
456455
qkv_shape,
457456
params.window_size,

csrc/include/natten/cuda/flash_fna/flash_kernel/mainloop_fwd_sm80.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ struct CollectiveMainloopFwdSm80 {
196196
NADim window_size;
197197
NADim stride;
198198
NADim dilation;
199-
int num_heads_actual;
199+
int batch_size_actual;
200200
};
201201

202202
// Device side kernel params
@@ -227,7 +227,7 @@ struct CollectiveMainloopFwdSm80 {
227227
NADim window_right;
228228
NADim stride;
229229
NADim dilation;
230-
int num_heads_actual;
230+
int batch_size_actual;
231231
bool is_fully_block_sparse;
232232
bool has_kv_padding;
233233
bool requires_qkv_fixup;
@@ -277,7 +277,7 @@ struct CollectiveMainloopFwdSm80 {
277277
args.stride_q_descale, args.stride_k_descale, args.stride_v_descale,
278278
1 /* args.num_splits */,
279279
args.qkv_shape, args.q_shape, args.kv_shape, args.window_size, window_left, window_right,
280-
args.stride, args.dilation, args.num_heads_actual,
280+
args.stride, args.dilation, args.batch_size_actual,
281281
is_fully_block_sparse, has_kv_padding, requires_qkv_fixup, is_dilated_
282282
};
283283
}
@@ -304,12 +304,11 @@ struct CollectiveMainloopFwdSm80 {
304304
int const split_idx = get<3>(block_coord);
305305
int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;
306306

307-
int head_idx = bidh;
308307
auto qkv_shape = params.qkv_shape;
309308
bool is_fully_block_sparse = params.is_fully_block_sparse;
310309
bool has_kv_padding = params.has_kv_padding;
311310
if (params.requires_qkv_fixup) {
312-
qkv_shape = correct_qkv_shape(params.qkv_shape, head_idx, params.dilation, params.num_heads_actual);
311+
qkv_shape = correct_qkv_shape(params.qkv_shape, bidb, params.dilation, params.batch_size_actual);
313312
is_fully_block_sparse = fully_block_sparse<Causal>(
314313
qkv_shape,
315314
params.window_size,

csrc/include/natten/cuda/flash_fna/flash_kernel/na_utils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,11 +434,11 @@ CUTLASS_DEVICE auto correct_qkv_shape(
434434
NADim const& qkv_shape, // this is pre-padding, pre-token permute, just
435435
// the original shape of the sequence mode in
436436
// the self attention
437-
int head_idx,
437+
int batch_idx,
438438
NADim const& dilation,
439-
int num_heads_actual) {
439+
int batch_size_actual) {
440440

441-
auto dilation_group_idx = head_idx / num_heads_actual;
441+
auto dilation_group_idx = batch_idx % product(dilation);
442442
auto dilation_group_crd = idx2crd(dilation_group_idx, dilation);
443443

444444
return correct_qkv_shape_wrt_dilation(

csrc/include/natten/cuda/flash_fna/flash_kernel/param_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Flash_fna_fwd_params<NADim> set_flash_fna_fwd_params(
8686
params.window_size = window_size;
8787
params.stride = stride;
8888
params.dilation = dilation;
89-
params.num_heads_actual = H / product(dilation);
89+
params.batch_size_actual = B / product(dilation);
9090

9191
params.is_bf16 = query.scalar_type() == torch::kBFloat16;
9292
params.is_e4m3 = false;
@@ -283,7 +283,7 @@ Flash_fna_bwd_params<NADim> set_flash_fna_bwd_params(
283283
params.window_size = window_size;
284284
params.stride = stride;
285285
params.dilation = dilation;
286-
params.num_heads_actual = H / product(dilation);
286+
params.batch_size_actual = B / product(dilation);
287287

288288
params.q_ptr = static_cast<void*>(query.data_ptr());
289289
params.k_ptr = static_cast<void*>(key.data_ptr());

setup.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,9 @@
6363
AUTOGEN_POLICY = AUTOGEN_POLICY if AUTOGEN_POLICY != "" else "default"
6464

6565
tmp_dir = tempfile.TemporaryDirectory()
66-
print(f"***************** {tmp_dir=}")
6766
NATTEN_BUILD_DIR = os.getenv("NATTEN_BUILD_DIR", tmp_dir.name)
6867
if not os.path.isdir(NATTEN_BUILD_DIR):
6968
NATTEN_BUILD_DIR = tmp_dir.name
70-
print(f"***************** {NATTEN_BUILD_DIR=}")
7169

7270
DEFAULT_N_WORKERS = max(1, (multiprocessing.cpu_count() // 4))
7371
try:

src/natten/_libnatten/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@
122122
"blackwell_na3d_backward",
123123
"blackwell_na3d_forward",
124124
"compute_delta",
125+
"flash_fmha_backward",
126+
"flash_fmha_forward",
125127
"fmha_backward",
126128
"fmha_forward",
127129
"hopper_fmha_backward",
@@ -132,6 +134,12 @@
132134
"hopper_na2d_forward",
133135
"hopper_na3d_backward",
134136
"hopper_na3d_forward",
137+
"flash_na1d_backward",
138+
"flash_na1d_forward",
139+
"flash_na2d_backward",
140+
"flash_na2d_forward",
141+
"flash_na3d_backward",
142+
"flash_na3d_forward",
135143
"na1d_backward",
136144
"na1d_forward",
137145
"na2d_backward",

0 commit comments

Comments
 (0)