Skip to content

SDPA Decode Optimization: Tree Reduce#37004

Open
alingTT wants to merge 4 commits intomainfrom
aling/tree-reduce
Open

SDPA Decode Optimization: Tree Reduce#37004
alingTT wants to merge 4 commits intomainfrom
aling/tree-reduce

Conversation

@alingTT
Copy link
Contributor

@alingTT alingTT commented Feb 2, 2026

Ticket

NA

Problem description

SDPA optimization needed for DS and Llama

What's changed

Reduction was previously O(n-1) time where n was the number of cores in a reducer group.
We can optimize this by using tree reduction where pairs of cores perform reduction. So complexity reduced to O(log(n)). On llama 70b galaxy shapes we see 8.3 us -> 7.4us improvement.

Checklist

@alingTT alingTT marked this pull request as ready for review February 6, 2026 02:42
@alingTT alingTT requested review from a team as code owners February 6, 2026 02:42
Copilot AI review requested due to automatic review settings February 6, 2026 02:42
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request implements a tree reduction optimization for SDPA (Scaled Dot-Product Attention) decode operations, improving the reduction complexity from O(n-1) to O(log n) where n is the number of cores in a reduction group. The change replaces the flat worker-to-reducer pattern with a binary tree reduction where cores hierarchically combine their attention results.

Changes:

  • Introduced tree reduction helper functions (count_trailing_zeros, ceil_log2, get_tree_reduction_params) to compute binary tree structure
  • Modified program factory to compute and pass tree reduction parameters to each core
  • Updated writer and compute kernels to perform round-by-round tree reduction with proper synchronization
  • Added semaphore encoding scheme using 4-bit nibbles per round for fine-grained synchronization

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 10 comments.

File Description
sdpa_decode_program_factory.cpp Adds tree reduction parameter calculation and passes them to kernels; builds physical core coordinate arrays for tree communication
writer_decode_all.cpp Implements tree reduction receiving and sending logic with round-based semaphore synchronization
sdpa_flash_decode.cpp Modifies compute flow to combine child results in tree pattern and handle root vs non-root finalization
reader_decode_all.cpp Minor cleanup (blank line removal, assert include)

}
// Count trailing ones in vid
uint32_t trailing_ones = count_trailing_zeros(~vid);
// Root in vid-space is 0 → physical core 0
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment "Root in vid-space is 0 → physical core 0" is incorrect. Based on the vid mapping (vid = num_cores - 1 - core_id) and root_vid = num_cores - 1, the root has vid = num_cores - 1, not vid = 0. For example, with 8 cores: core 0 has vid=7 (root), core 7 has vid=0 (leaf). The comment should say "Root in vid-space is (num_cores-1) → physical core 0".

Suggested change
// Root in vid-space is 0 → physical core 0
// Root in vid-space is (num_cores_in_group - 1) → physical core 0

Copilot uses AI. Check for mistakes.
Comment on lines 245 to 246
ASSERT(num_heads_per_core == 1); // if there are workers, then head must be split across workers

Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ASSERT checking num_heads_per_core == 1 when actual_num_children > 0 may be too restrictive. This assertion fails if a core processes multiple heads and any of those heads require tree reduction. However, the tree reduction logic operates per-head (inside the cur_head loop starting at line 236), so it should support multiple heads per core. Consider removing this assertion or moving it outside the head loop if the intent is to ensure heads aren't split across cores within a single reduction group.

Suggested change
ASSERT(num_heads_per_core == 1); // if there are workers, then head must be split across workers

Copilot uses AI. Check for mistakes.
cb_wait_front(cb_index_id, 1);
cur_pos = read_tile_value(cb_index_id, 0, cur_batch / q_heads_parallel_factor);
cb_pop_front(cb_index_id, 1);
// cb_pop_front(cb_index_id, 1);
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cb_pop_front for cb_index_id is commented out (line 149), which may cause the circular buffer to fill up if multiple operations try to use it. The comment references issue #27979, suggesting this is a known workaround for a mailbox-based synchronization issue. However, if this buffer is meant to be reused across multiple calls or heads, not popping it will eventually exhaust the buffer. Consider documenting why this pop is commented out and whether it needs to be addressed when the referenced issue is resolved.

Copilot uses AI. Check for mistakes.
Comment on lines +335 to +336
// Senders can return, dont need to participate
return;
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Early return at line 336 skips processing of remaining heads when num_heads_per_core > 1. The return statement is inside the head loop (lines 236-238), so sending to parent causes the kernel to exit completely rather than continuing to the next head. This breaks multi-head processing. The return should be replaced with a break or continue, or the send logic should be moved outside the head loop if each core only processes one head during tree reduction (which the ASSERT at line 245 suggests).

