@@ -61,6 +61,7 @@ void kernel_main() {
6161 constexpr bool use_half_tile = get_compile_time_arg_val (27 );
6262 constexpr uint32_t scale_fp32 = get_compile_time_arg_val (28 );
6363 constexpr uint32_t sliding_window_size = get_compile_time_arg_val (29 );
64+ constexpr uint32_t num_tree_reduction_rounds = get_compile_time_arg_val (30 );
6465
6566 constexpr uint32_t q_chunk_tiles = Sq_chunk_t * DHt;
6667 constexpr uint32_t out_chunk_tiles = Sq_chunk_t * vDHt;
@@ -109,6 +110,19 @@ void kernel_main() {
109110 const uint32_t core_num_in_output = get_arg_val<uint32_t >(arg_idx++);
110111 const uint32_t cur_pos_arg = get_arg_val<uint32_t >(arg_idx++);
111112
113+ // Tree reduction runtime arguments
114+ const bool is_tree_root = get_arg_val<uint32_t >(arg_idx++) == 1 ;
115+ const uint32_t parent_core_in_group = get_arg_val<uint32_t >(arg_idx++);
116+ const uint32_t send_at_round = get_arg_val<uint32_t >(arg_idx++);
117+ const uint32_t num_children = get_arg_val<uint32_t >(arg_idx++);
118+ const uint32_t my_active_rounds = get_arg_val<uint32_t >(arg_idx++);
119+
120+ // Read children_per_round array
121+ uint32_t children_per_round[MAX_TREE_REDUCTION_ROUNDS];
122+ for (uint32_t r = 0 ; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
123+ children_per_round[r] = get_arg_val<uint32_t >(arg_idx++);
124+ }
125+
112126 // Idle core
113127 // get_arg_val<uint32_t>(0) can go from 0-63 for the core_num; for active cores 65 is out of range so 65 indicates
114128 // an idle_core
@@ -280,6 +294,7 @@ void kernel_main() {
280294 bool add_mask_fusion = false ;
281295 bool add_sliding_window_mask_fusion = false ;
282296#endif
297+ DPRINT << " doing fa main loop" << ENDL ();
283298
284299 /* QK = Q_CHUNK @ K_CHUNK */
285300 // Determine which mask buffer to use for fusion
@@ -399,7 +414,7 @@ void kernel_main() {
399414 cb_out_mm = cb_out_im;
400415 } else {
401416 // When there is more than 1 chunk, we perform Lazy Softmax
402-
417+ DPRINT << " doing local softmax " << ENDL ();
403418 // Reconfig register DF
404419 reconfig_data_format (cb_prev_max, cb_cur_max);
405420 pack_reconfig_data_format (cb_exp_max_diff);
@@ -428,38 +443,63 @@ void kernel_main() {
428443 add_block_inplace<true >(cb_out_accumulate_im, cb_out_im, out_chunk_tiles);
429444 }
430445
431- if (k_chunk < k_chunk_end - 1 || do_reduce) {
432- // Move intermediate sum and max values to appropriate ping pong buffers
433- reconfig_data_format (cb_cur_max, cb_cur_max);
434- pack_reconfig_data_format (cb_prev_max);
435-
436- // PREV_MAX <- CUR_MAX
437- move_block<true >(cb_cur_max, cb_prev_max, Sq_chunk_t);
438-
439- // PREV_SUM <- CUR_SUM
440- move_block<true >(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
441- } else {
442- // Write results OUT_ACC, CUR_MAX, CUR_SUM to designated
443- // Write o, m, l into cb_out
444- move_block<true >(cb_out_accumulate_im, cb_out_o, out_chunk_tiles);
445- move_block<true >(cb_cur_max, cb_out_m, Sq_chunk_t);
446- move_block<true >(cb_cur_sum, cb_out_l, Sq_chunk_t);
447- }
446+ // Move intermediate sum and max values to appropriate ping pong buffers
447+ reconfig_data_format (cb_cur_max, cb_cur_max);
448+ pack_reconfig_data_format (cb_prev_max);
449+
450+ // PREV_MAX <- CUR_MAX
451+ move_block<true >(cb_cur_max, cb_prev_max, Sq_chunk_t);
452+
453+ // PREV_SUM <- CUR_SUM
454+ move_block<true >(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
455+
456+ // After this point:
457+ // cb_out_accumulate_im contains o_1
458+ // cb_prev_max contains m_1
459+ // cb_prev_sum contains l_1
460+
461+ // else {
462+ // DPRINT << "local move for tree reduction root" << ENDL();
463+ // // Write results OUT_ACC, CUR_MAX, CUR_SUM to designated
464+ // // Write o, m, l into cb_out
465+ // move_block<true>(cb_out_accumulate_im, cb_out_o, out_chunk_tiles);
466+ // move_block<true>(cb_cur_max, cb_out_m, Sq_chunk_t);
467+ // move_block<true>(cb_cur_sum, cb_out_l, Sq_chunk_t);
468+ // }
448469 }
449470 }
450471 /* END OF FLASH ATTENTION LOOP */
451- // Perform reduction across intermediates from other cores if this is the reduction core
452- if (do_reduce) {
453- // cb_out_accumulate_im should contain o_1 (output from FA of itself's core)
454- // cb_prev_max and cb_prev_sum should contain m_1 and l_1 (max and sum of logits of itself's core)
455-
456- if (k_chunk_end - k_chunk_start < k_num_chunks) {
457- // This indicates that there are computes done by other workers.
458- // We need to wait for them and send to reducer's compute
459- // Iterate through each worker
460- for (uint32_t i = 0 ; i < num_cores_to_wait; i++) {
461- move_block<true >(cb_l_in, cb_prev_sum_2, Sq_chunk_t);
462472
473+ /* *****************************************************************************
474+ * TREE REDUCTION LOGIC *
475+ ******************************************************************************/
476+ /* *
477+ * Tree reduction replaces the flat worker->reducer pattern with O(log n) rounds.
478+ *
479+ * For each round r (0 to my_active_rounds-1):
480+ * - If children_per_round[r] != UINT32_MAX, receive from that child
481+ * - Combine received data with local accumulator using softmax correction
482+ *
483+ * After all receives:
484+ * - If is_tree_root: finalize (1/sum normalization) and output
485+ * - Else: output intermediate results for writer to send to parent
486+ */
487+ DPRINT << " doing tree reduction" << ENDL ();
488+
489+ // Tree reduction: receive from children and combine
490+ if (num_children > 0 && k_chunk_end - k_chunk_start < k_num_chunks) {
491+ // cb_out_accumulate_im should contain o_1 (output from FA of this core)
492+ // cb_prev_max and cb_prev_sum should contain m_1 and l_1 (max and sum of logits of this core)
493+
494+ // Iterate through each round and receive from child if one exists
495+ for (uint32_t round = 0 ; round < my_active_rounds; ++round) {
496+ DPRINT << " doing tree reduction round " << round << ENDL ();
497+ uint32_t child_id = children_per_round[round];
498+ if (child_id != UINT32_MAX) {
499+ DPRINT << " doing tree reduction child " << child_id << ENDL ();
500+ // Writer kernel handles the wait and data transfer to cb_m_in, cb_l_in, cb_out_o
501+ move_block<true >(cb_l_in, cb_prev_sum_2, Sq_chunk_t);
502+ DPRINT << " moved child sum to prev_sum_2" << ENDL ();
463503 // Fused Softmax Correction
464504 // * Fused Correction is a fused operation that performs the following steps:
465505 // * 1. CUR_MAX = max(PREV_MAX, WORKER_MAX)
@@ -468,10 +508,9 @@ void kernel_main() {
468508 // * 4. EXP_MAX_DIFF = exp((PREV_MAX - CUR_MAX)*scale)
469509 // * 5. PREV_SUM *= EXP_MAX_DIFF
470510 // * 6. CUR_SUM = PREV_SUM_2 + PREV_SUM
471- // */
472511 correction_block<scale_fp32, vector_mode>(
473- cb_m_in, // cb worker max
474- cb_prev_sum_2, // cb worker sum
512+ cb_m_in, // cb child max
513+ cb_prev_sum_2, // cb child sum
475514 cb_cur_max,
476515 cb_prev_max,
477516 cb_cur_sum,
@@ -480,11 +519,12 @@ void kernel_main() {
480519 cb_exp_max_diff_2,
481520 Sq_chunk_t);
482521
483- // OUT_ACC_2 <- WORKER_OUT
522+ DPRINT << " done correction" << ENDL ();
523+ // OUT_ACC_2 <- CHILD_OUT
484524 move_block<true >(cb_out_o, cb_out_accumulate_im_2, out_chunk_tiles);
485525
486- // OUT_ACC_2 *= EXP_MAX_DIFF
487- // OUT_ACC *= EXP_MAX_DIFF_2
526+ // OUT_ACC *= EXP_MAX_DIFF (scale local accumulator)
527+ // OUT_ACC_2 *= EXP_MAX_DIFF_2 (scale child's accumulator)
488528 mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, cb_exp_max_diff);
489529 mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im_2, cb_exp_max_diff_2);
490530
@@ -497,9 +537,15 @@ void kernel_main() {
497537 cb_pop_front (cb_m_in, Sq_chunk_t);
498538 move_block<true >(cb_cur_max, cb_prev_max, Sq_chunk_t);
499539 move_block<true >(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
540+ DPRINT << " moved cur_max and cur_sum to prev_max and prev_sum" << ENDL ();
500541 }
501542 }
543+ }
502544
545+ // Finalize output based on tree role
546+ if (is_tree_root) {
547+ // Root node: perform final normalization and output
548+ DPRINT << " doing tree reduction root" << ENDL ();
503549 /* CUR_SUM = 1.0 / CUR_SUM */
504550 cb_push_back (cb_cur_sum, Sq_chunk_t);
505551 reconfig_data_format (cb_cur_sum, cb_cur_sum);
@@ -567,6 +613,18 @@ void kernel_main() {
567613 // Free up cb_prev_max after K chunks
568614 cb_pop_front (cb_prev_max, Sq_chunk_t);
569615 cb_pop_front (cb_prev_sum, Sq_chunk_t);
616+ DPRINT << " root done math" << ENDL ();
617+ } else if (parent_core_in_group != UINT32_MAX) {
618+ // Non-root node: output intermediate results for writer to send to parent
619+ // Writer will read from cb_out_worker (cb_out_o), cb_out_m, cb_out_l
620+ DPRINT << " doing tree reduction non-root" << ENDL ();
621+ move_block<true >(cb_out_accumulate_im, cb_out_o, out_chunk_tiles);
622+ DPRINT << " moved out im to out_o" << ENDL ();
623+ move_block<true >(cb_prev_max, cb_out_m, Sq_chunk_t);
624+ DPRINT << " moved prev_max to out_m" << ENDL ();
625+ move_block<true >(cb_prev_sum, cb_out_l, Sq_chunk_t);
626+ DPRINT << " moved prev_sum to out_l" << ENDL ();
627+ DPRINT << " non-root done math" << ENDL ();
570628 }
571629 }
572630
0 commit comments