Skip to content

Commit 6317ce0

Browse files
Deduplicated in Flash FNA
1 parent 42addc8 commit 6317ce0

28 files changed

+145
-4304
lines changed

csrc/include/natten/cuda/flash_fmha/flash_kernel/block.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ struct BlockMN {
2222
int const seqlen_q = seqlen_info.seqlen_q;
2323
int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
2424
int n_block_min = 0;
25-
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
26-
// if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
2725
return {n_block_min, n_block_max};
2826
}
2927

@@ -40,7 +38,6 @@ struct BlockMN {
4038
int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
4139
int const n_block_new_min = idx_k_new_min / kBlockN;
4240
int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
43-
// if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
4441
return {n_block_new_min, n_block_new_max};
4542
}
4643

csrc/include/natten/cuda/flash_fmha/flash_kernel/epilogue_bwd.hpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,10 @@
99
#include "cute/tensor.hpp"
1010

1111
#include "cutlass/gemm/collective/builders/sm90_common.inl"
12-
// #include "copy_sm90_bulk_reduce.hpp"
1312

1413
#include "seqlen.h"
1514
#include "named_barrier.hpp"
1615
#include "utils.h"
17-
// #include "natten/cuda/flash_fmha/seqlen.h"
18-
// #include "natten/cuda/flash_fmha/named_barrier.hpp"
19-
// #include "natten/cuda/flash_fmha/utils.h"
2016