Suggested change
// Senders can return, dont need to participate
return;
// Senders can stop participating in tree reduction for this head
break;

Copilot uses AI. Check for mistakes.
Comment on lines +255 to +262
while (true) {
invalidate_l1_cache();
uint32_t sem_val = *in0_receiver_semaphore_addr_ptr;
uint8_t step_sem = (sem_val >> step_semaphore_shift[round]) & 0x0F;
if (step_sem >= 1) {
break;
}
}
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semaphore is polled but never reset within or after the tree reduction rounds. If num_heads_per_core > 1, the outer loop iterates multiple times (lines 236-238), and the semaphore state from the first head will cause the second head to incorrectly think data has arrived. The semaphore needs to be reset after processing all children, or the encoding scheme needs to account for multiple heads per core.

Copilot uses AI. Check for mistakes.
@@ -499,43 +575,57 @@ void kernel_main() {
move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CB pop operations after consuming child data. The code reads from cb_l_in (line 537) and cb_out_o (line 559) but only pops cb_m_in (line 573). The cb_l_in and cb_out_o circular buffers need cb_pop_front calls to free the consumed tiles, otherwise the CBs will fill up and block on subsequent rounds or heads. Add: cb_pop_front(cb_l_in, Sq_chunk_t) and cb_pop_front(cb_out_o, out_chunk_tiles) after line 575.

Suggested change
move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t);
// Release consumed child L and O tiles from their circular buffers.
cb_pop_front(cb_l_in, Sq_chunk_t);
cb_pop_front(cb_out_o, out_chunk_tiles);

Copilot uses AI. Check for mistakes.

// Free up prev buffers if we used them
if (actual_num_children > 0) {
cb_pop_front(cb_prev_max, Sq_chunk_t);
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing cb_pop_front for sum buffer in root finalization. At line 660, only cb_prev_max is popped, but cb_prev_sum is also consumed (used in lines 607, 614, 622, 628) and needs to be freed. When actual_num_children == 0, cb_cur_sum is used instead but also not popped. Add cb_pop_front(sum_cb, Sq_chunk_t) after line 660 to properly release the consumed sum buffer.

Suggested change
cb_pop_front(cb_prev_max, Sq_chunk_t);
cb_pop_front(cb_prev_max, Sq_chunk_t);
// Also free the sum buffer (prev or current), which was consumed via sum_cb
cb_pop_front(sum_cb, Sq_chunk_t);

Copilot uses AI. Check for mistakes.
Comment on lines +1220 to +1221
// + reducer coords + output coords
std::vector<uint32_t> writer_rt_args(16 + MAX_TREE_REDUCTION_ROUNDS + 2 * num_cores_per_head, 0);
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The writer runtime args size for idle cores doesn't account for reducer and output core coordinates. The active cores receive additional args via insert operations at lines 1180-1183, appending reduce_core_physical_xs, reduce_core_physical_ys, output_core_physical_xs, and output_core_physical_ys. However, the idle core size calculation at line 1221 only accounts for base args, children_per_round, and group coords. The size should be: 16 + MAX_TREE_REDUCTION_ROUNDS + 2num_cores_per_head + 2num_reducer_cores + 2*num_output_cores.

Suggested change
// + reducer coords + output coords
std::vector<uint32_t> writer_rt_args(16 + MAX_TREE_REDUCTION_ROUNDS + 2 * num_cores_per_head, 0);
// + reducer coords (2*num_reducer_cores) + output coords (2*num_output_cores)
std::vector<uint32_t> writer_rt_args(
16 + MAX_TREE_REDUCTION_ROUNDS + 2 * num_cores_per_head + 2 * num_reducer_cores + 2 * num_output_cores,
0);

Copilot uses AI. Check for mistakes.
Comment on lines +65 to +70
// Semaphore encoding: each round uses a 4-bit field (nibble) in the semaphore value
// Round 0: bits 0-3, Round 1: bits 4-7, Round 2: bits 8-11, etc.
// step_semaphore_inc[r] = 1 << (r * 4) is the value to add to increment round r's counter
constexpr uint32_t step_semaphore_inc[6] = {1, 16, 256, 4096, 65536, 1048576};
// step_semaphore_shift[r] = r * 4 is the bit position to read round r's counter
constexpr uint32_t step_semaphore_shift[6] = {0, 4, 8, 12, 16, 20};
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semaphore encoding uses 4-bit nibbles per round (lines 65-70), which limits each round's counter to 0-15. This means a parent can receive from at most 15 children per round. With the binary tree structure, each parent receives from at most 1 child per round, so this is sufficient. However, if the tree structure changes in the future to allow multiple children per round, this encoding would fail. Consider adding a compile-time assertion or comment explaining this constraint.

Copilot uses AI. Check for mistakes.

// Calculate tree reduction parameters
// num_tree_reduction_rounds = ceil(log2(num_cores_per_head))
uint32_t num_tree_reduction_rounds = ceil_log2(num_cores_per_head);
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No runtime validation that num_tree_reduction_rounds doesn't exceed MAX_TREE_REDUCTION_ROUNDS. If num_cores_per_head is greater than 2^MAX_TREE_REDUCTION_ROUNDS (64), the system would silently produce incorrect results or access out-of-bounds array indices. Add a runtime check: TT_FATAL(num_tree_reduction_rounds <= MAX_TREE_REDUCTION_ROUNDS, "Tree reduction rounds {} exceeds maximum {}", num_tree_reduction_rounds, MAX_TREE_REDUCTION_ROUNDS).

Suggested change
uint32_t num_tree_reduction_rounds = ceil_log2(num_cores_per_head);
uint32_t num_tree_reduction_rounds = ceil_log2(num_cores_per_head);
TT_FATAL(
num_tree_reduction_rounds <= MAX_TREE_REDUCTION_ROUNDS,
"Tree reduction rounds {} exceeds maximum {}",
num_tree_reduction_rounds,
MAX_TREE_REDUCTION_ROUNDS);

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

};

