Skip to content

Commit 175c336

Browse files
committed
tree reduce working lfg
1 parent 3d6d9a0 commit 175c336

File tree

4 files changed

+566
-128
lines changed

4 files changed

+566
-128
lines changed

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp

Lines changed: 173 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#define REDUCE_OP (PoolType::MAX)
88
#define REDUCE_DIM (ReduceDim::REDUCE_ROW)
9+
#define MAX_TREE_REDUCTION_ROUNDS 6
910

1011
#include "compute_kernel_api.h"
1112
#include "compute_kernel_api/eltwise_binary.h"
@@ -61,6 +62,7 @@ void kernel_main() {
6162
constexpr bool use_half_tile = get_compile_time_arg_val(27);
6263
constexpr uint32_t scale_fp32 = get_compile_time_arg_val(28);
6364
constexpr uint32_t sliding_window_size = get_compile_time_arg_val(29);
65+
constexpr uint32_t num_tree_reduction_rounds = get_compile_time_arg_val(30);
6466

6567
constexpr uint32_t q_chunk_tiles = Sq_chunk_t * DHt;
6668
constexpr uint32_t out_chunk_tiles = Sq_chunk_t * vDHt;
@@ -109,6 +111,20 @@ void kernel_main() {
109111
const uint32_t core_num_in_output = get_arg_val<uint32_t>(arg_idx++);
110112
const uint32_t cur_pos_arg = get_arg_val<uint32_t>(arg_idx++);
111113

114+
// Tree reduction runtime arguments
115+
const bool is_tree_root = get_arg_val<uint32_t>(arg_idx++) == 1;
116+
const uint32_t parent_core_in_group = get_arg_val<uint32_t>(arg_idx++);
117+
const uint32_t send_at_round = get_arg_val<uint32_t>(arg_idx++);
118+
const uint32_t num_children = get_arg_val<uint32_t>(arg_idx++);
119+
const uint32_t my_active_rounds = get_arg_val<uint32_t>(arg_idx++);
120+
const bool has_parent = parent_core_in_group != UINT32_MAX;
121+
122+
// Read children_per_round array
123+
uint32_t children_per_round[MAX_TREE_REDUCTION_ROUNDS];
124+
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
125+
children_per_round[r] = get_arg_val<uint32_t>(arg_idx++);
126+
}
127+
112128
// Idle core
113129
// get_arg_val<uint32_t>(0) can go from 0-63 for the core_num; for active cores 65 is out of range so 65 indicates
114130
// an idle_core
@@ -130,7 +146,7 @@ void kernel_main() {
130146

131147
cb_wait_front(cb_index_id, 1);
132148
cur_pos = read_tile_value(cb_index_id, 0, cur_batch / q_heads_parallel_factor);
133-
cb_pop_front(cb_index_id, 1);
149+
// cb_pop_front(cb_index_id, 1);
134150
}
135151
if (cur_pos == UINT32_MAX) {
136152
// cur_pos of -1 indicates that the user should be skipped
@@ -150,31 +166,64 @@ void kernel_main() {
150166
num_cores_per_head,
151167
k_chunk_size_dynamic,
152168
sliding_window_size > 0 ? std::optional<uint32_t>(sliding_window_size) : std::nullopt);
153-
if (k_chunk_start == k_chunk_end) {
154-
return; // early exit because no computes needs to be done
169+
170+
// Check if this core has local data to process
171+
const bool has_local_data = (k_chunk_start != k_chunk_end);
172+
173+
// Cores without data don't participate in tree reduction at all
174+
// They just exit early - no sending, no receiving
175+
if (!has_local_data) {
176+
return;
177+
}
178+
179+
// Determine which children actually have data (based on chunk allocation)
180+
// A child at core_num has data if core_num < k_num_chunks
181+
uint32_t actual_num_children = 0;
182+
uint32_t actual_children_per_round[MAX_TREE_REDUCTION_ROUNDS];
183+
uint32_t actual_my_active_rounds = 0;
184+
185+
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
186+
uint32_t child_id = children_per_round[r];
187+
if (child_id != UINT32_MAX && child_id < k_num_chunks) {
188+
// This child has data
189+
actual_children_per_round[r] = child_id;
190+
actual_num_children++;
191+
actual_my_active_rounds = r + 1;
192+
} else {
193+
actual_children_per_round[r] = UINT32_MAX;
194+
}
155195
}
156196

197+
// Determine if we have an actual parent (parent must have data too, but root always has data)
198+
// Actually, we only need to check if WE should send - parent will handle receiving
199+
// We send if we have a parent AND we have data (which we do if we reach here)
200+
const bool should_send_to_parent = has_parent;
201+
157202
// Get number of worker cores to wait for
158203
uint32_t num_cores_to_wait = num_cores_per_head - 1;
159204
if (num_cores_per_head > k_num_chunks) {
160205
num_cores_to_wait = k_num_chunks - 1;
161206
}
162207

163-
// We tilize input Q if it is in ROW MAJOR layout
164-
if constexpr (tilize_q) {
165-
compute_kernel_hw_startup(cb_q_rm, cb_q_in);
166-
tilize_init(cb_q_rm, q_chunk_tiles, cb_q_in);
167-
cb_wait_front(cb_q_rm, q_chunk_tiles);
168-
cb_reserve_back(cb_q_in, q_chunk_tiles);
169-
tilize_block(cb_q_rm, q_chunk_tiles, cb_q_in);
170-
tilize_uninit(cb_q_rm, cb_q_in);
171-
cb_push_back(cb_q_in, q_chunk_tiles);
172-
cb_pop_front(cb_q_rm, q_chunk_tiles);
173-
mm_init_short(cb_q_in, cb_k_in);
174-
} else {
175-
mm_init(cb_q_in, cb_k_in, cb_qk_im);
208+
// Only process Q and initialize matmul if we have local data
209+
// Cores that only participate in tree reduction (no local data) skip this
210+
if (has_local_data) {
211+
// We tilize input Q if it is in ROW MAJOR layout
212+
if constexpr (tilize_q) {
213+
compute_kernel_hw_startup(cb_q_rm, cb_q_in);
214+
tilize_init(cb_q_rm, q_chunk_tiles, cb_q_in);
215+
cb_wait_front(cb_q_rm, q_chunk_tiles);
216+
cb_reserve_back(cb_q_in, q_chunk_tiles);
217+
tilize_block(cb_q_rm, q_chunk_tiles, cb_q_in);
218+
tilize_uninit(cb_q_rm, cb_q_in);
219+
cb_push_back(cb_q_in, q_chunk_tiles);
220+
cb_pop_front(cb_q_rm, q_chunk_tiles);
221+
mm_init_short(cb_q_in, cb_k_in);
222+
} else {
223+
mm_init(cb_q_in, cb_k_in, cb_qk_im);
224+
}
225+
cb_wait_front(cb_q_in, q_chunk_tiles);
176226
}
177-
cb_wait_front(cb_q_in, q_chunk_tiles);
178227

179228
// Define dynamic matmul configs
180229
#ifdef DYNAMIC_CHUNK_SIZE
@@ -184,7 +233,6 @@ void kernel_main() {
184233
const uint32_t qk_in1_num_subblocks_dynamic = 1;
185234
const uint32_t out_in0_block_w_dynamic = Sk_chunk_t_dynamic;
186235
const uint32_t out_num_blocks_dynamic = 1;
187-
188236
const uint32_t qk_chunk_tiles_dynamic = Sq_chunk_t * Sk_chunk_t_dynamic;
189237
#else
190238
constexpr uint32_t qk_subblock_h_dynamic = qk_subblock_h;
@@ -193,7 +241,6 @@ void kernel_main() {
193241
constexpr uint32_t qk_in1_num_subblocks_dynamic = qk_in1_num_subblocks;
194242
constexpr uint32_t out_in0_block_w_dynamic = out_in0_block_w;
195243
constexpr uint32_t out_num_blocks_dynamic = out_num_blocks;
196-
197244
constexpr uint32_t qk_chunk_tiles_dynamic = Sq_chunk_t * Sk_chunk_t;
198245
#endif
199246

@@ -262,7 +309,8 @@ void kernel_main() {
262309
* @param out_chunk_tiles - Number of output chunk tiles
263310
*/
264311
/* START OF FLASH ATTENTION LOOP */
265-
{
312+
// Only run FA loop if this core has local data to process
313+
if (has_local_data) {
266314
uint32_t cb_out_mm = cb_out_accumulate_im;
267315

268316
// Loop through all K chunks
@@ -399,7 +447,6 @@ void kernel_main() {
399447
cb_out_mm = cb_out_im;
400448
} else {
401449
// When there is more than 1 chunk, we perform Lazy Softmax
402-
403450
// Reconfig register DF
404451
reconfig_data_format(cb_prev_max, cb_cur_max);
405452
pack_reconfig_data_format(cb_exp_max_diff);
@@ -428,8 +475,10 @@ void kernel_main() {
428475
add_block_inplace<true>(cb_out_accumulate_im, cb_out_im, out_chunk_tiles);
429476
}
430477

431-
if (k_chunk < k_chunk_end - 1 || do_reduce) {
432-
// Move intermediate sum and max values to appropriate ping pong buffers
478+
// Move intermediate sum and max values to appropriate ping pong buffers
479+
// Always move to prev buffers during FA loop - we'll handle final output later
480+
if (k_chunk < k_chunk_end - 1) {
481+
// More local chunks to process - move to ping-pong buffers
433482
reconfig_data_format(cb_cur_max, cb_cur_max);
434483
pack_reconfig_data_format(cb_prev_max);
435484

@@ -438,26 +487,53 @@ void kernel_main() {
438487

439488
// PREV_SUM <- CUR_SUM
440489
move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
441-
} else {
442-
// Write results OUT_ACC, CUR_MAX, CUR_SUM to designated
443-
// Write o, m, l into cb_out
444-
move_block<true>(cb_out_accumulate_im, cb_out_o, out_chunk_tiles);
445-
move_block<true>(cb_cur_max, cb_out_m, Sq_chunk_t);
446-
move_block<true>(cb_cur_sum, cb_out_l, Sq_chunk_t);
447490
}
448491
}
492+
493+
// After FA loop completes, prepare buffers for tree reduction or output
494+
// Results are in: cb_out_accumulate_im (O), cb_cur_max (M), cb_cur_sum (L)
495+
if (actual_num_children > 0 || should_send_to_parent) {
496+
// Tree reduction will happen - move cur to prev buffers
497+
reconfig_data_format(cb_cur_max, cb_cur_max);
498+
pack_reconfig_data_format(cb_prev_max);
499+
move_block<true>(cb_cur_max, cb_prev_max, Sq_chunk_t);
500+
move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
501+
}
502+
// If root with no children, keep in cur buffers for finalization
449503
}
450504
/* END OF FLASH ATTENTION LOOP */
451-
// Perform reduction across intermediates from other cores if this is the reduction core
452-
if (do_reduce) {
453-
// cb_out_accumulate_im should contain o_1 (output from FA of itself's core)
454-
// cb_prev_max and cb_prev_sum should contain m_1 and l_1 (max and sum of logits of itself's core)
455-
456-
if (k_chunk_end - k_chunk_start < k_num_chunks) {
457-
// This indicates that there are computes done by other workers.
458-
// We need to wait for them and send to reducer's compute
459-
// Iterate through each worker
460-
for (uint32_t i = 0; i < num_cores_to_wait; i++) {
505+
506+
/******************************************************************************
507+
* TREE REDUCTION LOGIC *
508+
******************************************************************************/
509+
/**
510+
* Tree reduction replaces the flat worker->reducer pattern with O(log n) rounds.
511+
*
512+
* For each round r (0 to my_active_rounds-1):
513+
* - If children_per_round[r] != UINT32_MAX, receive from that child
514+
* - Combine received data with local accumulator using softmax correction
515+
*
516+
* After all receives:
517+
* - If is_tree_root: finalize (1/sum normalization) and output
518+
* - Else: output intermediate results for writer to send to parent
519+
*/
520+
521+
// Tree reduction: receive from children and combine
522+
// Buffer state entering tree reduction:
523+
// - cb_out_accumulate_im: local O (output accumulator)
524+
// - cb_prev_max: local M (max of logits)
525+
// - cb_prev_sum: local L (sum of exp)
526+
// Only receive from children that actually have data
527+
if (actual_num_children > 0) {
528+
// Iterate through each round and receive from child if one exists AND has data
529+
for (uint32_t round = 0; round < actual_my_active_rounds; ++round) {
530+
uint32_t child_id = actual_children_per_round[round];
531+
if (child_id != UINT32_MAX) {
532+
// Writer kernel handles the semaphore wait and data transfer to cb_m_in, cb_l_in, cb_out_o
533+
// Data arrives in order: m, l, o
534+
535+
// Combine child with existing local/accumulated data
536+
// Move child's L to cb_prev_sum_2 for correction
461537
move_block<true>(cb_l_in, cb_prev_sum_2, Sq_chunk_t);
462538

463539
// Fused Softmax Correction
@@ -468,10 +544,9 @@ void kernel_main() {
468544
// * 4. EXP_MAX_DIFF = exp((PREV_MAX - CUR_MAX)*scale)
469545
// * 5. PREV_SUM *= EXP_MAX_DIFF
470546
// * 6. CUR_SUM = PREV_SUM_2 + PREV_SUM
471-
// */
472547
correction_block<scale_fp32, vector_mode>(
473-
cb_m_in, // cb worker max
474-
cb_prev_sum_2, // cb worker sum
548+
cb_m_in, // cb child max
549+
cb_prev_sum_2, // cb child sum
475550
cb_cur_max,
476551
cb_prev_max,
477552
cb_cur_sum,
@@ -480,17 +555,18 @@ void kernel_main() {
480555
cb_exp_max_diff_2,
481556
Sq_chunk_t);
482557

483-
// OUT_ACC_2 <- WORKER_OUT
558+
// OUT_ACC_2 <- CHILD_OUT
484559
move_block<true>(cb_out_o, cb_out_accumulate_im_2, out_chunk_tiles);
485560

486-
// OUT_ACC_2 *= EXP_MAX_DIFF
487-
// OUT_ACC *= EXP_MAX_DIFF_2
561+
// OUT_ACC *= EXP_MAX_DIFF (scale local accumulator)
562+
// OUT_ACC_2 *= EXP_MAX_DIFF_2 (scale child's accumulator)
488563
mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, cb_exp_max_diff);
489564
mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im_2, cb_exp_max_diff_2);
490565

491566
// OUT_ACC = OUT_ACC + OUT_ACC_2
492567
add_block_inplace<true>(cb_out_accumulate_im, cb_out_accumulate_im_2, out_chunk_tiles);
493568

569+
// Update prev buffers for next round
494570
// PREV_MAX <- CUR_MAX
495571
// PREV_SUM <- CUR_SUM
496572
cb_pop_front(cb_prev_max, Sq_chunk_t);
@@ -499,43 +575,57 @@ void kernel_main() {
499575
move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
500576
}
501577
}
578+
}
579+
580+
// Finalize output based on tree role
581+
if (is_tree_root) {
582+
// Root node: perform final normalization and output
583+
// Determine which sum/max buffer to use based on whether we did tree reduction
584+
// If we had children with data, results are in cb_prev_sum/cb_prev_max after tree reduction
585+
// If single core (no children with data), results are in cb_cur_sum/cb_cur_max from FA loop
502586

503-
/* CUR_SUM = 1.0 / CUR_SUM */
504-
cb_push_back(cb_cur_sum, Sq_chunk_t);
505-
reconfig_data_format(cb_cur_sum, cb_cur_sum);
506-
pack_reconfig_data_format(cb_cur_sum);
587+
// Select the correct sum buffer based on whether tree reduction happened
588+
// If tree reduction happened, sum is in cb_prev_sum; otherwise it's in cb_cur_sum
589+
uint32_t sum_cb = (actual_num_children > 0) ? cb_prev_sum : cb_cur_sum;
590+
591+
/* SUM = 1.0 / SUM */
592+
reconfig_data_format(sum_cb, sum_cb);
593+
pack_reconfig_data_format(sum_cb);
507594

508595
// Handle attention sink here
509596
if constexpr (use_attention_sink) {
597+
// Use appropriate max buffer based on tree reduction
598+
uint32_t max_cb_for_sink = (actual_num_children > 0) ? cb_prev_max : cb_cur_max;
599+
510600
// m_new
511-
max_block<vector_mode>(cb_attention_sink, cb_prev_max, cb_cur_max, Sq_chunk_t);
601+
max_block<vector_mode>(cb_attention_sink, max_cb_for_sink, cb_cur_max, Sq_chunk_t);
512602

513603
// exp(m - m_new)
514-
sub_exp_block<scale_fp32>(cb_prev_max, cb_cur_max, cb_exp_max_diff, Sq_chunk_t);
604+
sub_exp_block<scale_fp32>(max_cb_for_sink, cb_cur_max, cb_exp_max_diff, Sq_chunk_t);
515605

516606
// l -> l * exp(m - m_new)
517-
mul_block_inplace(cb_cur_sum, cb_exp_max_diff, Sq_chunk_t);
607+
mul_block_inplace(sum_cb, cb_exp_max_diff, Sq_chunk_t);
518608

519609
// exp(sink - m_new)
520610
sub_exp_block<scale_fp32>(cb_attention_sink, cb_cur_max, cb_exp_max_diff_2, Sq_chunk_t);
521611
cb_pop_front(cb_cur_max, Sq_chunk_t);
522612

523613
// l -> l + exp(sink - m_new)
524-
add_block_inplace<true>(cb_cur_sum, cb_exp_max_diff_2, Sq_chunk_t);
614+
add_block_inplace<true>(sum_cb, cb_exp_max_diff_2, Sq_chunk_t);
525615

526616
// O -> O * exp(m - m_new)
527617
mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, cb_exp_max_diff);
528618
}
529619

530-
reconfig_data_format(cb_cur_sum, cb_cur_sum);
531-
pack_reconfig_data_format(cb_cur_sum);
532-
recip_block_inplace(cb_cur_sum, Sq_chunk_t);
620+
reconfig_data_format(sum_cb, sum_cb);
621+
pack_reconfig_data_format(sum_cb);
622+
recip_block_inplace(sum_cb, Sq_chunk_t);
533623

534-
/* OUT_ACC *= CUR_SUM */
535-
reconfig_data_format(cb_out_accumulate_im, cb_cur_sum);
624+
/* OUT_ACC *= 1/SUM */
625+
reconfig_data_format(cb_out_accumulate_im, sum_cb);
536626
pack_reconfig_data_format(cb_out_accumulate_im);
537627

538-
mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, cb_cur_sum);
628+
mul_block_bcast_cols_inplace<Sq_chunk_t, vDHt>(cb_out_accumulate_im, sum_cb);
539629
pack_reconfig_data_format(cb_out_final);
540630

541631
// Untilize output to ROW MAJOR if input Q was also ROW MAJOR
@@ -564,12 +654,31 @@ void kernel_main() {
564654
// Move output to buffer for the writer
565655
move_block<true>(cb_out_accumulate_im, cb_out_final, out_chunk_tiles);
566656
}
567-
// Free up cb_prev_max after K chunks
568-
cb_pop_front(cb_prev_max, Sq_chunk_t);
569-
cb_pop_front(cb_prev_sum, Sq_chunk_t);
657+
658+
// Free up prev buffers if we used them
659+
if (actual_num_children > 0) {
660+
cb_pop_front(cb_prev_max, Sq_chunk_t);
661+
}
662+
663+
} else if (should_send_to_parent) {
664+
// Non-root node with parent: send intermediate results
665+
// We have data (checked at function start), so send it
666+
// After tree reduction (if any), results are in:
667+
// - cb_out_accumulate_im: O
668+
// - cb_prev_max: M
669+
// - cb_prev_sum: L
670+
671+
// Move O to output CB
672+
move_block<true>(cb_out_accumulate_im, cb_out_o, out_chunk_tiles);
673+
// Move M to output CB
674+
move_block<true>(cb_prev_max, cb_out_m, Sq_chunk_t);
675+
// Move L to output CB
676+
move_block<true>(cb_prev_sum, cb_out_l, Sq_chunk_t);
570677
}
571678
}
572679

573-
// Free up cb_q_in after Q chunks
574-
cb_pop_front(cb_q_in, q_chunk_tiles);
680+
// Free up cb_q_in after Q chunks (only if we had local data)
681+
if (has_local_data) {
682+
cb_pop_front(cb_q_in, q_chunk_tiles);
683+
}
575684
}

0 commit comments

Comments
 (0)