Skip to content

Commit 0757704

Browse files
committed
clean up
1 parent 69cf16d commit 0757704

File tree

2 files changed

+35
-42
lines changed

2 files changed

+35
-42
lines changed

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -177,35 +177,28 @@ void kernel_main() {
177177
return;
178178
}
179179

180-
// Determine which children actually have data (based on chunk allocation)
181-
// A child at core_num has data if core_num < k_num_chunks
182-
uint32_t actual_num_children = 0;
183-
uint32_t actual_children_per_round[MAX_TREE_REDUCTION_ROUNDS];
184-
uint32_t actual_my_active_rounds = 0;
180+
// Determine which children actually participate in reduction (based on chunk allocation)
181+
// A child at core_num is active or has data if core_num < k_num_chunks
182+
// E.g. k_num_chunks = 2, num_cores_per_head = 4
183+
// | core 0 | core 1 | core 2 | core 3 |
184+
// chunk 0 chunk 1 NA NA
185+
// core 0 would have core 1 and core 2 as children, but only core 1 is active to perform reduction with
186+
uint32_t num_active_children = 0;
187+
uint32_t active_children_per_round[MAX_TREE_REDUCTION_ROUNDS];
188+
uint32_t num_active_rounds = 0;
185189

186190
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
187191
uint32_t child_id = children_per_round[r];
188192
if (child_id != UINT32_MAX && child_id < k_num_chunks) {
189193
// This child has data
190-
actual_children_per_round[r] = child_id;
191-
actual_num_children++;
192-
actual_my_active_rounds = r + 1;
194+
active_children_per_round[r] = child_id;
195+
num_active_children++;
196+
num_active_rounds = r + 1;
193197
} else {
194-
actual_children_per_round[r] = UINT32_MAX;
198+
active_children_per_round[r] = UINT32_MAX;
195199
}
196200
}
197201