2117
namespace natten {
2218
namespace cuda {
@@ -191,7 +187,6 @@ struct CollectiveEpilogueBwd {
191187
flash::convert_type_out(tdKrdK, tdKrdK_out);
192188
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
193189
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
194-
// if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
195190
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
196191
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
197192

@@ -461,12 +456,10 @@ struct CollectiveEpilogueBwdGQA {
461456
int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
462457
using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
463458

464-
// if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
465459

466460
if constexpr (Deterministic) {
467461
Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
468462
}
469-
// if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
470463
// if constexpr (Use_TMA) {
471464
// cutlass::arch::fence_view_async_shared();
472465
// cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
@@ -492,12 +485,10 @@ struct CollectiveEpilogueBwdGQA {
492485
cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
493486
}
494487
lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
495-
// if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
496488

497489
if constexpr (Deterministic) {
498490
Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
499491
}
500-
// if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
501492
// if constexpr (Use_TMA) {
502493
// cutlass::arch::fence_view_async_shared();
503494
// cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);

csrc/include/natten/cuda/flash_fmha/flash_kernel/epilogue_fwd.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ struct CollectiveEpilogueFwd {
289289
Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)),
290290
params.shape_LSE_packed,
291291
params.stride_LSE_packed)(_, bidh, bidb, 0);
292-
// if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
293292
if (!LargeHeadDimV || warp_group_idx == 0) {
294293
if constexpr (!PackGQA) {
295294
#pragma unroll
@@ -327,7 +326,6 @@ struct CollectiveEpilogueFwd {
327326
// if (!is_split) {
328327
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, bidb, _0{});
329328
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
330-
// if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
331329
GmemTiledCopyO gmem_tiled_copy_O;
332330
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
333331
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)

csrc/include/natten/cuda/flash_fmha/flash_kernel/flash.h

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,6 @@ struct Flash_fwd_params : public Qkv_params {
7272

7373
// The scaling factors for the kernel.
7474
float scale_softmax;
75-
// float softcap;
76-
77-
// array of length b+1 holding starting offset of each sequence.
78-
// int * __restrict__ cu_seqlens_q;
79-
// int * __restrict__ cu_seqlens_k;
80-
// int * __restrict__ cu_seqlens_knew;
81-
// int * __restrict__ leftpad_k;
82-
83-
// If provided, the actual length of each q/k sequence.
84-
// int *__restrict__ seqused_q;
85-
// int *__restrict__ seqused_k;
8675

8776
// The stride between rows of Oaccum.
8877
index_t oaccum_split_stride;
@@ -95,38 +84,6 @@ struct Flash_fwd_params : public Qkv_params {
9584
index_t lseaccum_batch_stride;
9685
index_t lseaccum_head_stride;
9786

98-
// The K_new and V_new matrices.
99-
// void * __restrict__ knew_ptr;
100-
// void * __restrict__ vnew_ptr;
101-
102-
// The stride between rows of the Q, K and V matrices.
103-
// index_t knew_batch_stride;
104-
// index_t vnew_batch_stride;
105-
// index_t knew_row_stride;
106-
// index_t vnew_row_stride;
107-
// index_t knew_head_stride;
108-
// index_t vnew_head_stride;
109-
110-
// void *__restrict__ qv_ptr;
111-
// index_t qv_batch_stride;
112-
// index_t qv_row_stride;
113-
// index_t qv_head_stride;
114-
115-
// The cos and sin matrices for rotary embedding.
116-
// void * __restrict__ rotary_cos_ptr;
117-
// void * __restrict__ rotary_sin_ptr;
118-
// int *__restrict__ seqlens_rotary;
119-
120-
// The indices to index into the KV cache.
121-
// int * __restrict__ kv_batch_idx;
122-
123-
// Paged KV cache
124-
// int * __restrict__ page_table;
125-
// index_t page_table_batch_stride;
126-
// int page_size;
127-
// int num_pages;
128-
// bool pagedkv_tma;
129-
13087
// The dropout probability (probability of keeping an activation).
13188
float p_dropout;
13289
// uint32_t p_dropout_in_uint;
@@ -221,10 +178,3 @@ struct Flash_bwd_params : public Flash_fwd_params {
221178
} // namespace natten
222179
////////////////////////////////////////////////////////////////////////////////////////////////////
223180

224-
// template <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
225-
// void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
226-
// void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl);
227-
// template <int Arch, typename T, int kHeadDim, bool Has_softcap>
228-
// void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
229-
// template <typename T, typename Tpartial, int kBlockK>
230-
// void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);

csrc/include/natten/cuda/flash_fmha/flash_kernel/flash_bwd_launch_template.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
221221
// int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
222222
// int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));
223223
// int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));
224-
// printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);
225224
void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
226225
if constexpr (size(ClusterShape{}) > 1) {
227226
if (smem_size >= 48 * 1024) {

csrc/include/natten/cuda/flash_fna/flash_kernel/bwd_mask.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88

99
#include "cutlass/fast_math.h" // For cutlass::FastDivmod
1010

11-
#include "utils.h"
1211
#include "na_utils.h"
13-
// #include "natten/cuda/flash_fmha/utils.h"
12+
#include "natten/cuda/flash_fmha/flash_kernel/utils.h"
1413

1514
namespace natten {
1615
namespace cuda {
@@ -73,8 +72,8 @@ struct BwdNAMask {
7372
Tensor cS = cute::make_identity_tensor(Shape<Int<!SwapAB ? kBlockM : kBlockN>, Int<!SwapAB ? kBlockN : kBlockM>>{});
7473
Tensor tScS = thread_mma.partition_C(cS);
7574

76-
Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tSrS.layout()));
77-
Tensor tScS_rowcol = make_tensor(tScS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tScS.layout()));
75+
Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tSrS.layout()));
76+
Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tScS.layout()));
7877

7978
if constexpr (!SwapAB) {
8079
tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * m_block * kBlockM + E<0>{} * size(q_blk_offset);
@@ -165,8 +164,8 @@ struct BwdNAMask {
165164
Tensor cS = cute::make_identity_tensor(Shape<Int<!SwapAB ? kBlockM : kBlockN>, Int<!SwapAB ? kBlockN : kBlockM>>{});
166165
Tensor tScS = thread_mma.partition_C(cS);
167166

168-
Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tSrS.layout()));
169-
Tensor tScS_rowcol = make_tensor(tScS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tScS.layout()));
167+
Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tSrS.layout()));
168+
Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tScS.layout()));
170169

171170
if constexpr (!SwapAB) {
172171
tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * m_block * kBlockM + E<0>{} * size(q_blk_offset);

csrc/include/natten/cuda/flash_fna/flash_kernel/cuda_check.h

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)