@@ -177,35 +177,28 @@ void kernel_main() {
177177 return ;
178178 }
179179
180- // Determine which children actually have data (based on chunk allocation)
181- // A child at core_num has data if core_num < k_num_chunks
182- uint32_t actual_num_children = 0 ;
183- uint32_t actual_children_per_round[MAX_TREE_REDUCTION_ROUNDS];
184- uint32_t actual_my_active_rounds = 0 ;
180+ // Determine which children actually participate in reduction (based on chunk allocation)
181+ // A child at core_num is active or has data if core_num < k_num_chunks
182+ // E.g. k_num_chunks = 2, num_cores_per_head = 4
183+ // | core 0 | core 1 | core 2 | core 3 |
184+ // chunk 0 chunk 1 NA NA
185+ // core 0 would have core 1 and core 2 as children, but only core 1 is active to perform reduction with
186+ uint32_t num_active_children = 0 ;
187+ uint32_t active_children_per_round[MAX_TREE_REDUCTION_ROUNDS];
188+ uint32_t num_active_rounds = 0 ;
185189
186190 for (uint32_t r = 0 ; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
187191 uint32_t child_id = children_per_round[r];
188192 if (child_id != UINT32_MAX && child_id < k_num_chunks) {
189193 // This child has data
190- actual_children_per_round [r] = child_id;
191- actual_num_children ++;
192- actual_my_active_rounds = r + 1 ;
194+ active_children_per_round [r] = child_id;
195+ num_active_children ++;
196+ num_active_rounds = r + 1 ;
193197 } else {
194- actual_children_per_round [r] = UINT32_MAX;
198+ active_children_per_round [r] = UINT32_MAX;
195199 }
196200 }
197201
198- // Determine if we have an actual parent (parent must have data too, but root always has data)
199- // Actually, we only need to check if WE should send - parent will handle receiving
200- // We send if we have a parent AND we have data (which we do if we reach here)
201- const bool should_send_to_parent = has_parent;
202-
203- // Get number of worker cores to wait for
204- uint32_t num_cores_to_wait = num_cores_per_head - 1 ;
205- if (num_cores_per_head > k_num_chunks) {
206- num_cores_to_wait = k_num_chunks - 1 ;
207- }
208-
209202 // We tilize input Q if it is in ROW MAJOR layout
210203 if constexpr (tilize_q) {
211204 compute_kernel_hw_startup (cb_q_rm, cb_q_in);
@@ -492,7 +485,7 @@ void kernel_main() {
492485
493486 // After FA loop completes, prepare buffers for tree reduction or output
494487 // Results are in: cb_out_accumulate_im (O), cb_cur_max (M), cb_cur_sum (L)
495- if (actual_num_children > 0 || should_send_to_parent ) {
488+ if (num_active_children > 0 || has_parent ) {
496489 // Tree reduction will happen - move cur to prev buffers
497490 reconfig_data_format (cb_cur_max, cb_cur_max);
498491 pack_reconfig_data_format (cb_prev_max);
@@ -523,10 +516,10 @@ void kernel_main() {
523516 // - cb_prev_max: local M (max of logits)
524517 // - cb_prev_sum: local L (sum of exp)
525518 // Only receive from children that actually have data
526- if (actual_num_children > 0 ) {
519+ if (num_active_children > 0 ) {
527520 // Iterate through each round and receive from child if one exists AND has data
528- for (uint32_t round = 0 ; round < actual_my_active_rounds ; ++round) {
529- uint32_t child_id = actual_children_per_round [round];
521+ for (uint32_t round = 0 ; round < num_active_rounds ; ++round) {
522+ uint32_t child_id = active_children_per_round [round];
530523 if (child_id != UINT32_MAX) {
531524 // Writer kernel handles the semaphore wait and data transfer to cb_m_in, cb_l_in, cb_out_o
532525 // Data arrives in order: l, m, o
@@ -584,7 +577,7 @@ void kernel_main() {
584577
585578 // Select the correct sum buffer based on whether tree reduction happened
586579 // If tree reduction happened, sum is in cb_prev_sum; otherwise it's in cb_cur_sum
587- uint32_t sum_cb = (actual_num_children > 0 ) ? cb_prev_sum : cb_cur_sum;
580+ uint32_t sum_cb = (num_active_children > 0 ) ? cb_prev_sum : cb_cur_sum;
588581
589582 /* SUM = 1.0 / SUM */
590583 reconfig_data_format (sum_cb, sum_cb);
@@ -593,7 +586,7 @@ void kernel_main() {
593586 // Handle attention sink here
594587 if constexpr (use_attention_sink) {
595588 // Use appropriate max buffer based on tree reduction
596- uint32_t max_cb_for_sink = (actual_num_children > 0 ) ? cb_prev_max : cb_cur_max;
589+ uint32_t max_cb_for_sink = (num_active_children > 0 ) ? cb_prev_max : cb_cur_max;
597590
598591 // m_new
599592 max_block<vector_mode>(cb_attention_sink, max_cb_for_sink, cb_cur_max, Sq_chunk_t);
@@ -632,16 +625,16 @@ void kernel_main() {
632625 // Pop the max buffer that still has data
633626 if constexpr (use_attention_sink) {
634627 // In attention sink path:
635- // - If actual_num_children > 0: max_cb_for_sink = cb_prev_max, cb_cur_max was popped
628+ // - If num_active_children > 0: max_cb_for_sink = cb_prev_max, cb_cur_max was popped
636629 // So we need to pop cb_prev_max
637- // - If actual_num_children == 0: max_cb_for_sink = cb_cur_max, cb_cur_max was popped
630+ // - If num_active_children == 0: max_cb_for_sink = cb_cur_max, cb_cur_max was popped
638631 // Nothing left to pop
639- if (actual_num_children > 0 ) {
632+ if (num_active_children > 0 ) {
640633 cb_pop_front (cb_prev_max, Sq_chunk_t);
641634 }
642635 } else {
643636 // No attention sink: the max buffer was never popped
644- uint32_t max_cb = (actual_num_children > 0 ) ? cb_prev_max : cb_cur_max;
637+ uint32_t max_cb = (num_active_children > 0 ) ? cb_prev_max : cb_cur_max;
645638 cb_pop_front (max_cb, Sq_chunk_t);
646639 }
647640
@@ -672,7 +665,7 @@ void kernel_main() {
672665 move_block<true >(cb_out_accumulate_im, cb_out_final, out_chunk_tiles);
673666 }
674667
675- } else if (should_send_to_parent ) {
668+ } else if (has_parent ) {
676669 // Non-root node with parent: send intermediate results
677670 // We have data (checked at function start), so send it
678671 // After tree reduction (if any), results are in:
0 commit comments