198-
// Determine if we have an actual parent (parent must have data too, but root always has data)
199-
// Actually, we only need to check if WE should send - parent will handle receiving
200-
// We send if we have a parent AND we have data (which we do if we reach here)
201-
const bool should_send_to_parent = has_parent;
202-
203-
// Get number of worker cores to wait for
204-
uint32_t num_cores_to_wait = num_cores_per_head - 1;
205-
if (num_cores_per_head > k_num_chunks) {
206-
num_cores_to_wait = k_num_chunks - 1;
207-
}
208-
209202
// We tilize input Q if it is in ROW MAJOR layout
210203
if constexpr (tilize_q) {
211204
compute_kernel_hw_startup(cb_q_rm, cb_q_in);
@@ -492,7 +485,7 @@ void kernel_main() {
492485

493486
// After FA loop completes, prepare buffers for tree reduction or output
494487
// Results are in: cb_out_accumulate_im (O), cb_cur_max (M), cb_cur_sum (L)
495-
if (actual_num_children > 0 || should_send_to_parent) {
488+
if (num_active_children > 0 || has_parent) {
496489
// Tree reduction will happen - move cur to prev buffers
497490
reconfig_data_format(cb_cur_max, cb_cur_max);
498491
pack_reconfig_data_format(cb_prev_max);
@@ -523,10 +516,10 @@ void kernel_main() {
523516
// - cb_prev_max: local M (max of logits)
524517
// - cb_prev_sum: local L (sum of exp)
525518
// Only receive from children that actually have data
526-
if (actual_num_children > 0) {
519+
if (num_active_children > 0) {
527520
// Iterate through each round and receive from child if one exists AND has data
528-
for (uint32_t round = 0; round < actual_my_active_rounds; ++round) {
529-
uint32_t child_id = actual_children_per_round[round];
521+
for (uint32_t round = 0; round < num_active_rounds; ++round) {
522+
uint32_t child_id = active_children_per_round[round];
530523
if (child_id != UINT32_MAX) {
531524
// Writer kernel handles the semaphore wait and data transfer to cb_m_in, cb_l_in, cb_out_o
532525
// Data arrives in order: l, m, o
@@ -584,7 +577,7 @@ void kernel_main() {
584577

585578
// Select the correct sum buffer based on whether tree reduction happened
586579
// If tree reduction happened, sum is in cb_prev_sum; otherwise it's in cb_cur_sum
587-
uint32_t sum_cb = (actual_num_children > 0) ? cb_prev_sum : cb_cur_sum;
580+
uint32_t sum_cb = (num_active_children > 0) ? cb_prev_sum : cb_cur_sum;
588581

589582
/* SUM = 1.0 / SUM */
590583
reconfig_data_format(sum_cb, sum_cb);
@@ -593,7 +586,7 @@ void kernel_main() {
593586
// Handle attention sink here
594587
if constexpr (use_attention_sink) {
595588
// Use appropriate max buffer based on tree reduction
596-
uint32_t max_cb_for_sink = (actual_num_children > 0) ? cb_prev_max : cb_cur_max;
589+
uint32_t max_cb_for_sink = (num_active_children > 0) ? cb_prev_max : cb_cur_max;
597590

598591
// m_new
599592
max_block<vector_mode>(cb_attention_sink, max_cb_for_sink, cb_cur_max, Sq_chunk_t);
@@ -632,16 +625,16 @@ void kernel_main() {
632625
// Pop the max buffer that still has data
633626
if constexpr (use_attention_sink) {
634627
// In attention sink path:
635-
// - If actual_num_children > 0: max_cb_for_sink = cb_prev_max, cb_cur_max was popped
628+
// - If num_active_children > 0: max_cb_for_sink = cb_prev_max, cb_cur_max was popped
636629
// So we need to pop cb_prev_max
637-
// - If actual_num_children == 0: max_cb_for_sink = cb_cur_max, cb_cur_max was popped
630+
// - If num_active_children == 0: max_cb_for_sink = cb_cur_max, cb_cur_max was popped
638631
// Nothing left to pop
639-
if (actual_num_children > 0) {
632+
if (num_active_children > 0) {
640633
cb_pop_front(cb_prev_max, Sq_chunk_t);
641634
}
642635
} else {
643636
// No attention sink: the max buffer was never popped
644-
uint32_t max_cb = (actual_num_children > 0) ? cb_prev_max : cb_cur_max;
637+
uint32_t max_cb = (num_active_children > 0) ? cb_prev_max : cb_cur_max;
645638
cb_pop_front(max_cb, Sq_chunk_t);
646639
}
647640

@@ -672,7 +665,7 @@ void kernel_main() {
672665
move_block<true>(cb_out_accumulate_im, cb_out_final, out_chunk_tiles);
673666
}
674667

675-
} else if (should_send_to_parent) {
668+
} else if (has_parent) {
676669
// Non-root node with parent: send intermediate results
677670
// We have data (checked at function start), so send it
678671
// After tree reduction (if any), results are in:

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,21 +129,21 @@ void kernel_main() {
129129
return;
130130
}
131131

132-
// Determine which children actually have data (based on chunk allocation)
132+
// Determine which children actually participate in reduction (based on chunk allocation)
133133
// A child at core_num has data if core_num < k_num_chunks
134-
uint32_t actual_num_children = 0;
135-
uint32_t actual_children_per_round[MAX_TREE_REDUCTION_ROUNDS];
136-
uint32_t actual_my_active_rounds = 0;
134+
uint32_t num_active_children = 0;
135+
uint32_t active_children_per_round[MAX_TREE_REDUCTION_ROUNDS];
136+
uint32_t num_active_rounds = 0;
137137

138138
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
139139
uint32_t child_id = children_per_round[r];
140140
if (child_id != UINT32_MAX && child_id < k_num_chunks) {
141141
// This child has data
142-
actual_children_per_round[r] = child_id;
143-
actual_num_children++;
144-
actual_my_active_rounds = r + 1;
142+
active_children_per_round[r] = child_id;
143+
num_active_children++;
144+
num_active_rounds = r + 1;
145145
} else {
146-
actual_children_per_round[r] = UINT32_MAX;
146+
active_children_per_round[r] = UINT32_MAX;
147147
}
148148
}
149149

@@ -235,12 +235,12 @@ void kernel_main() {
235235
// Each round, we wait for one child (if any), read remote_sum, remote_max, remote_output, and push to CBs
236236
// The compute kernel processes each child's data before we move to the next round
237237
// Only receive from children that actually have data
238-
if (actual_num_children > 0) {
238+
if (num_active_children > 0) {
239239
ASSERT(num_heads_per_core == 1); // if there are workers, then head must be split across workers
240240

241241
// Process each round sequentially
242-
for (uint32_t round = 0; round < actual_my_active_rounds; ++round) {
243-
uint32_t child_id = actual_children_per_round[round];
242+
for (uint32_t round = 0; round < num_active_rounds; ++round) {
243+
uint32_t child_id = active_children_per_round[round];
244244

245245
if (child_id != UINT32_MAX) {
246246
// Wait for this specific child to send its results

0 commit comments

Comments
 (0)