66
77#define REDUCE_OP (PoolType::MAX)
88#define REDUCE_DIM (ReduceDim::REDUCE_ROW)
9+ #define MAX_TREE_REDUCTION_ROUNDS 6
910
1011#include " compute_kernel_api.h"
1112#include " compute_kernel_api/eltwise_binary.h"
@@ -61,6 +62,7 @@ void kernel_main() {
6162 constexpr bool use_half_tile = get_compile_time_arg_val (27 );
6263 constexpr uint32_t scale_fp32 = get_compile_time_arg_val (28 );
6364 constexpr uint32_t sliding_window_size = get_compile_time_arg_val (29 );
65+ constexpr uint32_t num_tree_reduction_rounds = get_compile_time_arg_val (30 );
6466
6567 constexpr uint32_t q_chunk_tiles = Sq_chunk_t * DHt;
6668 constexpr uint32_t out_chunk_tiles = Sq_chunk_t * vDHt;
@@ -109,6 +111,20 @@ void kernel_main() {
109111 const uint32_t core_num_in_output = get_arg_val<uint32_t >(arg_idx++);
110112 const uint32_t cur_pos_arg = get_arg_val<uint32_t >(arg_idx++);
111113
114+ // Tree reduction runtime arguments
115+ const bool is_tree_root = get_arg_val<uint32_t >(arg_idx++) == 1 ;
116+ const uint32_t parent_core_in_group = get_arg_val<uint32_t >(arg_idx++);
117+ const uint32_t send_at_round = get_arg_val<uint32_t >(arg_idx++);
118+ const uint32_t num_children = get_arg_val<uint32_t >(arg_idx++);
119+ const uint32_t my_active_rounds = get_arg_val<uint32_t >(arg_idx++);
120+ const bool has_parent = parent_core_in_group != UINT32_MAX;
121+
122+ // Read children_per_round array
123+ uint32_t children_per_round[MAX_TREE_REDUCTION_ROUNDS];
124+ for (uint32_t r = 0 ; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
125+ children_per_round[r] = get_arg_val<uint32_t >(arg_idx++);
126+ }
127+
112128 // Idle core
113129 // get_arg_val<uint32_t>(0) can go from 0-63 for the core_num; for active cores 65 is out of range so 65 indicates
114130 // an idle_core
@@ -130,7 +146,7 @@ void kernel_main() {
130146
131147 cb_wait_front (cb_index_id, 1 );
132148 cur_pos = read_tile_value (cb_index_id, 0 , cur_batch / q_heads_parallel_factor);
133- cb_pop_front (cb_index_id, 1 );
149+ // cb_pop_front(cb_index_id, 1);
134150 }
135151 if (cur_pos == UINT32_MAX) {
136152 // cur_pos of -1 indicates that the user should be skipped
@@ -150,31 +166,64 @@ void kernel_main() {
150166 num_cores_per_head,
151167 k_chunk_size_dynamic,
152168 sliding_window_size > 0 ? std::optional<uint32_t >(sliding_window_size) : std::nullopt );
153- if (k_chunk_start == k_chunk_end) {
154- return ; // early exit because no computes needs to be done
169+
170+ // Check if this core has local data to process
171+ const bool has_local_data = (k_chunk_start != k_chunk_end);
172+
173+ // Cores without data don't participate in tree reduction at all
174+ // They just exit early - no sending, no receiving
175+ if (!has_local_data) {
176+ return ;
177+ }
178+
179+ // Determine which children actually have data (based on chunk allocation)
180+ // A child at core_num has data if core_num < k_num_chunks
181+ uint32_t actual_num_children = 0 ;
182+ uint32_t actual_children_per_round[MAX_TREE_REDUCTION_ROUNDS];
183+ uint32_t actual_my_active_rounds = 0 ;
184+
185+ for (uint32_t r = 0 ; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
186+ uint32_t child_id = children_per_round[r];
187+ if (child_id != UINT32_MAX && child_id < k_num_chunks) {
188+ // This child has data
189+ actual_children_per_round[r] = child_id;
190+ actual_num_children++;
191+ actual_my_active_rounds = r + 1 ;
192+ } else {
193+ actual_children_per_round[r] = UINT32_MAX;
194+ }
155195 }
156196
197+ // Determine if we have an actual parent (parent must have data too, but root always has data)
198+ // Actually, we only need to check if WE should send - parent will handle receiving
199+ // We send if we have a parent AND we have data (which we do if we reach here)
200+ const bool should_send_to_parent = has_parent;
201+
157202 // Get number of worker cores to wait for
158203 uint32_t num_cores_to_wait = num_cores_per_head - 1 ;
159204 if (num_cores_per_head > k_num_chunks) {
160205 num_cores_to_wait = k_num_chunks - 1 ;
161206 }
162207
163- // We tilize input Q if it is in ROW MAJOR layout
164- if constexpr (tilize_q) {
165- compute_kernel_hw_startup (cb_q_rm, cb_q_in);
166- tilize_init (cb_q_rm, q_chunk_tiles, cb_q_in);
167- cb_wait_front (cb_q_rm, q_chunk_tiles);
168- cb_reserve_back (cb_q_in, q_chunk_tiles);
169- tilize_block (cb_q_rm, q_chunk_tiles, cb_q_in);
170- tilize_uninit (cb_q_rm, cb_q_in);
171- cb_push_back (cb_q_in, q_chunk_tiles);
172- cb_pop_front (cb_q_rm, q_chunk_tiles);
173- mm_init_short (cb_q_in, cb_k_in);
174- } else {
175- mm_init (cb_q_in, cb_k_in, cb_qk_im);
208+ // Only process Q and initialize matmul if we have local data
209+ // Cores that only participate in tree reduction (no local data) skip this
210+ if (has_local_data) {
211+ // We tilize input Q if it is in ROW MAJOR layout
212+ if constexpr (tilize_q) {
213+ compute_kernel_hw_startup (cb_q_rm, cb_q_in);
214+ tilize_init (cb_q_rm, q_chunk_tiles, cb_q_in);
215+ cb_wait_front (cb_q_rm, q_chunk_tiles);
216+ cb_reserve_back (cb_q_in, q_chunk_tiles);
217+ tilize_block (cb_q_rm, q_chunk_tiles, cb_q_in);
218+ tilize_uninit (cb_q_rm, cb_q_in);
219+ cb_push_back (cb_q_in, q_chunk_tiles);
220+ cb_pop_front (cb_q_rm, q_chunk_tiles);
221+ mm_init_short (cb_q_in, cb_k_in);
222+ } else {
223+ mm_init (cb_q_in, cb_k_in, cb_qk_im);
224+ }
225+ cb_wait_front (cb_q_in, q_chunk_tiles);
176226 }
177- cb_wait_front (cb_q_in, q_chunk_tiles);
178227
179228 // Define dynamic matmul configs
180229#ifdef DYNAMIC_CHUNK_SIZE
@@ -184,7 +233,6 @@ void kernel_main() {
184233 const uint32_t qk_in1_num_subblocks_dynamic = 1 ;
185234 const uint32_t out_in0_block_w_dynamic = Sk_chunk_t_dynamic;
186235 const uint32_t out_num_blocks_dynamic = 1 ;
187-
188236 const uint32_t qk_chunk_tiles_dynamic = Sq_chunk_t * Sk_chunk_t_dynamic;
189237#else
190238 constexpr uint32_t qk_subblock_h_dynamic = qk_subblock_h;
@@ -193,7 +241,6 @@ void kernel_main() {
193241 constexpr uint32_t qk_in1_num_subblocks_dynamic = qk_in1_num_subblocks;
194242 constexpr uint32_t out_in0_block_w_dynamic = out_in0_block_w;
195243 constexpr uint32_t out_num_blocks_dynamic = out_num_blocks;
196-
197244 constexpr uint32_t qk_chunk_tiles_dynamic = Sq_chunk_t * Sk_chunk_t;
198245#endif
199246
@@ -262,7 +309,8 @@ void kernel_main() {
262309 * @param out_chunk_tiles - Number of output chunk tiles
263310 */
264311 /* START OF FLASH ATTENTION LOOP */
265- {
312+ // Only run FA loop if this core has local data to process
313+ if (has_local_data) {
266314 uint32_t cb_out_mm = cb_out_accumulate_im;
267315
268316 // Loop through all K chunks
@@ -399,7 +447,6 @@ void kernel_main() {
399447 cb_out_mm = cb_out_im;
400448 } else {
401449 // When there is more than 1 chunk, we perform Lazy Softmax
402-
403450 // Reconfig register DF
404451 reconfig_data_format (cb_prev_max, cb_cur_max);
405452 pack_reconfig_data_format (cb_exp_max_diff);
@@ -428,8 +475,10 @@ void kernel_main() {
428475 add_block_inplace<true >(cb_out_accumulate_im, cb_out_im, out_chunk_tiles);
429476 }
430477
431- if (k_chunk < k_chunk_end - 1 || do_reduce) {
432- // Move intermediate sum and max values to appropriate ping pong buffers
478+ // Move intermediate sum and max values to appropriate ping pong buffers
479+ // Always move to prev buffers during FA loop - we'll handle final output later
480+ if (k_chunk < k_chunk_end - 1 ) {
481+ // More local chunks to process - move to ping-pong buffers
433482 reconfig_data_format (cb_cur_max, cb_cur_max);
434483 pack_reconfig_data_format (cb_prev_max);
435484
@@ -438,26 +487,53 @@ void kernel_main() {
438487
439488 // PREV_SUM <- CUR_SUM
440489 move_block<true >(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
441- } else {
442- // Write results OUT_ACC, CUR_MAX, CUR_SUM to designated
443- // Write o, m, l into cb_out
444- move_block<true >(cb_out_accumulate_im, cb_out_o, out_chunk_tiles);
445- move_block<true >(cb_cur_max, cb_out_m, Sq_chunk_t);
446- move_block<true >(cb_cur_sum, cb_out_l, Sq_chunk_t);
447490 }
448491 }
492+
493+ // After FA loop completes, prepare buffers for tree reduction or output
494+ // Results are in: cb_out_accumulate_im (O), cb_cur_max (M), cb_cur_sum (L)
495+ if (actual_num_children > 0 || should_send_to_parent) {
496+ // Tree reduction will happen - move cur to prev buffers
497+ reconfig_data_format (cb_cur_max, cb_cur_max);
498+ pack_reconfig_data_format (cb_prev_max);
499+ move_block<true >(cb_cur_max, cb_prev_max, Sq_chunk_t);
500+ move_block<true >(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
501+ }
502+ // If root with no children, keep in cur buffers for finalization
449503 }
450504 /* END OF FLASH ATTENTION LOOP */
451- // Perform reduction across intermediates from other cores if this is the reduction core
452- if (do_reduce) {
453- // cb_out_accumulate_im should contain o_1 (output from FA of itself's core)
454- // cb_prev_max and cb_prev_sum should contain m_1 and l_1 (max and sum of logits of itself's core)
455-
456- if (k_chunk_end - k_chunk_start < k_num_chunks) {
457- // This indicates that there are computes done by other workers.
458- // We need to wait for them and send to reducer's compute
459- // Iterate through each worker
460- for (uint32_t i = 0 ; i < num_cores_to_wait; i++) {
505+
506+ /* *****************************************************************************
507+ * TREE REDUCTION LOGIC *
508+ ******************************************************************************/
509+ /* *
510+ * Tree reduction replaces the flat worker->reducer pattern with O(log n) rounds.
511+ *
512+ * For each round r (0 to my_active_rounds-1):
513+ * - If children_per_round[r] != UINT32_MAX, receive from that child
514+ * - Combine received data with local accumulator using softmax correction
515+ *
516+ * After all receives:
517+ * - If is_tree_root: finalize (1/sum normalization) and output
518+ * - Else: output intermediate results for writer to send to parent
519+ */
520+
521+ // Tree reduction: receive from children and combine
522+ // Buffer state entering tree reduction:
523+ // - cb_out_accumulate_im: local O (output accumulator)
524+ // - cb_prev_max: local M (max of logits)
525+ // - cb_prev_sum: local L (sum of exp)
526+ // Only receive from children that actually have data
527+ if (actual_num_children > 0 ) {
528+ // Iterate through each round and receive from child if one exists AND has data
529+ for (uint32_t round = 0 ; round < actual_my_active_rounds; ++round) {
530+ uint32_t child_id = actual_children_per_round[round];
531+ if (child_id != UINT32_MAX) {
532+ // Writer kernel handles the semaphore wait and data transfer to cb_m_in, cb_l_in, cb_out_o
533+ // Data arrives in order: m, l, o
534+
535+ // Combine child with existing local/accumulated data
536+ // Move child's L to cb_prev_sum_2 for correction
461537 move_block<true >(cb_l_in, cb_prev_sum_2, Sq_chunk_t);
462538
463539 // Fused Softmax Correction
@@ -468,10 +544,9 @@ void kernel_main() {
468544 // * 4. EXP_MAX_DIFF = exp((PREV_MAX - CUR_MAX)*scale)
469545 // * 5. PREV_SUM *= EXP_MAX_DIFF
470546 // * 6. CUR_SUM = PREV_SUM_2 + PREV_SUM
471- // */
472547 correction_block<scale_fp32, vector_mode>(
473- cb_m_in, // cb worker max
474- cb_prev_sum_2, // cb worker sum
548+ cb_m_in, // cb child max
549+ cb_prev_sum_2, // cb child sum
475550 cb_cur_max,
476551 cb_prev_max,
477552 cb_cur_sum,
@@ -480,17 +555,18 @@ void kernel_main() {
480555 cb_exp_max_diff_2,
481556 Sq_chunk_t);
482557
483- // OUT_ACC_2 <- WORKER_OUT
558+ // OUT_ACC_2 <- CHILD_OUT
484559 move_block<true >(cb_out_o, cb_out_accumulate_im_2, out_chunk_tiles);
485560
486- // OUT_ACC_2 *= EXP_MAX_DIFF
487- // OUT_ACC *= EXP_MAX_DIFF_2
561+ // OUT_ACC *= EXP_MAX_DIFF (scale local accumulator)
562+ // OUT_ACC_2 *= EXP_MAX_DIFF_2 (scale child's accumulator)
488563 mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, cb_exp_max_diff);
489564 mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im_2, cb_exp_max_diff_2);
490565
491566 // OUT_ACC = OUT_ACC + OUT_ACC_2
492567 add_block_inplace<true >(cb_out_accumulate_im, cb_out_accumulate_im_2, out_chunk_tiles);
493568
569+ // Update prev buffers for next round
494570 // PREV_MAX <- CUR_MAX
495571 // PREV_SUM <- CUR_SUM
496572 cb_pop_front (cb_prev_max, Sq_chunk_t);
@@ -499,43 +575,57 @@ void kernel_main() {
499575 move_block<true >(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
500576 }
501577 }
578+ }
579+
580+ // Finalize output based on tree role
581+ if (is_tree_root) {
582+ // Root node: perform final normalization and output
583+ // Determine which sum/max buffer to use based on whether we did tree reduction
584+ // If we had children with data, results are in cb_prev_sum/cb_prev_max after tree reduction
585+ // If single core (no children with data), results are in cb_cur_sum/cb_cur_max from FA loop
502586
503- /* CUR_SUM = 1.0 / CUR_SUM */
504- cb_push_back (cb_cur_sum, Sq_chunk_t);
505- reconfig_data_format (cb_cur_sum, cb_cur_sum);
506- pack_reconfig_data_format (cb_cur_sum);
587+ // Select the correct sum buffer based on whether tree reduction happened
588+ // If tree reduction happened, sum is in cb_prev_sum; otherwise it's in cb_cur_sum
589+ uint32_t sum_cb = (actual_num_children > 0 ) ? cb_prev_sum : cb_cur_sum;
590+
591+ /* SUM = 1.0 / SUM */
592+ reconfig_data_format (sum_cb, sum_cb);
593+ pack_reconfig_data_format (sum_cb);
507594
508595 // Handle attention sink here
509596 if constexpr (use_attention_sink) {
597+ // Use appropriate max buffer based on tree reduction
598+ uint32_t max_cb_for_sink = (actual_num_children > 0 ) ? cb_prev_max : cb_cur_max;
599+
510600 // m_new
511- max_block<vector_mode>(cb_attention_sink, cb_prev_max , cb_cur_max, Sq_chunk_t);
601+ max_block<vector_mode>(cb_attention_sink, max_cb_for_sink , cb_cur_max, Sq_chunk_t);
512602
513603 // exp(m - m_new)
514- sub_exp_block<scale_fp32>(cb_prev_max , cb_cur_max, cb_exp_max_diff, Sq_chunk_t);
604+ sub_exp_block<scale_fp32>(max_cb_for_sink , cb_cur_max, cb_exp_max_diff, Sq_chunk_t);
515605
516606 // l -> l * exp(m - m_new)
517- mul_block_inplace (cb_cur_sum , cb_exp_max_diff, Sq_chunk_t);
607+ mul_block_inplace (sum_cb , cb_exp_max_diff, Sq_chunk_t);
518608
519609 // exp(sink - m_new)
520610 sub_exp_block<scale_fp32>(cb_attention_sink, cb_cur_max, cb_exp_max_diff_2, Sq_chunk_t);
521611 cb_pop_front (cb_cur_max, Sq_chunk_t);
522612
523613 // l -> l + exp(sink - m_new)
524- add_block_inplace<true >(cb_cur_sum , cb_exp_max_diff_2, Sq_chunk_t);
614+ add_block_inplace<true >(sum_cb , cb_exp_max_diff_2, Sq_chunk_t);
525615
526616 // O -> O * exp(m - m_new)
527617 mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, cb_exp_max_diff);
528618 }
529619
530- reconfig_data_format (cb_cur_sum, cb_cur_sum );
531- pack_reconfig_data_format (cb_cur_sum );
532- recip_block_inplace (cb_cur_sum , Sq_chunk_t);
620+ reconfig_data_format (sum_cb, sum_cb );
621+ pack_reconfig_data_format (sum_cb );
622+ recip_block_inplace (sum_cb , Sq_chunk_t);
533623
534- /* OUT_ACC *= CUR_SUM */
535- reconfig_data_format (cb_out_accumulate_im, cb_cur_sum );
624+ /* OUT_ACC *= 1/SUM */
625+ reconfig_data_format (cb_out_accumulate_im, sum_cb );
536626 pack_reconfig_data_format (cb_out_accumulate_im);
537627
538- mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, cb_cur_sum );
628+ mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, sum_cb );
539629 pack_reconfig_data_format (cb_out_final);
540630
541631 // Untilize output to ROW MAJOR if input Q was also ROW MAJOR
@@ -564,12 +654,31 @@ void kernel_main() {
564654 // Move output to buffer for the writer
565655 move_block<true >(cb_out_accumulate_im, cb_out_final, out_chunk_tiles);
566656 }
567- // Free up cb_prev_max after K chunks
568- cb_pop_front (cb_prev_max, Sq_chunk_t);
569- cb_pop_front (cb_prev_sum, Sq_chunk_t);
657+
658+ // Free up prev buffers if we used them
659+ if (actual_num_children > 0 ) {
660+ cb_pop_front (cb_prev_max, Sq_chunk_t);
661+ }
662+
663+ } else if (should_send_to_parent) {
664+ // Non-root node with parent: send intermediate results
665+ // We have data (checked at function start), so send it
666+ // After tree reduction (if any), results are in:
667+ // - cb_out_accumulate_im: O
668+ // - cb_prev_max: M
669+ // - cb_prev_sum: L
670+
671+ // Move O to output CB
672+ move_block<true >(cb_out_accumulate_im, cb_out_o, out_chunk_tiles);
673+ // Move M to output CB
674+ move_block<true >(cb_prev_max, cb_out_m, Sq_chunk_t);
675+ // Move L to output CB
676+ move_block<true >(cb_prev_sum, cb_out_l, Sq_chunk_t);
570677 }
571678 }
572679
573- // Free up cb_q_in after Q chunks
574- cb_pop_front (cb_q_in, q_chunk_tiles);
680+ // Free up cb_q_in after Q chunks (only if we had local data)
681+ if (has_local_data) {
682+ cb_pop_front (cb_q_in, q_chunk_tiles);
683+ }
575684}
0 commit comments