Skip to content

Commit dc571f9

Browse files
committed
tree reduce on sdpa decode
1 parent 3d6d9a0 commit dc571f9

File tree

7 files changed

+702
-237
lines changed

7 files changed

+702
-237
lines changed

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp

Lines changed: 162 additions & 63 deletions
Large diffs are not rendered by default.

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,118 @@ uint32_t write_partial_tiles_to_memory(
478478
/******************************************************************************
479479
* Reader Kernel Specific Functions *
480480
******************************************************************************/
481+
template <
482+
uint32_t cb_k_in,
483+
uint32_t DHt,
484+
uint32_t num_kv_heads,
485+
uint32_t block_size_t,
486+
uint32_t k_tile_bytes,
487+
uint32_t barrier_threshold,
488+
bool is_page_table_sharded,
489+
typename KReaderType>
490+
uint64_t read_k(
491+
uint32_t k_chunk_tiles,
492+
uint32_t cur_head,
493+
uint32_t Sk_chunk_t_dynamic,
494+
uint32_t k_chunk_start_row_num,
495+
const KReaderType& k_reader,
496+
volatile tt_l1_ptr uint16_t* page_table_ptr_u16,
497+
volatile tt_l1_ptr uint32_t* page_table_ptr_u32,
498+
uint32_t& barrier_count) {
499+
cb_reserve_back(cb_k_in, k_chunk_tiles);
500+
uint32_t k_write_ptr = get_write_ptr(cb_k_in);
501+
uint64_t k_base_read_ptr = get_noc_addr(k_write_ptr);
502+
barrier_count = 0;
503+
for (uint32_t row = 0; row < Sk_chunk_t_dynamic; ++row) {
504+
uint32_t k_write_ptr_col = k_write_ptr + row * k_tile_bytes;
505+
uint32_t virtual_k_tile_row_num = k_chunk_start_row_num + row;
506+
uint32_t physical_k_tile_id =
507+
(is_page_table_sharded)
508+
? virtual_seq_tile_id_to_physical_tile_id<uint16_t, num_kv_heads, block_size_t, DHt>(
509+
virtual_k_tile_row_num, cur_head, page_table_ptr_u16)
510+
: virtual_seq_tile_id_to_physical_tile_id<uint32_t, num_kv_heads, block_size_t, DHt>(
511+
virtual_k_tile_row_num, cur_head, page_table_ptr_u32);
512+
for (uint32_t col = 0; col < DHt; ++col) {
513+
noc_async_read_tile(physical_k_tile_id, k_reader, k_write_ptr_col);
514+
physical_k_tile_id += 1; // Go to next tile in row
515+
k_write_ptr_col += Sk_chunk_t_dynamic * k_tile_bytes; // Go to next column in CB
516+
517+
if (++barrier_count == barrier_threshold) {
518+
noc_async_read_barrier();
519+
barrier_count = 0;
520+
}
521+
}
522+
}
523+
noc_async_read_barrier();
524+
cb_push_back(cb_k_in, k_chunk_tiles);
525+
return k_base_read_ptr;
526+
}
527+
528+
template <
529+
uint32_t cb_v_in,
530+
uint32_t vDHt,
531+
uint32_t num_kv_heads,
532+
uint32_t block_size_t,
533+
uint32_t v_tile_bytes,
534+
uint32_t barrier_threshold,
535+
bool is_page_table_sharded,
536+
bool reuse_k,
537+
typename VReaderType>
538+
void read_v(
539+
uint32_t v_chunk_tiles,
540+
uint32_t cur_head,
541+
uint32_t Sk_chunk_t_dynamic,
542+
uint32_t k_chunk_start_row_num,
543+
const VReaderType& v_reader,
544+
volatile tt_l1_ptr uint16_t* page_table_ptr_u16,
545+
volatile tt_l1_ptr uint32_t* page_table_ptr_u32,
546+
uint32_t& barrier_count,
547+
uint64_t k_base_read_ptr = 0,
548+
uint32_t k_tile_bytes = 0) {
549+
cb_reserve_back(cb_v_in, v_chunk_tiles);
550+
uint32_t v_write_ptr = get_write_ptr(cb_v_in);
551+
552+
if constexpr (reuse_k) {
553+
// Read V chunk (transpose of K), from K's L1 buffer
554+
uint64_t k_read_ptr = k_base_read_ptr;
555+
for (uint32_t row = 0; row < Sk_chunk_t_dynamic; ++row) { // Row of V
556+
k_read_ptr = k_base_read_ptr + row * k_tile_bytes; // Increment across K's Col
557+
for (uint32_t col = 0; col < vDHt; ++col) { // Col of V
558+
noc_async_read(k_read_ptr, v_write_ptr, v_tile_bytes);
559+
v_write_ptr += v_tile_bytes;
560+
k_read_ptr += Sk_chunk_t_dynamic * k_tile_bytes; // Stride across K's width
561+
}
562+
}
563+
noc_async_read_barrier();
564+
} else {
565+
// Read V chunk in row major order, write in row-major order
566+
// V is an independent tensor with its own layout (width = vDHt, not DHt)
567+
barrier_count = 0;
568+
for (uint32_t row = 0; row < Sk_chunk_t_dynamic; ++row) {
569+
uint32_t virtual_v_tile_row_num = k_chunk_start_row_num + row;
570+
// Use vDHt for V tensor's width since V is independent
571+
uint32_t physical_v_tile_id =
572+
(is_page_table_sharded)
573+
? virtual_seq_tile_id_to_physical_tile_id<uint16_t, num_kv_heads, block_size_t, vDHt>(
574+
virtual_v_tile_row_num, cur_head, page_table_ptr_u16)
575+
: virtual_seq_tile_id_to_physical_tile_id<uint32_t, num_kv_heads, block_size_t, vDHt>(
576+
virtual_v_tile_row_num, cur_head, page_table_ptr_u32);
577+
for (uint32_t col = 0; col < vDHt; ++col) {
578+
noc_async_read_tile(physical_v_tile_id, v_reader, v_write_ptr);
579+
physical_v_tile_id += 1;
580+
v_write_ptr += v_tile_bytes;
581+
582+
if (++barrier_count == barrier_threshold) {
583+
noc_async_read_barrier();
584+
barrier_count = 0;
585+
}
586+
}
587+
// No padding to skip - V is an independent tensor with contiguous layout
588+
}
589+
noc_async_read_barrier();
590+
}
591+
cb_push_back(cb_v_in, v_chunk_tiles);
592+
}
481593

482594
template <
483595
uint32_t DHt,

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp

Lines changed: 48 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stdint.h>
66
#include "api/dataflow/dataflow_api.h"
77
#include <vector>
8+
#include "api/debug/assert.h"
89

910
#include "ttnn/operations/transformer/sdpa_decode/device/kernels/rt_args_common.hpp"
1011
#include "dataflow_common.hpp"
@@ -97,7 +98,6 @@ void kernel_main() {
9798
noc_async_read(tensor_index_noc_addr, index_cb_wr_ptr, index_stick_size_B);
9899
noc_async_read_barrier();
99100
}
100-
101101
cb_push_back(cb_index_id, 1);
102102
volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(index_cb_wr_ptr);
103103
cur_pos = index_ptr[cur_batch / q_heads_parallel_factor];
@@ -113,13 +113,14 @@ void kernel_main() {
113113
auto k_chunk_size_dynamic = Sk_chunk_t_dynamic * tt::constants::TILE_HEIGHT;
114114

115115
// Sequence length assignment
116-
auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end, window_start_unaligned, window_start_chunk] = get_runtime_args(
117-
cur_pos,
118-
cur_batch,
119-
core_num_in_reduce,
120-
num_cores_per_head,
121-
k_chunk_size_dynamic,
122-
sliding_window_size > 0 ? std::optional<uint32_t>(sliding_window_size) : std::nullopt);
116+
auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end, window_start_unaligned, window_start_chunk] =
117+
get_workload_for_core(
118+
cur_pos,
119+
cur_batch,
120+
core_num_in_reduce,
121+
num_cores_per_head,
122+
k_chunk_size_dynamic,
123+
sliding_window_size > 0 ? std::optional<uint32_t>(sliding_window_size) : std::nullopt);
123124

