Skip to content

Commit 2e312df

Browse files
#36982: create_q_heads tilizes to 8x32 tiles (#37574)
### Ticket [Link to Github Issue](#36982) ### Problem description SDPA input needed to be tilized in 8x32 tiles. The previous implementation gathered to the target grid and left the data in row major, and the QRoPE output was split. This needed to be changed according to the design described in the linked issue. ### What's changed - Changed `gather_heads` --> `create_q_heads` - No longer splitting QRoPE output, QNoPE heads now split because of max tilize dim being 256, changed noc write offsets of the heads for proper tilize order - Integrated standalone op changes into pre-SDPA fused kernel - Updated standalone op unit tests and pre-SDPA unit test ### Checklist - [x] [![All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml/badge.svg?branch=ssundaram/deepseek_blitz_create_q_heads)](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml?query=branch:ssundaram/deepseek_blitz_create_q_heads) - [x] [![Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml/badge.svg?branch=ssundaram/deepseek_blitz_create_q_heads)](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml?query=branch:ssundaram/deepseek_blitz_create_q_heads) - [ ] [![cpp-unit-tests](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml/badge.svg?branch={{branch_name}})](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml?query=branch:{{branch_name}}) - [ ] New/Existing tests provide coverage for changes #### Model tests If your changes cover model-related code, you should run tests corresponding to affected models and platforms (Single card, T3K, Galaxy). "Choose your pipeline" workflows facilitate running multiple kinds of tests in a single run. Each offers `models-mandatory` and `models-extended` presets. The former includes a minimal set of tests, to be run always. The latter extends that with additional ones - use your best judgement in deciding which is the most appropriate for your PR. - [ ] [![(Single) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml/badge.svg?branch={{branch_name}})](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml?query=branch:{{branch_name}}) - [ ] `models-mandatory` preset (runs: [Device perf regressions](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) and [Frequent model and ttnn tests](https://github.com/tenstorrent/tt-metal/actions/workflows/fast-dispatch-full-regressions-and-models.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/single-card-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) tests) - [ ] other selection - specify runs - [ ] [![(T3K) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml/badge.svg?branch={{branch_name}})](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml?query=branch:{{branch_name}}) - [ ] `models-mandatory` preset (runs: [Unit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-unit-tests.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-model-perf-tests.yaml) tests) - [ ] other selection - specify runs - [ ] [![(Galaxy) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml/badge.svg?branch={{branch_name}})](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml?query=branch:{{branch_name}}) - [ ] `models-mandatory` preset (runs: [Quick tests](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-quick.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-perf-tests.yaml) tests) - [ ] other selection - specify runs --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent a4ea4a6 commit 2e312df

File tree

12 files changed

+1290
-1438
lines changed

12 files changed

+1290
-1438
lines changed

models/demos/deepseek_v3_b1/fused_ops/pre_sdpa/kernels/pre_sdpa_kernel.cpp

Lines changed: 85 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
// Each RISC has its own CTArgs struct with different compile-time arg layout
77
//
88
// Implements: CCL Broadcast + RMSNorm + Mcast + Matmul + Gather + RMSNorm2 + Mcast2 + Matmul2 + Matmul3 + RoPE +
9-
// GatherHeads
9+
// CreateQHeads
1010
// - NCRISC: CCL Broadcast Reader + RMSNorm reader + Mcast receiver (on matmul cores), Matmul reader + Gather sender (on
1111
// matmul cores),
1212
// RMSNorm2 reader + Mcast2 receiver (on matmul2 cores), Matmul2 reader (on matmul2 cores),
13-
// Matmul3 reader (on qnope cores), RoPE reader (on qrope cores), GatherHeads sender (on qnope/qrope cores)
13+
// Matmul3 reader (on qnope cores), RoPE reader (on qrope cores), CreateQHeads sender (on qnope/qrope cores)
1414
// - BRISC: CCL Broadcast Writer + RMSNorm writer + Mcast sender (on input core), Matmul writer (on matmul cores),
1515
// Gather receiver (on
1616
// input core), Mcast2 sender (on input core), Matmul2 writer (on matmul2 cores),
17-
// GatherHeads receiver (on sdpa input cores)
17+
// CreateQHeads receiver (on sdpa input cores) - matching gather pattern: NCRISC sender, BRISC receiver
1818
// - TRISC: RMSNorm compute (on input core), Matmul compute (on matmul cores), RMSNorm2 compute (on input core),
1919
// Matmul2 compute (on matmul2 cores), Matmul3 compute (on qnope cores), RoPE compute (on qrope cores)
2020
//
@@ -33,7 +33,7 @@
3333
#include "../../../unified_kernels/gather.hpp"
3434
#include "../../../unified_kernels/gather_reduce.hpp"
3535
#include "../../../unified_kernels/kn_sliced_matmul.hpp"
36-
#include "../../../unified_kernels/gather_heads.hpp"
36+
#include "../../../unified_kernels/create_q_heads.hpp"
3737
#include "../../../unified_kernels/rope.hpp"
3838
#include "../../../unified_kernels/broadcast.hpp"
3939

