Skip to content

Commit aafa675

Browse files
committed
tree reduce not hanging no more
1 parent 3d6d9a0 commit aafa675

File tree

5 files changed

+648
-94
lines changed

5 files changed

+648
-94
lines changed

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp

Lines changed: 93 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ void kernel_main() {
6161
constexpr bool use_half_tile = get_compile_time_arg_val(27);
6262
constexpr uint32_t scale_fp32 = get_compile_time_arg_val(28);
6363
constexpr uint32_t sliding_window_size = get_compile_time_arg_val(29);
64+
constexpr uint32_t num_tree_reduction_rounds = get_compile_time_arg_val(30);
6465

6566
constexpr uint32_t q_chunk_tiles = Sq_chunk_t * DHt;
6667
constexpr uint32_t out_chunk_tiles = Sq_chunk_t * vDHt;
@@ -109,6 +110,19 @@ void kernel_main() {
109110
const uint32_t core_num_in_output = get_arg_val<uint32_t>(arg_idx++);
110111
const uint32_t cur_pos_arg = get_arg_val<uint32_t>(arg_idx++);
111112

113+
// Tree reduction runtime arguments
114+
const bool is_tree_root = get_arg_val<uint32_t>(arg_idx++) == 1;
115+
const uint32_t parent_core_in_group = get_arg_val<uint32_t>(arg_idx++);
116+
const uint32_t send_at_round = get_arg_val<uint32_t>(arg_idx++);
117+
const uint32_t num_children = get_arg_val<uint32_t>(arg_idx++);
118+
const uint32_t my_active_rounds = get_arg_val<uint32_t>(arg_idx++);
119+
120+
// Read children_per_round array
121+
uint32_t children_per_round[MAX_TREE_REDUCTION_ROUNDS];
122+
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
123+
children_per_round[r] = get_arg_val<uint32_t>(arg_idx++);
124+
}
125+
112126
// Idle core
113127
// get_arg_val<uint32_t>(0) can go from 0-63 for the core_num; for active cores 65 is out of range so 65 indicates
114128
// an idle_core
@@ -280,6 +294,7 @@ void kernel_main() {
280294
bool add_mask_fusion = false;
281295
bool add_sliding_window_mask_fusion = false;
282296
#endif
297+
DPRINT << "doing fa main loop" << ENDL();
283298

284299
/* QK = Q_CHUNK @ K_CHUNK */
285300
// Determine which mask buffer to use for fusion
@@ -399,7 +414,7 @@ void kernel_main() {
399414
cb_out_mm = cb_out_im;
400415
} else {
401416
// When there is more than 1 chunk, we perform Lazy Softmax
402-
417+
DPRINT << "doing local softmax" << ENDL();
403418
// Reconfig register DF
404419
reconfig_data_format(cb_prev_max, cb_cur_max);
405420
pack_reconfig_data_format(cb_exp_max_diff);
@@ -428,38 +443,63 @@ void kernel_main() {
428443
add_block_inplace<true>(cb_out_accumulate_im, cb_out_im, out_chunk_tiles);
429444
}
430445

431-
if (k_chunk < k_chunk_end - 1 || do_reduce) {
432-
// Move intermediate sum and max values to appropriate ping pong buffers
433-
reconfig_data_format(cb_cur_max, cb_cur_max);
434-
pack_reconfig_data_format(cb_prev_max);
435-
436-
// PREV_MAX <- CUR_MAX
437-
move_block<true>(cb_cur_max, cb_prev_max, Sq_chunk_t);
438-
439-
// PREV_SUM <- CUR_SUM
440-
move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
441-
} else {
442-
// Write results OUT_ACC, CUR_MAX, CUR_SUM to designated
443-
// Write o, m, l into cb_out
444-
move_block<true>(cb_out_accumulate_im, cb_out_o, out_chunk_tiles);
445-
move_block<true>(cb_cur_max, cb_out_m, Sq_chunk_t);
446-
move_block<true>(cb_cur_sum, cb_out_l, Sq_chunk_t);
447-
}
446+
// Move intermediate sum and max values to appropriate ping pong buffers
447+
reconfig_data_format(cb_cur_max, cb_cur_max);
448+
pack_reconfig_data_format(cb_prev_max);
449+
450+
// PREV_MAX <- CUR_MAX
451+
move_block<true>(cb_cur_max, cb_prev_max, Sq_chunk_t);
452+
453+
// PREV_SUM <- CUR_SUM
454+
move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
455+
456+
// After this point:
457+
// cb_out_accumulate_im contains o_1
458+
// cb_prev_max contains m_1
459+
// cb_prev_sum contains l_1
460+
461+
// else {
462+
// DPRINT << "local move for tree reduction root" << ENDL();
463+
// // Write results OUT_ACC, CUR_MAX, CUR_SUM to designated
464+
// // Write o, m, l into cb_out
465+
// move_block<true>(cb_out_accumulate_im, cb_out_o, out_chunk_tiles);
466+
// move_block<true>(cb_cur_max, cb_out_m, Sq_chunk_t);
467+
// move_block<true>(cb_cur_sum, cb_out_l, Sq_chunk_t);
468+
// }
448469
}
449470
}
450471
/* END OF FLASH ATTENTION LOOP */
451-
// Perform reduction across intermediates from other cores if this is the reduction core
452-
if (do_reduce) {
453-
// cb_out_accumulate_im should contain o_1 (output from FA of itself's core)
454-
// cb_prev_max and cb_prev_sum should contain m_1 and l_1 (max and sum of logits of itself's core)
455-
456-
if (k_chunk_end - k_chunk_start < k_num_chunks) {
457-
// This indicates that there are computes done by other workers.
458-
// We need to wait for them and send to reducer's compute
459-
// Iterate through each worker
460-
for (uint32_t i = 0; i < num_cores_to_wait; i++) {
461-
move_block<true>(cb_l_in, cb_prev_sum_2, Sq_chunk_t);
462472

473+
/******************************************************************************
474+
* TREE REDUCTION LOGIC *
475+
******************************************************************************/
476+
/**
477+
* Tree reduction replaces the flat worker->reducer pattern with O(log n) rounds.
478+
*
479+
* For each round r (0 to my_active_rounds-1):
480+
* - If children_per_round[r] != UINT32_MAX, receive from that child
481+
* - Combine received data with local accumulator using softmax correction
482+
*
483+
* After all receives:
484+
* - If is_tree_root: finalize (1/sum normalization) and output
485+
* - Else: output intermediate results for writer to send to parent
486+
*/
487+
DPRINT << "doing tree reduction" << ENDL();
488+
489+
// Tree reduction: receive from children and combine
490+
if (num_children > 0 && k_chunk_end - k_chunk_start < k_num_chunks) {
491+
// cb_out_accumulate_im should contain o_1 (output from FA of this core)
492+
// cb_prev_max and cb_prev_sum should contain m_1 and l_1 (max and sum of logits of this core)
493+
494+
// Iterate through each round and receive from child if one exists
495+
for (uint32_t round = 0; round < my_active_rounds; ++round) {
496+
DPRINT << "doing tree reduction round " << round << ENDL();
497+
uint32_t child_id = children_per_round[round];
498+
if (child_id != UINT32_MAX) {
499+
DPRINT << "doing tree reduction child " << child_id << ENDL();
500+
// Writer kernel handles the wait and data transfer to cb_m_in, cb_l_in, cb_out_o
501+
move_block<true>(cb_l_in, cb_prev_sum_2, Sq_chunk_t);
502+
DPRINT << "moved child sum to prev_sum_2" << ENDL();
463503
// Fused Softmax Correction
464504
// * Fused Correction is a fused operation that performs the following steps:
465505
// * 1. CUR_MAX = max(PREV_MAX, WORKER_MAX)
@@ -468,10 +508,9 @@ void kernel_main() {
468508
// * 4. EXP_MAX_DIFF = exp((PREV_MAX - CUR_MAX)*scale)
469509
// * 5. PREV_SUM *= EXP_MAX_DIFF
470510
// * 6. CUR_SUM = PREV_SUM_2 + PREV_SUM
471-
// */
472511
correction_block<scale_fp32, vector_mode>(
473-
cb_m_in, // cb worker max
474-
cb_prev_sum_2, // cb worker sum
512+
cb_m_in, // cb child max
513+
cb_prev_sum_2, // cb child sum
475514
cb_cur_max,
476515
cb_prev_max,
477516
cb_cur_sum,
@@ -480,11 +519,12 @@ void kernel_main() {
480519
cb_exp_max_diff_2,
481520
Sq_chunk_t);
482521

483-
// OUT_ACC_2 <- WORKER_OUT
522+
DPRINT << "done correction" << ENDL();
523+
// OUT_ACC_2 <- CHILD_OUT
484524
move_block<true>(cb_out_o, cb_out_accumulate_im_2, out_chunk_tiles);
485525

486-
// OUT_ACC_2 *= EXP_MAX_DIFF
487-
// OUT_ACC *= EXP_MAX_DIFF_2
526+
// OUT_ACC *= EXP_MAX_DIFF (scale local accumulator)
527+
// OUT_ACC_2 *= EXP_MAX_DIFF_2 (scale child's accumulator)
488528
mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, cb_exp_max_diff);
489529
mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im_2, cb_exp_max_diff_2);
490530

@@ -497,9 +537,15 @@ void kernel_main() {
497537
cb_pop_front(cb_m_in, Sq_chunk_t);
498538
move_block<true>(cb_cur_max, cb_prev_max, Sq_chunk_t);
499539
move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
540+
DPRINT << "moved cur_max and cur_sum to prev_max and prev_sum" << ENDL();
500541
}
501542
}
543+
}
502544

545+
// Finalize output based on tree role
546+
if (is_tree_root) {
547+
// Root node: perform final normalization and output
548+
DPRINT << "doing tree reduction root" << ENDL();
503549
/* CUR_SUM = 1.0 / CUR_SUM */
504550
cb_push_back(cb_cur_sum, Sq_chunk_t);
505551
reconfig_data_format(cb_cur_sum, cb_cur_sum);
@@ -567,6 +613,18 @@ void kernel_main() {
567613
// Free up cb_prev_max after K chunks
568614
cb_pop_front(cb_prev_max, Sq_chunk_t);
569615
cb_pop_front(cb_prev_sum, Sq_chunk_t);
616+
DPRINT << "root done math" << ENDL();
617+
} else if (parent_core_in_group != UINT32_MAX) {
618+
// Non-root node: output intermediate results for writer to send to parent
619+
// Writer will read from cb_out_worker (cb_out_o), cb_out_m, cb_out_l
620+
DPRINT << "doing tree reduction non-root" << ENDL();
621+
move_block<true>(cb_out_accumulate_im, cb_out_o, out_chunk_tiles);
622+
DPRINT << "moved out im to out_o" << ENDL();
623+
move_block<true>(cb_prev_max, cb_out_m, Sq_chunk_t);
624+
DPRINT << "moved prev_max to out_m" << ENDL();
625+
move_block<true>(cb_prev_sum, cb_out_l, Sq_chunk_t);
626+
DPRINT << "moved prev_sum to out_l" << ENDL();
627+
DPRINT << "non-root done math" << ENDL();
570628
}
571629
}
572630

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,112 @@ void worker_compute(
402402
cb_pop_front(cb_out_l, PNHt);
403403
}
404404

405+
/******************************************************************************
406+
* Tree Reduction Worker Functions *
407+
******************************************************************************/
408+
409+
/**
410+
* Tree reduction send function: sends intermediate results (o, m, l) to parent core
411+
*
412+
* @param parent_noc_x Physical X coordinate of parent core
413+
* @param parent_noc_y Physical Y coordinate of parent core
414+
* @param semaphore_addr Local semaphore address
415+
* @param round The current round in tree reduction (determines write offset)
416+
*/
417+
template <
418+
uint32_t out_chunk_tiles,
419+
uint32_t cb_out,
420+
uint32_t cb_out_m,
421+
uint32_t cb_out_l,
422+
uint32_t cb_intermed_out,
423+
uint32_t PNHt>
424+
void tree_reduction_send_to_parent(
425+
uint32_t parent_noc_x, uint32_t parent_noc_y, uint32_t semaphore_addr, uint32_t round) {
426+
// Wait for compute to deliver output chunk
427+
DPRINT << "waiting for compute to deliver out and send" << ENDL();
428+
429+
cb_wait_front(cb_out, out_chunk_tiles);
430+
cb_wait_front(cb_out_m, PNHt);
431+
cb_wait_front(cb_out_l, PNHt);
432+
DPRINT << "compute produced data, gonna send" << ENDL();
433+
// In tree reduction, each round has a specific offset in the parent's intermediate buffer
434+
// Round 0 children write at offset 0, round 1 children write at offset 1, etc.
435+
constexpr uint32_t tile_bytes = get_tile_size(cb_out);
436+
uint32_t block_offset = round * (out_chunk_tiles + 2 * PNHt) * tile_bytes;
437+
constexpr uint32_t o_write_size = out_chunk_tiles * tile_bytes;
438+
constexpr uint32_t ml_write_size = PNHt * tile_bytes;
439+
440+
uint64_t output_write_addr =
441+
get_noc_addr(parent_noc_x, parent_noc_y, get_write_ptr(cb_intermed_out)) + block_offset;
442+
443+
// send m, l, o to parent (same order as original worker_compute)
444+
noc_async_write(get_read_ptr(cb_out_m), output_write_addr, ml_write_size);
445+
output_write_addr += ml_write_size;
446+
noc_async_write(get_read_ptr(cb_out_l), output_write_addr, ml_write_size);
447+
output_write_addr += ml_write_size;
448+
noc_async_write(get_read_ptr(cb_out), output_write_addr, o_write_size);
449+
450+
// increment parent's semaphore
451+
noc_async_write_barrier();
452+
uint64_t parent_semaphore_noc_addr = get_noc_addr(parent_noc_x, parent_noc_y, semaphore_addr);
453+
noc_semaphore_inc(parent_semaphore_noc_addr, 1);
454+
DPRINT << "incremented parent sem" << ENDL();
455+
// pop front
456+
cb_pop_front(cb_out, out_chunk_tiles);
457+
cb_pop_front(cb_out_m, PNHt);
458+
cb_pop_front(cb_out_l, PNHt);
459+
DPRINT << "sent to parent" << ENDL();
460+
}
461+
462+
/**
463+
* Tree reduction receive function: receives intermediate results from a child core
464+
* Data is read from this core's intermediate buffer and pushed to compute CBs
465+
*
466+
* @param round The round from which the child is sending (determines read offset)
467+
*/
468+
template <
469+
uint32_t out_chunk_tiles,
470+
uint32_t cb_out_o,
471+
uint32_t cb_m_in,
472+
uint32_t cb_l_in,
473+
uint32_t cb_intermed_out,
474+
uint32_t PNHt>
475+
void tree_reduction_receive_from_child(uint32_t round) {
476+
constexpr uint32_t tile_bytes_intermed = get_tile_size(cb_intermed_out);
477+
constexpr uint32_t o_read_size = out_chunk_tiles * tile_bytes_intermed;
478+
constexpr uint32_t ml_read_size = PNHt * tile_bytes_intermed;
479+
480+
// Calculate offset based on round
481+
uint32_t block_offset = round * (out_chunk_tiles + 2 * PNHt) * tile_bytes_intermed;
482+
uint64_t intermed_l1_read_addr = get_noc_addr(get_read_ptr(cb_intermed_out)) + block_offset;
483+
484+
// Reserve and read m, l, o (same order as send)
485+
DPRINT << "reserving mlo1 for round " << round << ENDL();
486+
cb_reserve_back(cb_m_in, PNHt);
487+
DPRINT << "reserving mlo2 for round " << round << ENDL();
488+
cb_reserve_back(cb_l_in, PNHt);
489+
DPRINT << "reserving mlo3 for round " << round << ENDL();
490+
cb_reserve_back(cb_out_o, out_chunk_tiles);
491+
DPRINT << "reserved mlo for round " << round << ENDL();
492+
uint32_t m_write_ptr = get_read_ptr(cb_m_in);
493+
noc_async_read(intermed_l1_read_addr, m_write_ptr, ml_read_size);
494+
intermed_l1_read_addr += ml_read_size;
495+
noc_async_read_barrier();
496+
cb_push_back(cb_m_in, PNHt);
497+
498+
uint32_t l_write_ptr = get_read_ptr(cb_l_in);
499+
noc_async_read(intermed_l1_read_addr, l_write_ptr, ml_read_size);
500+
intermed_l1_read_addr += ml_read_size;
501+
noc_async_read_barrier();
502+
cb_push_back(cb_l_in, PNHt);
503+
504+
uint32_t o_write_ptr = get_read_ptr(cb_out_o);
505+
noc_async_read(intermed_l1_read_addr, o_write_ptr, o_read_size);
506+
noc_async_read_barrier();
507+
cb_push_back(cb_out_o, out_chunk_tiles);
508+
DPRINT << "data is ready for another reduction" << round << ENDL();
509+
}
510+
405511
template <uint32_t cb_out, uint32_t out_chunk_tiles, uint32_t barrier_threshold, typename WriterType>
406512
uint32_t write_tiles_to_memory(uint32_t& out_tile_id, const WriterType& out_writer, uint32_t& barrier_count) {
407513
constexpr uint32_t tile_bytes = get_tile_size(cb_out);

0 commit comments

Comments
 (0)