|
| 1 | + |
| 2 | + |
| 3 | +#pragma once |
| 4 | + |
| 5 | +#include <cute/tensor.hpp> |
| 6 | +#include "cute/util/debug.hpp" |
| 7 | +#include "cute/util/print.hpp" |
| 8 | + |
| 9 | +#include "cutlass/fast_math.h" // For cutlass::FastDivmod |
| 10 | + |
| 11 | +#include "utils.h" |
| 12 | +#include "na_utils.h" |
| 13 | +// #include "natten/cuda/flash_fmha/utils.h" |
| 14 | + |
| 15 | +namespace natten { |
| 16 | +namespace cuda { |
| 17 | +namespace flash_fna { |
| 18 | + |
| 19 | +using namespace cute; |
| 20 | + |
| 21 | +template<int kBlockM, int kBlockN, class NADim, class QTileShape, class KVTileShape, class Causal, |
| 22 | + bool PackGQA, typename TiledMma, class IterMapType, bool SwapAB=false> |
| 23 | +struct BwdNAMask { |
| 24 | + static_assert(!(PackGQA && SwapAB), "Cannot be both PackGQA and SwapAB"); |
| 25 | + int const thread_idx; |
| 26 | + int const seqlen_q, seqlen_k; |
| 27 | + |
| 28 | + NADim window_size; |
| 29 | + NADim window_left; |
| 30 | + NADim window_right; |
| 31 | + NADim stride; |
| 32 | + NADim qkv_shape; |
| 33 | + NADim q_shape; |
| 34 | + NADim kv_shape; |
| 35 | + NADim q_blk_offset; |
| 36 | + NADim q_diff_tiles; |
| 37 | + |
| 38 | + bool is_fully_block_sparse; |
| 39 | + bool has_q_padding; |
| 40 | + |
| 41 | + IterMapType iter_to_tile_map; |
| 42 | + cutlass::FastDivmod const qhead_per_khead_divmod; |
| 43 | + |
| 44 | + CUTLASS_DEVICE |
| 45 | + BwdNAMask (const int thread_idx, const int seqlen_q, const int seqlen_k, |
| 46 | + cutlass::FastDivmod const &qhead_per_khead_divmod, |
| 47 | + NADim window_size, NADim window_left, NADim window_right, NADim stride, |
| 48 | + NADim qkv_shape, NADim q_shape, NADim kv_shape, |
| 49 | + NADim q_blk_offset, NADim q_diff_tiles, IterMapType iter_to_tile_map, |
| 50 | + bool is_fully_block_sparse, bool has_q_padding): |
| 51 | + thread_idx(thread_idx), |
| 52 | + seqlen_q(seqlen_q), |
| 53 | + seqlen_k(seqlen_k), |
| 54 | + qhead_per_khead_divmod(qhead_per_khead_divmod), |
| 55 | + window_size(window_size), |
| 56 | + window_left(window_left), |
| 57 | + window_right(window_right), |
| 58 | + stride(stride), |
| 59 | + qkv_shape(qkv_shape), |
| 60 | + q_shape(q_shape), |
| 61 | + kv_shape(kv_shape), |
| 62 | + q_blk_offset(q_blk_offset), |
| 63 | + q_diff_tiles(q_diff_tiles), |
| 64 | + iter_to_tile_map(iter_to_tile_map), |
| 65 | + is_fully_block_sparse(is_fully_block_sparse), |
| 66 | + has_q_padding(has_q_padding) {} |
| 67 | + |
| 68 | + template <typename Engine, typename Layout> |
| 69 | + CUTLASS_DEVICE |
| 70 | + void apply_na_mask(Tensor<Engine, Layout> &tSrS, const int m_block, const int n_block) { |
| 71 | + auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); |
| 72 | + |
| 73 | + Tensor cS = cute::make_identity_tensor(Shape<Int<!SwapAB ? kBlockM : kBlockN>, Int<!SwapAB ? kBlockN : kBlockM>>{}); |
| 74 | + Tensor tScS = thread_mma.partition_C(cS); |
| 75 | + |
| 76 | + Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tSrS.layout())); |
| 77 | + Tensor tScS_rowcol = make_tensor(tScS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tScS.layout())); |
| 78 | + |
| 79 | + if constexpr (!SwapAB) { |
| 80 | + tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * m_block * kBlockM + E<0>{} * size(q_blk_offset); |
| 81 | + tScS_rowcol.data() = tScS_rowcol.data() + E<1>{} * n_block * kBlockN; |
| 82 | + } |
| 83 | + else { |
| 84 | + tScS_rowcol.data() = tScS_rowcol.data() + E<1>{} * m_block * kBlockM + E<1>{} * size(q_blk_offset); |
| 85 | + tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * n_block * kBlockN; |
| 86 | + } |
| 87 | + // tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * m_block * kBlockM + E<0>{} * size(q_blk_offset); |
| 88 | + // tScS_rowcol.data() = tScS_rowcol.data() + E<1>{} * n_block * kBlockN; |
| 89 | + |
| 90 | + auto q_tile_shape = QTileShape{}; |
| 91 | + auto kv_tile_shape = KVTileShape{}; |
| 92 | + |
| 93 | + auto stride_group_offset = get_bwd_stride_offset(stride); |
| 94 | + |
| 95 | + auto kv_tiled = ceil_div(kv_shape, kv_tile_shape); |
| 96 | + |
| 97 | + auto [q_idx_first, kv_idx_first] = tScS_rowcol(0); |
| 98 | + if constexpr (SwapAB) { |
| 99 | + auto tmp = q_idx_first; |
| 100 | + q_idx_first = kv_idx_first; |
| 101 | + kv_idx_first = tmp; |
| 102 | + } |
| 103 | + |
| 104 | + // KV coord remap |
| 105 | + int kv_tile_idx = kv_idx_first / size(kv_tile_shape); |
| 106 | + auto kv_tile_coord = idx2crd(kv_tile_idx, kv_tiled); |
| 107 | + auto kv_tile_offset = tuple_mul(kv_tile_coord, kv_tile_shape); |
| 108 | + int kv_idx_first_in_tile = kv_tile_idx * size(kv_tile_shape); |
| 109 | + auto kv_ctr = make_identity_tensor(kv_tile_shape); |
| 110 | + auto kv_ctr_offset = domain_offset(kv_tile_offset, kv_ctr); |
| 111 | + |
| 112 | + // Q coord remap |
| 113 | + int q_tile_idx = m_block; |
| 114 | + auto q_tile_coord = idx2crd(q_tile_idx, q_diff_tiles); |
| 115 | + auto q_tile_offset = tuple_add(q_blk_offset, tuple_mul(q_tile_coord, q_tile_shape)); |
| 116 | + int q_idx_first_in_tile = (q_tile_idx * size(q_tile_shape)) + size(q_blk_offset); |
| 117 | + |
| 118 | + auto q_ctr = make_identity_tensor(q_tile_shape); |
| 119 | + auto q_ctr_offset = domain_offset(q_tile_offset, q_ctr); |
| 120 | + |
| 121 | + CUTLASS_PRAGMA_UNROLL |
| 122 | + for (int i = 0; i < size(tSrS_rowcol); i++) { |
| 123 | + auto [q_idx, kv_idx] = tScS_rowcol(i); |
| 124 | + |
| 125 | + if constexpr (SwapAB) { |
| 126 | + auto tmp = q_idx; |
| 127 | + q_idx = kv_idx; |
| 128 | + kv_idx = tmp; |
| 129 | + } |
| 130 | + |
| 131 | + auto q_coord = q_ctr_offset(q_idx - q_idx_first_in_tile); |
| 132 | + auto kv_coord = kv_ctr_offset(kv_idx - kv_idx_first_in_tile); |
| 133 | + |
| 134 | + auto q_start = get_bwd_window_start<Causal>( |
| 135 | + kv_coord, |
| 136 | + stride_group_offset, |
| 137 | + window_left, |
| 138 | + window_right, |
| 139 | + window_size, |
| 140 | + stride, |
| 141 | + qkv_shape); |
| 142 | + auto q_end = get_bwd_window_end<Causal>( |
| 143 | + kv_coord, |
| 144 | + stride_group_offset, |
| 145 | + window_left, |
| 146 | + window_right, |
| 147 | + window_size, |
| 148 | + stride, |
| 149 | + qkv_shape); |
| 150 | + |
| 151 | + bool is_neigh = is_neighbor(q_coord, q_start, q_end); |
| 152 | + if (not is_neighbor(q_coord, q_start, q_end)) { |
| 153 | + tSrS_rowcol(i) = -INFINITY; |
| 154 | + } |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + template <typename Engine, typename Layout> |
| 159 | + CUTLASS_DEVICE |
| 160 | + void apply_padding(Tensor<Engine, Layout> &tSrS, const int m_block, const int n_block) { |
| 161 | + |
| 162 | + // Q coord remap |
| 163 | + auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); |
| 164 | + |
| 165 | + Tensor cS = cute::make_identity_tensor(Shape<Int<!SwapAB ? kBlockM : kBlockN>, Int<!SwapAB ? kBlockN : kBlockM>>{}); |
| 166 | + Tensor tScS = thread_mma.partition_C(cS); |
| 167 | + |
| 168 | + Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tSrS.layout())); |
| 169 | + Tensor tScS_rowcol = make_tensor(tScS.data(), flash_fna::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tScS.layout())); |
| 170 | + |
| 171 | + if constexpr (!SwapAB) { |
| 172 | + tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * m_block * kBlockM + E<0>{} * size(q_blk_offset); |
| 173 | + tScS_rowcol.data() = tScS_rowcol.data() + E<1>{} * n_block * kBlockN; |
| 174 | + } |
| 175 | + else { |
| 176 | + tScS_rowcol.data() = tScS_rowcol.data() + E<1>{} * m_block * kBlockM + E<1>{} * size(q_blk_offset); |
| 177 | + tScS_rowcol.data() = tScS_rowcol.data() + E<0>{} * n_block * kBlockN; |
| 178 | + } |
| 179 | + |
| 180 | + auto q_tile_shape = QTileShape{}; |
| 181 | + |
| 182 | + auto stride_group_offset = get_bwd_stride_offset(stride); |
| 183 | + |
| 184 | + auto [q_idx_first, kv_idx_first] = tScS_rowcol(0); |
| 185 | + if constexpr (SwapAB) { |
| 186 | + auto tmp = q_idx_first; |
| 187 | + q_idx_first = kv_idx_first; |
| 188 | + kv_idx_first = tmp; |
| 189 | + } |
| 190 | + |
| 191 | + int q_tile_idx = m_block; |
| 192 | + auto q_tile_coord = idx2crd(q_tile_idx, q_diff_tiles); |
| 193 | + // auto q_tile_offset = idx2crd(q_tile_res, q_tile_shape); |
| 194 | + auto q_tile_offset = tuple_add(q_blk_offset, tuple_mul(q_tile_coord, q_tile_shape)); |
| 195 | + int q_idx_first_in_tile = (q_tile_idx * size(q_tile_shape)) + size(q_blk_offset); |
| 196 | + |
| 197 | + auto q_ctr = make_identity_tensor(q_tile_shape); |
| 198 | + auto q_ctr_offset = domain_offset(q_tile_offset, q_ctr); |
| 199 | + |
| 200 | + // int q_tile_idx = q_idx_first / size(q_tile_shape); |
| 201 | + // int q_tile_res = q_idx_first % size(q_tile_shape); |
| 202 | + |
| 203 | + // auto q_tile_coord = idx2crd(q_tile_idx, q_diff_tiles); |
| 204 | + // auto q_tile_offset = idx2crd(q_tile_res, q_tile_shape); |
| 205 | + // auto q_thread_offset = tuple_add( |
| 206 | + // q_tile_offset, |
| 207 | + // tuple_add(q_blk_offset, tuple_mul(q_tile_coord, q_tile_shape))); |
| 208 | + |
| 209 | + // auto q_ctr = make_identity_tensor(q_tile_shape); |
| 210 | + // auto q_ctr_offset = domain_offset(q_thread_offset, q_ctr); |
| 211 | + |
| 212 | + CUTLASS_PRAGMA_UNROLL |
| 213 | + for (int i = 0; i < size(tSrS_rowcol); i++) { |
| 214 | + auto [q_idx, kv_idx] = tScS_rowcol(i); |
| 215 | + |
| 216 | + if constexpr (SwapAB) { |
| 217 | + auto tmp = q_idx; |
| 218 | + q_idx = kv_idx; |
| 219 | + kv_idx = tmp; |
| 220 | + } |
| 221 | + |
| 222 | + auto q_coord = q_ctr_offset(q_idx - q_idx_first_in_tile); |
| 223 | + |
| 224 | + if (not is_within_bounds(q_coord, qkv_shape)) { |
| 225 | + tSrS_rowcol(i) = -INFINITY; |
| 226 | + } |
| 227 | + } |
| 228 | + } |
| 229 | + |
| 230 | +}; |
| 231 | + |
| 232 | +} // namespace flash_fna |
| 233 | +} // namespace cuda |
| 234 | +} // namespace natten |
0 commit comments