124125
if (k_chunk_start == k_chunk_end) {
125126
return; // early exit because no computes needs to be done
@@ -154,7 +155,9 @@ void kernel_main() {
154155
constexpr uint32_t barrier_threshold = get_barrier_read_threshold<q_tile_bytes, num_cores>();
155156
uint32_t barrier_count = 0;
156157

157-
// First, read Q entirely, it could be interleaved or sharded
158+
// Read Q entirely - always read into cb_q_in
159+
// When tilize_q is true, compute will tilize in-place back to cb_q_in
160+
// When tilize_q is false, Q is already tilized
158161
uint32_t q_batch_offset = cur_batch * q_chunk_tiles;
159162

160163
if constexpr (is_q_sharded) {
@@ -273,33 +276,22 @@ void kernel_main() {
273276
{
274277
// Read K chunk in row-major order (to simplify page mapping). Write tiles to CB in transposed
275278
// order.
276-
cb_reserve_back(cb_k_in, k_chunk_tiles);
277-
uint32_t k_write_ptr = get_write_ptr(cb_k_in);
278-
k_base_read_ptr = get_noc_addr(k_write_ptr);
279-
barrier_count = 0;
280-
for (uint32_t row = 0; row < Sk_chunk_t_dynamic; ++row) {
281-
uint32_t k_write_ptr_col = k_write_ptr + row * k_tile_bytes;
282-
uint32_t virtual_k_tile_row_num = k_chunk_start_row_num + row;
283-
284-
uint32_t physical_k_tile_id =
285-
(is_page_table_sharded)
286-
? virtual_seq_tile_id_to_physical_tile_id<uint16_t, num_kv_heads, block_size_t, DHt>(
287-
virtual_k_tile_row_num, cur_head, page_table_ptr_u16)
288-
: virtual_seq_tile_id_to_physical_tile_id<uint32_t, num_kv_heads, block_size_t, DHt>(
289-
virtual_k_tile_row_num, cur_head, page_table_ptr_u32);
290-
for (uint32_t col = 0; col < DHt; ++col) {
291-
noc_async_read_tile(physical_k_tile_id, k_reader, k_write_ptr_col);
292-
physical_k_tile_id += 1; // Go to next tile in row
293-
k_write_ptr_col += Sk_chunk_t_dynamic * k_tile_bytes; // Go to next column in CB
294-
295-
if (++barrier_count == barrier_threshold) {
296-
noc_async_read_barrier();
297-
barrier_count = 0;
298-
}
299-
}
300-
}
301-
noc_async_read_barrier();
302-
cb_push_back(cb_k_in, k_chunk_tiles);
279+
k_base_read_ptr = read_k<
280+
cb_k_in,
281+
DHt,
282+
num_kv_heads,
283+
block_size_t,
284+
k_tile_bytes,
285+
barrier_threshold,
286+
is_page_table_sharded>(
287+
k_chunk_tiles,
288+
cur_head,
289+
Sk_chunk_t_dynamic,
290+
k_chunk_start_row_num,
291+
k_reader,
292+
page_table_ptr_u16,
293+
page_table_ptr_u32,
294+
barrier_count);
303295
}
304296

305297
if constexpr (use_attention_mask) {
@@ -308,60 +300,26 @@ void kernel_main() {
308300
}
309301

310302
{
311-
if constexpr (reuse_k) {
312-
// Read V chunk (tranpose of K), from K's L1 buffer
313-
cb_reserve_back(cb_v_in, v_chunk_tiles);
314-
uint32_t v_write_ptr = get_write_ptr(cb_v_in);
315-
uint64_t k_read_ptr = k_base_read_ptr;
316-
317-
for (uint32_t row = 0; row < Sk_chunk_t_dynamic; ++row) { // Row of V
318-
k_read_ptr = k_base_read_ptr + row * k_tile_bytes; // Increment across K's Col
319-
320-
for (uint32_t col = 0; col < vDHt; ++col) { // Col of V
321-
noc_async_read(k_read_ptr, v_write_ptr, v_tile_bytes);
322-
323-
v_write_ptr += v_tile_bytes;
324-
k_read_ptr += Sk_chunk_t_dynamic * k_tile_bytes; // Strid across K's width
325-
}
326-
}
327-
} else {
328-
// Read V chunk in row major order, write in row-major order
329-
// V is an independent tensor with its own layout (width = vDHt, not DHt)
330-
cb_reserve_back(cb_v_in, v_chunk_tiles);
331-
uint32_t v_write_ptr = get_write_ptr(cb_v_in);
332-
barrier_count = 0;
333-
334-
for (uint32_t row = 0; row < Sk_chunk_t_dynamic; ++row) {
335-
uint32_t virtual_v_tile_row_num = k_chunk_start_row_num + row;
336-
// Use vDHt for V tensor's width since V is independent
337-
uint32_t physical_v_tile_id =
338-
(is_page_table_sharded)
339-
? virtual_seq_tile_id_to_physical_tile_id<
340-
uint16_t,
341-
num_kv_heads,
342-
block_size_t,
343-
vDHt>(virtual_v_tile_row_num, cur_head, page_table_ptr_u16)
344-
: virtual_seq_tile_id_to_physical_tile_id<
345-
uint32_t,
346-
num_kv_heads,
347-
block_size_t,
348-
vDHt>(virtual_v_tile_row_num, cur_head, page_table_ptr_u32);
349-
for (uint32_t col = 0; col < vDHt; ++col) {
350-
noc_async_read_tile(physical_v_tile_id, v_reader, v_write_ptr);
351-
physical_v_tile_id += 1;
352-
v_write_ptr += v_tile_bytes;
353-
354-
if (++barrier_count == barrier_threshold) {
355-
noc_async_read_barrier();
356-
barrier_count = 0;
357-
}
358-
}
359-
// No padding to skip - V is an independent tensor with contiguous layout
360-
}
361-
}
362-
363-
noc_async_read_barrier();
364-
cb_push_back(cb_v_in, v_chunk_tiles);
303+
// Read V chunk - either from DRAM or from K's L1 buffer (transpose) when reuse_k is true
304+
read_v<
305+
cb_v_in,
306+
vDHt,
307+
num_kv_heads,
308+
block_size_t,
309+
v_tile_bytes,
310+
barrier_threshold,
311+
is_page_table_sharded,
312+
reuse_k>(
313+
v_chunk_tiles,
314+
cur_head,
315+
Sk_chunk_t_dynamic,
316+
k_chunk_start_row_num,
317+
v_reader,
318+
page_table_ptr_u16,
319+
page_table_ptr_u32,
320+
barrier_count,
321+
k_base_read_ptr,
322+
k_tile_bytes);
365323
}
366324

367325
}

0 commit comments

Comments
 (0)