Skip to content

Commit 69cf16d

Browse files
committed
fix hang in prev sum
1 parent 5b70696 commit 69cf16d

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,12 @@ void kernel_main() {
254254

255255
// Loop through all heads assigned to core
256256
for (uint32_t cur_head_work = 0; cur_head_work < num_heads_per_core; ++cur_head_work) {
257+
// Reset ping-pong buffer assignments at the start of each head iteration
258+
cb_cur_max = cb_max_1;
259+
cb_prev_max = cb_max_2;
260+
cb_cur_sum = cb_sum_1;
261+
cb_prev_sum = cb_sum_2;
262+
257263
/******************************************************************************
258264
* FLASH ATTENTION LOOP *
259265
******************************************************************************/
@@ -528,7 +534,6 @@ void kernel_main() {
528534
// Combine child with existing local/accumulated data
529535
// Move child's L to cb_prev_sum_2 for correction
530536
move_block<true>(cb_l_in, cb_prev_sum_2, Sq_chunk_t);
531-
532537
// Fused Softmax Correction
533538
// * Fused Correction is a fused operation that performs the following steps:
534539
// * 1. CUR_MAX = max(PREV_MAX, WORKER_MAX)
@@ -621,6 +626,25 @@ void kernel_main() {
621626
mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, sum_cb);
622627
pack_reconfig_data_format(cb_out_final);
623628

629+
// Note: sum_cb was already consumed (popped) by mul_block_bcast_cols_inplace above,
630+
// so we don't need to pop it again here.
631+
632+
// Pop the max buffer that still has data
633+
if constexpr (use_attention_sink) {
634+
// In attention sink path:
635+
// - If actual_num_children > 0: max_cb_for_sink = cb_prev_max, cb_cur_max was popped
636+
// So we need to pop cb_prev_max
637+
// - If actual_num_children == 0: max_cb_for_sink = cb_cur_max, cb_cur_max was popped
638+
// Nothing left to pop
639+
if (actual_num_children > 0) {
640+
cb_pop_front(cb_prev_max, Sq_chunk_t);
641+
}
642+
} else {
643+
// No attention sink: the max buffer was never popped
644+
uint32_t max_cb = (actual_num_children > 0) ? cb_prev_max : cb_cur_max;
645+
cb_pop_front(max_cb, Sq_chunk_t);
646+
}
647+
624648
// Untilize output to ROW MAJOR if input Q was also ROW MAJOR
625649
if constexpr (untilize_output) {
626650
// Conditionally use pack_untilize or untilize
@@ -663,10 +687,6 @@ void kernel_main() {
663687
// Move L to output CB
664688
move_block<true>(cb_prev_sum, cb_out_l, Sq_chunk_t);
665689
}
666-
667-
// Free up prev buffers if we used them
668-
cb_pop_front(cb_prev_max, Sq_chunk_t);
669-
cb_pop_front(cb_prev_sum, Sq_chunk_t);
670690
}
671691

672692
// Free up cb_q_in after Q chunks

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ SdpaDecodeProgramFactory::cached_program_t SdpaDecodeProgramFactory::create(
302302
uint32_t out_im_tiles = PNHt * vDHt;
303303
uint32_t out0_t = PNHt * vDHt;
304304
uint32_t scale_tiles = 1;
305-
uint32_t statistics_tiles = PNHt; // Single column of values in each iteration
305+
uint32_t statistics_tiles = PNHt * 2; // Single column of values in each iteration
306306

307307
// log all values
308308
log_debug(tt::LogOp, "q_tiles: {}", q_tiles);

0 commit comments

Comments
 (0)