inline TreeReductionParams get_tree_reduction_params(uint32_t core_id_in_group, uint32_t num_cores_in_group) {
TreeReductionParams params;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ cppcoreguidelines-pro-type-member-init ⚠️
uninitialized record type: params

Suggested change
TreeReductionParams params;
TreeReductionParams params{};

Comment on lines 94 to 95
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; r++) {
params.children_per_round[r] = UINT32_MAX;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ modernize-loop-convert ⚠️
use range-based for loop instead

Suggested change
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; r++) {
params.children_per_round[r] = UINT32_MAX;
for (unsigned int & r : params.children_per_round) {
r = UINT32_MAX;

Comment on lines +1168 to +1169
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
writer_rt_args.push_back(tree_params.children_per_round[r]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ modernize-loop-convert ⚠️
use range-based for loop instead

Suggested change
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
writer_rt_args.push_back(tree_params.children_per_round[r]);
for (unsigned int r : tree_params.children_per_round) {
writer_rt_args.push_back(r);

Comment on lines +1202 to +1203
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
compute_rt_args.push_back(tree_params.children_per_round[r]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ modernize-loop-convert ⚠️
use range-based for loop instead

Suggested change
for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) {
compute_rt_args.push_back(tree_params.children_per_round[r]);
for (unsigned int r : tree_params.children_per_round) {
compute_rt_args.push_back(r);

// writer runtime args - need to match the size with tree reduction params
// Base args (16) + children_per_round (MAX_TREE_REDUCTION_ROUNDS) + group coords (2*num_cores_per_head)
// + reducer coords + output coords
std::vector<uint32_t> writer_rt_args(16 + MAX_TREE_REDUCTION_ROUNDS + 2 * num_cores_per_head, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ readability-math-missing-parentheses ⚠️
* has higher precedence than +; add parentheses to explicitly specify the order of operations

Suggested change
std::vector<uint32_t> writer_rt_args(16 + MAX_TREE_REDUCTION_ROUNDS + 2 * num_cores_per_head, 0);
std::vector<uint32_t> writer_rt_args(16 + MAX_TREE_REDUCTION_ROUNDS + (2 * num_cores_per_head), 0);

@alingTT
Copy link
Contributor Author

alingTT commented Feb 9, 2026

/codeowners ping

@tenstorrent-github-bot
Copy link

tenstorrent-github-bot commented Feb 9, 2026

CodeOwners Group Analysis

This PR requires approval from one member of each of the following groups:

Summary: 1 pending groups, 0 approved groups

Group Information:

Note: At least one approval from each group is sufficient.

@alingTT
Copy link
Contributor Author

alingTT commented Feb 9, 2026

/codeowners ping

@tenstorrent-github-bot
Copy link

🔄 CodeOwners Summary Updated

CodeOwners summary updated here

💡 Tip: Use /codeowners new to post a fresh summary comment instead of updating the existing one.

@tenstorrent-github-bot
Copy link

Hi Evan Smal (@esmalTT), Raymond Kim (@tt-rkim), this PR SDPA Decode Optimization: Tree Reduce by Ambrose Ling (@alingTT) needs your approval/review to merge this.

uint32_t child_id = actual_children_per_round[round];
if (child_id != UINT32_MAX) {
// Writer kernel handles the semaphore wait and data transfer to cb_m_in, cb_l_in, cb_out_o
// Data arrives in order: m, l, o
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does data arrive in order m, l, o, if l is processed first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the send order got messed up, should be arriving in order of l, m, o just fixed, will do a few more passes of the code to double check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants