55#include < stdint.h>
66#include " api/dataflow/dataflow_api.h"
77#include < vector>
8+ #include " api/debug/assert.h"
89
910#include " ttnn/operations/transformer/sdpa_decode/device/kernels/rt_args_common.hpp"
1011#include " dataflow_common.hpp"
@@ -97,7 +98,6 @@ void kernel_main() {
9798 noc_async_read (tensor_index_noc_addr, index_cb_wr_ptr, index_stick_size_B);
9899 noc_async_read_barrier ();
99100 }
100-
101101 cb_push_back (cb_index_id, 1 );
102102 volatile tt_l1_ptr uint32_t * index_ptr = reinterpret_cast <volatile tt_l1_ptr uint32_t *>(index_cb_wr_ptr);
103103 cur_pos = index_ptr[cur_batch / q_heads_parallel_factor];
@@ -113,13 +113,14 @@ void kernel_main() {
113113 auto k_chunk_size_dynamic = Sk_chunk_t_dynamic * tt::constants::TILE_HEIGHT;
114114
115115 // Sequence length assignment
116- auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end, window_start_unaligned, window_start_chunk] = get_runtime_args (
117- cur_pos,
118- cur_batch,
119- core_num_in_reduce,
120- num_cores_per_head,
121- k_chunk_size_dynamic,
122- sliding_window_size > 0 ? std::optional<uint32_t >(sliding_window_size) : std::nullopt );
116+ auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end, window_start_unaligned, window_start_chunk] =
117+ get_workload_for_core (
118+ cur_pos,
119+ cur_batch,
120+ core_num_in_reduce,
121+ num_cores_per_head,
122+ k_chunk_size_dynamic,
123+ sliding_window_size > 0 ? std::optional<uint32_t >(sliding_window_size) : std::nullopt );
123124
124125 if (k_chunk_start == k_chunk_end) {
125126 return ; // early exit because no computes needs to be done
@@ -154,7 +155,9 @@ void kernel_main() {
154155 constexpr uint32_t barrier_threshold = get_barrier_read_threshold<q_tile_bytes, num_cores>();
155156 uint32_t barrier_count = 0 ;
156157
157- // First, read Q entirely, it could be interleaved or sharded
158+ // Read Q entirely - always read into cb_q_in
159+ // When tilize_q is true, compute will tilize in-place back to cb_q_in
160+ // When tilize_q is false, Q is already tilized
158161 uint32_t q_batch_offset = cur_batch * q_chunk_tiles;
159162
160163 if constexpr (is_q_sharded) {
@@ -273,33 +276,22 @@ void kernel_main() {
273276 {
274277 // Read K chunk in row-major order (to simplify page mapping). Write tiles to CB in transposed
275278 // order.
276- cb_reserve_back (cb_k_in, k_chunk_tiles);
277- uint32_t k_write_ptr = get_write_ptr (cb_k_in);
278- k_base_read_ptr = get_noc_addr (k_write_ptr);
279- barrier_count = 0 ;
280- for (uint32_t row = 0 ; row < Sk_chunk_t_dynamic; ++row) {
281- uint32_t k_write_ptr_col = k_write_ptr + row * k_tile_bytes;
282- uint32_t virtual_k_tile_row_num = k_chunk_start_row_num + row;
283-
284- uint32_t physical_k_tile_id =
285- (is_page_table_sharded)
286- ? virtual_seq_tile_id_to_physical_tile_id<uint16_t , num_kv_heads, block_size_t , DHt>(
287- virtual_k_tile_row_num, cur_head, page_table_ptr_u16)
288- : virtual_seq_tile_id_to_physical_tile_id<uint32_t , num_kv_heads, block_size_t , DHt>(
289- virtual_k_tile_row_num, cur_head, page_table_ptr_u32);
290- for (uint32_t col = 0 ; col < DHt; ++col) {
291- noc_async_read_tile (physical_k_tile_id, k_reader, k_write_ptr_col);
292- physical_k_tile_id += 1 ; // Go to next tile in row
293- k_write_ptr_col += Sk_chunk_t_dynamic * k_tile_bytes; // Go to next column in CB
294-
295- if (++barrier_count == barrier_threshold) {
296- noc_async_read_barrier ();
297- barrier_count = 0 ;
298- }
299- }
300- }
301- noc_async_read_barrier ();
302- cb_push_back (cb_k_in, k_chunk_tiles);
279+ k_base_read_ptr = read_k<
280+ cb_k_in,
281+ DHt,
282+ num_kv_heads,
283+ block_size_t ,
284+ k_tile_bytes,
285+ barrier_threshold,
286+ is_page_table_sharded>(
287+ k_chunk_tiles,
288+ cur_head,
289+ Sk_chunk_t_dynamic,
290+ k_chunk_start_row_num,
291+ k_reader,
292+ page_table_ptr_u16,
293+ page_table_ptr_u32,
294+ barrier_count);
303295 }
304296
305297 if constexpr (use_attention_mask) {
@@ -308,60 +300,26 @@ void kernel_main() {
308300 }
309301
310302 {
311- if constexpr (reuse_k) {
312- // Read V chunk (tranpose of K), from K's L1 buffer
313- cb_reserve_back (cb_v_in, v_chunk_tiles);
314- uint32_t v_write_ptr = get_write_ptr (cb_v_in);
315- uint64_t k_read_ptr = k_base_read_ptr;
316-
317- for (uint32_t row = 0 ; row < Sk_chunk_t_dynamic; ++row) { // Row of V
318- k_read_ptr = k_base_read_ptr + row * k_tile_bytes; // Increment across K's Col
319-
320- for (uint32_t col = 0 ; col < vDHt; ++col) { // Col of V
321- noc_async_read (k_read_ptr, v_write_ptr, v_tile_bytes);
322-
323- v_write_ptr += v_tile_bytes;
324- k_read_ptr += Sk_chunk_t_dynamic * k_tile_bytes; // Strid across K's width
325- }
326- }
327- } else {
328- // Read V chunk in row major order, write in row-major order
329- // V is an independent tensor with its own layout (width = vDHt, not DHt)
330- cb_reserve_back (cb_v_in, v_chunk_tiles);
331- uint32_t v_write_ptr = get_write_ptr (cb_v_in);
332- barrier_count = 0 ;
333-
334- for (uint32_t row = 0 ; row < Sk_chunk_t_dynamic; ++row) {
335- uint32_t virtual_v_tile_row_num = k_chunk_start_row_num + row;
336- // Use vDHt for V tensor's width since V is independent
337- uint32_t physical_v_tile_id =
338- (is_page_table_sharded)
339- ? virtual_seq_tile_id_to_physical_tile_id<
340- uint16_t ,
341- num_kv_heads,
342- block_size_t ,
343- vDHt>(virtual_v_tile_row_num, cur_head, page_table_ptr_u16)
344- : virtual_seq_tile_id_to_physical_tile_id<
345- uint32_t ,
346- num_kv_heads,
347- block_size_t ,
348- vDHt>(virtual_v_tile_row_num, cur_head, page_table_ptr_u32);
349- for (uint32_t col = 0 ; col < vDHt; ++col) {
350- noc_async_read_tile (physical_v_tile_id, v_reader, v_write_ptr);
351- physical_v_tile_id += 1 ;
352- v_write_ptr += v_tile_bytes;
353-
354- if (++barrier_count == barrier_threshold) {
355- noc_async_read_barrier ();
356- barrier_count = 0 ;
357- }
358- }
359- // No padding to skip - V is an independent tensor with contiguous layout
360- }
361- }
362-
363- noc_async_read_barrier ();
364- cb_push_back (cb_v_in, v_chunk_tiles);
303+ // Read V chunk - either from DRAM or from K's L1 buffer (transpose) when reuse_k is true
304+ read_v<
305+ cb_v_in,
306+ vDHt,
307+ num_kv_heads,
308+ block_size_t ,
309+ v_tile_bytes,
310+ barrier_threshold,
311+ is_page_table_sharded,
312+ reuse_k>(
313+ v_chunk_tiles,
314+ cur_head,
315+ Sk_chunk_t_dynamic,
316+ k_chunk_start_row_num,
317+ v_reader,
318+ page_table_ptr_u16,
319+ page_table_ptr_u32,
320+ barrier_count,
321+ k_base_read_ptr,
322+ k_tile_bytes);
365323 }
366324
367325 }
0 commit comments