@@ -48,7 +48,7 @@ struct Core {
4848
// Qrope cores: 32 cores (4x8 grid), each handles 2 heads of 64 elements
4949
static constexpr bool is_qnope_core = get_named_compile_time_arg_val("is_qnope_core") == 1;
5050
static constexpr bool is_qrope_core = get_named_compile_time_arg_val("is_qrope_core") == 1;
51-
// SDPA Input core: receives interleaved QNOPE/QROPE gather heads (4×2 grid = 8 cores)
51+
// SDPA Input core: receives interleaved QNOPE/QROPE, runs create q heads (4×2 grid = 8 cores)
5252
static constexpr bool is_sdpa_input_core = get_named_compile_time_arg_val("is_sdpa_input_core") == 1;
5353

5454
// DKV Matmul core: 9x2 grid, each core handles 1 head of 32 dim
@@ -144,12 +144,43 @@ void kernel_main() {
144144
deepseek_b1_ops::Matmul::ReaderArgs matmul3_args{};
145145

146146
// Qrope CTArgs type alias (NCRISC uses ReaderCTArgs)
147-
using QRopeCTArgs =
148-
deepseek_b1_ops::Rope::ReaderCTArgs<get_named_compile_time_arg_val("Wt"), get_named_compile_time_arg_val("Ht")>;
147+
using QRopeCTArgs = deepseek_b1_ops::Rope::
148+
ReaderCTArgs<get_named_compile_time_arg_val("qrope_Wt"), get_named_compile_time_arg_val("qrope_Ht")>;
149149

150150
// Qrope reader args (NCRISC is no-op)
151151
deepseek_b1_ops::Rope::ReaderArgs qrope_args{};
152152

153+
// NCRISC: Sender args for QNOPE/QROPE cores
154+
// Senders write to intermediate CB, then compute tilizes to output CB
155+
// 3-phase synchronization: nope_phase1, nope_phase2, rope semaphores
156+
constexpr uint32_t cqh_receiver_in_cb = get_named_compile_time_arg_val("cqh_receiver_in_cb");
157+
deepseek_b1_ops::CreateQHeads::SenderArgs create_q_heads_args{
158+
0, // sender_grid_start_x (logical 0)
159+
0, // sender_grid_start_y (logical 0)
160+
get_named_compile_time_arg_val("cqh_qnope_data_size_bytes"),
161+
get_named_compile_time_arg_val("cqh_qrope_head_size_bytes"),
162+
get_named_compile_time_arg_val("cqh_head_stride_bytes"),
163+
get_named_compile_time_arg_val("cqh_qnope_cols"),
164+
get_named_compile_time_arg_val("cqh_qnope_src_cb"),
165+
get_named_compile_time_arg_val("cqh_qrope_src_cb"),
166+
Core::is_qnope_core ? get_named_compile_time_arg_val("cqh_qnope_src_num_pages")
167+
: get_named_compile_time_arg_val("cqh_qrope_src_num_pages"),
168+
get_named_compile_time_arg_val("cqh_nope_phase1_semaphore_id"),
169+
get_named_compile_time_arg_val("cqh_nope_phase2_semaphore_id"),
170+
get_named_compile_time_arg_val("cqh_rope_semaphore_id"),
171+
{
172+
get_named_compile_time_arg_val("cqh_target_noc_coords_row0"),
173+
get_named_compile_time_arg_val("cqh_target_noc_coords_row1"),
174+
get_named_compile_time_arg_val("cqh_target_noc_coords_row2"),
175+
get_named_compile_time_arg_val("cqh_target_noc_coords_row3"),
176+
get_named_compile_time_arg_val("cqh_target_noc_coords_row4"),
177+
get_named_compile_time_arg_val("cqh_target_noc_coords_row5"),
178+
get_named_compile_time_arg_val("cqh_target_noc_coords_row6"),
179+
get_named_compile_time_arg_val("cqh_target_noc_coords_row7"),
180+
},
181+
get_write_ptr(cqh_receiver_in_cb),
182+
};
183+
153184
// Matmul CTArgs type alias (NCRISC uses ReaderCTArgs)
154185
using DKV_MatmulCTArgs = deepseek_b1_ops::Matmul::ReaderCTArgs;
155186

@@ -288,6 +319,19 @@ void kernel_main() {
288319
get_named_compile_time_arg_val("gather_reduce_dst_num_tiles"),
289320
};
290321

322+
// BRISC: Receiver args for SDPA input cores
323+
deepseek_b1_ops::CreateQHeads::ReceiverArgs create_q_heads_args{
324+
get_named_compile_time_arg_val("cqh_nope_phase1_semaphore_id"),
325+
get_named_compile_time_arg_val("cqh_nope_phase2_semaphore_id"),
326+
get_named_compile_time_arg_val("cqh_rope_semaphore_id"),
327+
get_named_compile_time_arg_val("cqh_num_nope_senders"),
328+
get_named_compile_time_arg_val("cqh_num_rope_senders"),
329+
get_named_compile_time_arg_val("cqh_receiver_in_cb"),
330+
get_named_compile_time_arg_val("cqh_out_cb"),
331+
get_named_compile_time_arg_val("cqh_nope_tiles"),
332+
get_named_compile_time_arg_val("cqh_rope_tiles"),
333+
};
334+
291335
// Matmul2 writer args (BRISC is no-op)
292336
deepseek_b1_ops::Matmul::WriterArgs matmul2_args{};
293337

@@ -321,34 +365,6 @@ void kernel_main() {
321365
get_write_ptr(matmul2_in0), // Write to matmul2_in0 (loopback)
322366
};
323367

324-
// BRISC: Sender args for QNOPE/QROPE cores
325-
// Senders write directly to output CB (allocated on sender+receiver cores)
326-
constexpr uint32_t receive_cb = get_named_compile_time_arg_val("receive_cb");
327-
deepseek_b1_ops::GatherHeads::SenderArgs gather_heads_args{
328-
0, // sender_grid_start_x (logical 0)
329-
0, // sender_grid_start_y (logical 0)
330-
get_named_compile_time_arg_val("qnope_data_size_bytes"),
331-
get_named_compile_time_arg_val("qrope_data_size_bytes"),
332-
get_named_compile_time_arg_val("head_stride_bytes"),
333-
get_named_compile_time_arg_val("qnope_grid_cols"),
334-
get_named_compile_time_arg_val("qnope_src_cb"),
335-
get_named_compile_time_arg_val("qrope_src_cb"),
336-
Core::is_qnope_core ? get_named_compile_time_arg_val("qnope_src_num_pages")
337-
: get_named_compile_time_arg_val("qrope_src_num_pages"),
338-
get_named_compile_time_arg_val("receiver_semaphore_id"),
339-
{
340-
get_named_compile_time_arg_val("target_noc_coords_row0"),
341-
get_named_compile_time_arg_val("target_noc_coords_row1"),
342-
get_named_compile_time_arg_val("target_noc_coords_row2"),
343-
get_named_compile_time_arg_val("target_noc_coords_row3"),
344-
get_named_compile_time_arg_val("target_noc_coords_row4"),
345-
get_named_compile_time_arg_val("target_noc_coords_row5"),
346-
get_named_compile_time_arg_val("target_noc_coords_row6"),
347-
get_named_compile_time_arg_val("target_noc_coords_row7"),
348-
},
349-
get_write_ptr(receive_cb), // Write directly to output CB
350-
};
351-
352368
// Matmul writer args (BRISC is no-op)
353369
using DKV_MatmulCTArgs = deepseek_b1_ops::Matmul::WriterCTArgs;
354370
deepseek_b1_ops::Matmul::WriterArgs dkv_matmul_args{};
@@ -471,22 +487,27 @@ void kernel_main() {
471487

472488
// Qrope CTArgs type alias
473489
using QRopeCTArgs = deepseek_b1_ops::Rope::
474-
ComputeCTArgs<get_named_compile_time_arg_val("Wt"), get_named_compile_time_arg_val("Ht")>;
490+
ComputeCTArgs<get_named_compile_time_arg_val("qrope_Wt"), get_named_compile_time_arg_val("qrope_Ht")>;
475491

476492
// Qrope compute args (from compile-time args)
477493
deepseek_b1_ops::Rope::ComputeArgs qrope_args{
478-
get_named_compile_time_arg_val("in_cb"), // Input from matmul2 output
479-
get_named_compile_time_arg_val("cos_cb"),
480-
get_named_compile_time_arg_val("sin_cb"),
481-
get_named_compile_time_arg_val("trans_mat_cb"),
482-
get_named_compile_time_arg_val("rotated_in_interm_cb"),
483-
get_named_compile_time_arg_val("cos_interm_cb"),
484-
get_named_compile_time_arg_val("sin_interm_cb"),
485-
get_named_compile_time_arg_val("out_cb"),
494+
get_named_compile_time_arg_val("qrope_in_cb"), // Input from matmul2 output
495+
get_named_compile_time_arg_val("qrope_cos_cb"),
496+
get_named_compile_time_arg_val("qrope_sin_cb"),
497+
get_named_compile_time_arg_val("qrope_trans_mat_cb"),
498+
get_named_compile_time_arg_val("qrope_rotated_in_interm_cb"),
499+
get_named_compile_time_arg_val("qrope_cos_interm_cb"),
500+
get_named_compile_time_arg_val("qrope_sin_interm_cb"),
501+
get_named_compile_time_arg_val("qrope_output_cb"),
486502
};
487503

488-
// Gather heads compute args (no-op for TRISC)
489-
deepseek_b1_ops::GatherHeads::ComputeArgs gather_heads_args{};
504+
// CreateQHeads compute args (tilization on SDPA input cores)
505+
deepseek_b1_ops::CreateQHeads::ComputeArgs create_q_heads_args{
506+
get_named_compile_time_arg_val("cqh_receiver_in_cb"),
507+
get_named_compile_time_arg_val("cqh_out_cb"),
508+
get_named_compile_time_arg_val("cqh_nope_tiles"),
509+
get_named_compile_time_arg_val("cqh_rope_tiles"),
510+
};
490511

491512
// DKV Matmul compute args
492513
using DKV_MatmulCTArgs =
@@ -525,7 +546,7 @@ void kernel_main() {
525546
constexpr uint32_t krope_input_cb = get_named_compile_time_arg_val("krope_in_cb");
526547
constexpr uint32_t krope_cos_cb = get_named_compile_time_arg_val("krope_cos_cb");
527548
constexpr uint32_t krope_sin_cb = get_named_compile_time_arg_val("krope_sin_cb");
528-
constexpr uint32_t trans_mat_cb = get_named_compile_time_arg_val("trans_mat_cb");
549+
constexpr uint32_t krope_trans_mat_cb = get_named_compile_time_arg_val("krope_trans_mat_cb");
529550
constexpr uint32_t krope_rotated_in_interm_cb = get_named_compile_time_arg_val("krope_rotated_in_interm_cb");
530551
constexpr uint32_t krope_cos_interm_cb = get_named_compile_time_arg_val("krope_cos_interm_cb");
531552
constexpr uint32_t krope_sin_interm_cb = get_named_compile_time_arg_val("krope_sin_interm_cb");
@@ -536,7 +557,7 @@ void kernel_main() {
536557
.in_cb = krope_input_cb,
537558
.cos_cb = krope_cos_cb,
538559
.sin_cb = krope_sin_cb,
539-
.trans_mat_cb = trans_mat_cb,
560+
.trans_mat_cb = krope_trans_mat_cb,
540561
.rotated_in_interm_cb = krope_rotated_in_interm_cb,
541562
.cos_interm_cb = krope_cos_interm_cb,
542563
.sin_interm_cb = krope_sin_interm_cb,
@@ -587,10 +608,10 @@ void kernel_main() {
587608

588609
if constexpr (Core::is_qrope_core) {
589610
// Qrope CB indices and parameters from named compile-time args
590-
constexpr uint32_t qrope_cos_cb = get_named_compile_time_arg_val("cos_cb");
591-
constexpr uint32_t qrope_sin_cb = get_named_compile_time_arg_val("sin_cb");
592-
constexpr uint32_t qrope_trans_mat_cb = get_named_compile_time_arg_val("trans_mat_cb");
593-
constexpr uint32_t Wt = get_named_compile_time_arg_val("Wt");
611+
constexpr uint32_t qrope_cos_cb = get_named_compile_time_arg_val("qrope_cos_cb");
612+
constexpr uint32_t qrope_sin_cb = get_named_compile_time_arg_val("qrope_sin_cb");
613+
constexpr uint32_t qrope_trans_mat_cb = get_named_compile_time_arg_val("qrope_trans_mat_cb");
614+
constexpr uint32_t Wt = get_named_compile_time_arg_val("qrope_Wt");
594615

595616
// NOTE: Do NOT setup qrope input CB (matmul2_output_cb) as sharded buffer!
596617
// The input to RoPE comes from matmul2 compute output, NOT from a sharded tensor.
@@ -601,14 +622,6 @@ void kernel_main() {
601622
unified_kernels::setup_sharded_buffer(qrope_trans_mat_cb, 1); // trans_mat is 1 tile (32x32)
602623
}
603624

604-
// NCRISC: Receiver args for SDPA input cores
605-
deepseek_b1_ops::GatherHeads::ReceiverArgs gather_heads_args{
606-
get_named_compile_time_arg_val("receiver_semaphore_id"),
607-
get_named_compile_time_arg_val("num_senders"),
608-
get_named_compile_time_arg_val("receive_cb"), // Output CB
609-
get_named_compile_time_arg_val("dst_num_pages"),
610-
};
611-
612625
if constexpr (Core::is_dkv_matmul_core) {
613626
// Matmul weights (in1)
614627
constexpr uint32_t dkv_matmul_in1 = get_named_compile_time_arg_val("dkv_matmul_in1");
@@ -792,24 +805,23 @@ void kernel_main() {
792805
}
793806

794807
// ========================================================================
795-
// GatherHeads: QNOPE/QROPE -> SDPA interleaved transfer
796-
// QNOPE cores (cols 0-7): send [1, 512] to SDPA at offset = head_idx * 576
797-
// QROPE cores (cols 8-11): send 2x [1, 64] to SDPA at offsets:
798-
// - head_idx * 576 + 512
799-
// - (head_idx + 1) * 576 + 512
808+
// CreateQHeads: 3-phase QNOPE/QROPE -> SDPA transfer with tilization
809+
// Phase 1: QNOPE first 256 elements → [8, 256] row-major → 8 tiles
810+
// Phase 2: QNOPE second 256 elements → [8, 256] row-major → 8 tiles
811+
// Phase 3: QROPE 64 elements per head → [8, 64] row-major → 2 tiles
812+
// Senders write to intermediate CB, TRISC tilizes to output CB
813+
// NCRISC sends from qnope/qrope cores, BRISC receives on sdpa input cores, TRISC no-op
800814
// ========================================================================
801815
{
802-
DeviceZoneScopedN("GATHER_HEADS");
803-
// GatherHeads Op configuration:
816+
DeviceZoneScopedN("CREATE_Q_HEADS");
817+
// CreateQHeads Op configuration:
804818
// - IsSenderCore: is_qnope_core || is_qrope_core
805819
// - IsReceiverCore: is_sdpa_input_core
806-
// - setup_sharded_input: false (data already in CB from previous compute)
807820
// - pop_src: true (pop source CB after sending)
808-
// - use_cb_output: true (receiver uses cb_reserve_back/cb_push_back, writes directly to output tensor)
809-
constexpr bool is_gather_heads_sender = Core::is_qnope_core || Core::is_qrope_core;
810-
deepseek_b1_ops::GatherHeads::Op<is_gather_heads_sender, Core::is_sdpa_input_core, false, true, true>
811-
gather_heads;
812-
gather_heads(gather_heads_args);
821+
constexpr bool is_create_q_heads_sender = Core::is_qnope_core || Core::is_qrope_core;
822+
deepseek_b1_ops::CreateQHeads::Op<is_create_q_heads_sender, Core::is_sdpa_input_core, false, true>
823+
create_q_heads;
824+
create_q_heads(create_q_heads_args);
813825
}
814826
}
815827
{

0 commit comments

Comments
 (0)