@@ -254,6 +254,12 @@ void kernel_main() {
254254
255255 // Loop through all heads assigned to core
256256 for (uint32_t cur_head_work = 0 ; cur_head_work < num_heads_per_core; ++cur_head_work) {
257+ // Reset ping-pong buffer assignments at the start of each head iteration
258+ cb_cur_max = cb_max_1;
259+ cb_prev_max = cb_max_2;
260+ cb_cur_sum = cb_sum_1;
261+ cb_prev_sum = cb_sum_2;
262+
257263 /* *****************************************************************************
258264 * FLASH ATTENTION LOOP *
259265 ******************************************************************************/
@@ -528,7 +534,6 @@ void kernel_main() {
528534 // Combine child with existing local/accumulated data
529535 // Move child's L to cb_prev_sum_2 for correction
530536 move_block<true >(cb_l_in, cb_prev_sum_2, Sq_chunk_t);
531-
532537 // Fused Softmax Correction
533538 // * Fused Correction is a fused operation that performs the following steps:
534539 // * 1. CUR_MAX = max(PREV_MAX, WORKER_MAX)
@@ -621,6 +626,25 @@ void kernel_main() {
621626 mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, sum_cb);
622627 pack_reconfig_data_format (cb_out_final);
623628
629+ // Note: sum_cb was already consumed (popped) by mul_block_bcast_cols_inplace above,
630+ // so we don't need to pop it again here.
631+
632+ // Pop the max buffer that still has data
633+ if constexpr (use_attention_sink) {
634+ // In attention sink path:
635+ // - If actual_num_children > 0: max_cb_for_sink = cb_prev_max, cb_cur_max was popped
636+ // So we need to pop cb_prev_max
637+ // - If actual_num_children == 0: max_cb_for_sink = cb_cur_max, cb_cur_max was popped
638+ // Nothing left to pop
639+ if (actual_num_children > 0 ) {
640+ cb_pop_front (cb_prev_max, Sq_chunk_t);
641+ }
642+ } else {
643+ // No attention sink: the max buffer was never popped
644+ uint32_t max_cb = (actual_num_children > 0 ) ? cb_prev_max : cb_cur_max;
645+ cb_pop_front (max_cb, Sq_chunk_t);
646+ }
647+
624648 // Untilize output to ROW MAJOR if input Q was also ROW MAJOR
625649 if constexpr (untilize_output) {
626650 // Conditionally use pack_untilize or untilize
@@ -663,10 +687,6 @@ void kernel_main() {
663687 // Move L to output CB
664688 move_block<true >(cb_prev_sum, cb_out_l, Sq_chunk_t);
665689 }
666-
667- // Free up prev buffers if we used them
668- cb_pop_front (cb_prev_max, Sq_chunk_t);
669- cb_pop_front (cb_prev_sum, Sq_chunk_t);
670690 }
671691
672692 // Free up cb_q_in after Q chunks
0 commit comments