66// Each RISC has its own CTArgs struct with different compile-time arg layout
77//
88// Implements: CCL Broadcast + RMSNorm + Mcast + Matmul + Gather + RMSNorm2 + Mcast2 + Matmul2 + Matmul3 + RoPE +
9- // GatherHeads
9+ // CreateQHeads
1010// - NCRISC: CCL Broadcast Reader + RMSNorm reader + Mcast receiver (on matmul cores), Matmul reader + Gather sender (on
1111// matmul cores),
1212// RMSNorm2 reader + Mcast2 receiver (on matmul2 cores), Matmul2 reader (on matmul2 cores),
13- // Matmul3 reader (on qnope cores), RoPE reader (on qrope cores), GatherHeads sender (on qnope/qrope cores)
13+ // Matmul3 reader (on qnope cores), RoPE reader (on qrope cores), CreateQHeads sender (on qnope/qrope cores)
1414// - BRISC: CCL Broadcast Writer + RMSNorm writer + Mcast sender (on input core), Matmul writer (on matmul cores),
1515// Gather receiver (on
1616// input core), Mcast2 sender (on input core), Matmul2 writer (on matmul2 cores),
17- // GatherHeads receiver (on sdpa input cores)
17+ // CreateQHeads receiver (on sdpa input cores) - matching gather pattern: NCRISC sender, BRISC receiver
1818// - TRISC: RMSNorm compute (on input core), Matmul compute (on matmul cores), RMSNorm2 compute (on input core),
1919// Matmul2 compute (on matmul2 cores), Matmul3 compute (on qnope cores), RoPE compute (on qrope cores)
2020//
3333#include " ../../../unified_kernels/gather.hpp"
3434#include " ../../../unified_kernels/gather_reduce.hpp"
3535#include " ../../../unified_kernels/kn_sliced_matmul.hpp"
36- #include " ../../../unified_kernels/gather_heads .hpp"
36+ #include " ../../../unified_kernels/create_q_heads .hpp"
3737#include " ../../../unified_kernels/rope.hpp"
3838#include " ../../../unified_kernels/broadcast.hpp"
3939
@@ -48,7 +48,7 @@ struct Core {
4848 // Qrope cores: 32 cores (4x8 grid), each handles 2 heads of 64 elements
4949 static constexpr bool is_qnope_core = get_named_compile_time_arg_val(" is_qnope_core" ) == 1 ;
5050 static constexpr bool is_qrope_core = get_named_compile_time_arg_val(" is_qrope_core" ) == 1 ;
51- // SDPA Input core: receives interleaved QNOPE/QROPE gather heads (4×2 grid = 8 cores)
51+ // SDPA Input core: receives interleaved QNOPE/QROPE, runs create q heads (4×2 grid = 8 cores)
5252 static constexpr bool is_sdpa_input_core = get_named_compile_time_arg_val(" is_sdpa_input_core" ) == 1 ;
5353
5454 // DKV Matmul core: 9x2 grid, each core handles 1 head of 32 dim
@@ -144,12 +144,43 @@ void kernel_main() {
144144 deepseek_b1_ops::Matmul::ReaderArgs matmul3_args{};
145145
146146 // Qrope CTArgs type alias (NCRISC uses ReaderCTArgs)
147- using QRopeCTArgs =
148- deepseek_b1_ops::Rope:: ReaderCTArgs<get_named_compile_time_arg_val (" Wt " ), get_named_compile_time_arg_val (" Ht " )>;
147+ using QRopeCTArgs = deepseek_b1_ops::Rope::
148+ ReaderCTArgs<get_named_compile_time_arg_val (" qrope_Wt " ), get_named_compile_time_arg_val (" qrope_Ht " )>;
149149
150150 // Qrope reader args (NCRISC is no-op)
151151 deepseek_b1_ops::Rope::ReaderArgs qrope_args{};
152152
153+ // NCRISC: Sender args for QNOPE/QROPE cores
154+ // Senders write to intermediate CB, then compute tilizes to output CB
155+ // 3-phase synchronization: nope_phase1, nope_phase2, rope semaphores
156+ constexpr uint32_t cqh_receiver_in_cb = get_named_compile_time_arg_val (" cqh_receiver_in_cb" );
157+ deepseek_b1_ops::CreateQHeads::SenderArgs create_q_heads_args{
158+ 0 , // sender_grid_start_x (logical 0)
159+ 0 , // sender_grid_start_y (logical 0)
160+ get_named_compile_time_arg_val (" cqh_qnope_data_size_bytes" ),
161+ get_named_compile_time_arg_val (" cqh_qrope_head_size_bytes" ),
162+ get_named_compile_time_arg_val (" cqh_head_stride_bytes" ),
163+ get_named_compile_time_arg_val (" cqh_qnope_cols" ),
164+ get_named_compile_time_arg_val (" cqh_qnope_src_cb" ),
165+ get_named_compile_time_arg_val (" cqh_qrope_src_cb" ),
166+ Core::is_qnope_core ? get_named_compile_time_arg_val (" cqh_qnope_src_num_pages" )
167+ : get_named_compile_time_arg_val (" cqh_qrope_src_num_pages" ),
168+ get_named_compile_time_arg_val (" cqh_nope_phase1_semaphore_id" ),
169+ get_named_compile_time_arg_val (" cqh_nope_phase2_semaphore_id" ),
170+ get_named_compile_time_arg_val (" cqh_rope_semaphore_id" ),
171+ {
172+ get_named_compile_time_arg_val (" cqh_target_noc_coords_row0" ),
173+ get_named_compile_time_arg_val (" cqh_target_noc_coords_row1" ),
174+ get_named_compile_time_arg_val (" cqh_target_noc_coords_row2" ),
175+ get_named_compile_time_arg_val (" cqh_target_noc_coords_row3" ),
176+ get_named_compile_time_arg_val (" cqh_target_noc_coords_row4" ),
177+ get_named_compile_time_arg_val (" cqh_target_noc_coords_row5" ),
178+ get_named_compile_time_arg_val (" cqh_target_noc_coords_row6" ),
179+ get_named_compile_time_arg_val (" cqh_target_noc_coords_row7" ),
180+ },
181+ get_write_ptr (cqh_receiver_in_cb),
182+ };
183+
153184 // Matmul CTArgs type alias (NCRISC uses ReaderCTArgs)
154185 using DKV_MatmulCTArgs = deepseek_b1_ops::Matmul::ReaderCTArgs;
155186
@@ -288,6 +319,19 @@ void kernel_main() {
288319 get_named_compile_time_arg_val (" gather_reduce_dst_num_tiles" ),
289320 };
290321
322+ // BRISC: Receiver args for SDPA input cores
323+ deepseek_b1_ops::CreateQHeads::ReceiverArgs create_q_heads_args{
324+ get_named_compile_time_arg_val (" cqh_nope_phase1_semaphore_id" ),
325+ get_named_compile_time_arg_val (" cqh_nope_phase2_semaphore_id" ),
326+ get_named_compile_time_arg_val (" cqh_rope_semaphore_id" ),
327+ get_named_compile_time_arg_val (" cqh_num_nope_senders" ),
328+ get_named_compile_time_arg_val (" cqh_num_rope_senders" ),
329+ get_named_compile_time_arg_val (" cqh_receiver_in_cb" ),
330+ get_named_compile_time_arg_val (" cqh_out_cb" ),
331+ get_named_compile_time_arg_val (" cqh_nope_tiles" ),
332+ get_named_compile_time_arg_val (" cqh_rope_tiles" ),
333+ };
334+
291335 // Matmul2 writer args (BRISC is no-op)
292336 deepseek_b1_ops::Matmul::WriterArgs matmul2_args{};
293337
@@ -321,34 +365,6 @@ void kernel_main() {
321365 get_write_ptr (matmul2_in0), // Write to matmul2_in0 (loopback)
322366 };
323367
324- // BRISC: Sender args for QNOPE/QROPE cores
325- // Senders write directly to output CB (allocated on sender+receiver cores)
326- constexpr uint32_t receive_cb = get_named_compile_time_arg_val (" receive_cb" );
327- deepseek_b1_ops::GatherHeads::SenderArgs gather_heads_args{
328- 0 , // sender_grid_start_x (logical 0)
329- 0 , // sender_grid_start_y (logical 0)
330- get_named_compile_time_arg_val (" qnope_data_size_bytes" ),
331- get_named_compile_time_arg_val (" qrope_data_size_bytes" ),
332- get_named_compile_time_arg_val (" head_stride_bytes" ),
333- get_named_compile_time_arg_val (" qnope_grid_cols" ),
334- get_named_compile_time_arg_val (" qnope_src_cb" ),
335- get_named_compile_time_arg_val (" qrope_src_cb" ),
336- Core::is_qnope_core ? get_named_compile_time_arg_val (" qnope_src_num_pages" )
337- : get_named_compile_time_arg_val (" qrope_src_num_pages" ),
338- get_named_compile_time_arg_val (" receiver_semaphore_id" ),
339- {
340- get_named_compile_time_arg_val (" target_noc_coords_row0" ),
341- get_named_compile_time_arg_val (" target_noc_coords_row1" ),
342- get_named_compile_time_arg_val (" target_noc_coords_row2" ),
343- get_named_compile_time_arg_val (" target_noc_coords_row3" ),
344- get_named_compile_time_arg_val (" target_noc_coords_row4" ),
345- get_named_compile_time_arg_val (" target_noc_coords_row5" ),
346- get_named_compile_time_arg_val (" target_noc_coords_row6" ),
347- get_named_compile_time_arg_val (" target_noc_coords_row7" ),
348- },
349- get_write_ptr (receive_cb), // Write directly to output CB
350- };
351-
352368 // Matmul writer args (BRISC is no-op)
353369 using DKV_MatmulCTArgs = deepseek_b1_ops::Matmul::WriterCTArgs;
354370 deepseek_b1_ops::Matmul::WriterArgs dkv_matmul_args{};
@@ -471,22 +487,27 @@ void kernel_main() {
471487
472488 // Qrope CTArgs type alias
473489 using QRopeCTArgs = deepseek_b1_ops::Rope::
474- ComputeCTArgs<get_named_compile_time_arg_val (" Wt " ), get_named_compile_time_arg_val (" Ht " )>;
490+ ComputeCTArgs<get_named_compile_time_arg_val (" qrope_Wt " ), get_named_compile_time_arg_val (" qrope_Ht " )>;
475491
476492 // Qrope compute args (from compile-time args)
477493 deepseek_b1_ops::Rope::ComputeArgs qrope_args{
478- get_named_compile_time_arg_val (" in_cb " ), // Input from matmul2 output
479- get_named_compile_time_arg_val (" cos_cb " ),
480- get_named_compile_time_arg_val (" sin_cb " ),
481- get_named_compile_time_arg_val (" trans_mat_cb " ),
482- get_named_compile_time_arg_val (" rotated_in_interm_cb " ),
483- get_named_compile_time_arg_val (" cos_interm_cb " ),
484- get_named_compile_time_arg_val (" sin_interm_cb " ),
485- get_named_compile_time_arg_val (" out_cb " ),
494+ get_named_compile_time_arg_val (" qrope_in_cb " ), // Input from matmul2 output
495+ get_named_compile_time_arg_val (" qrope_cos_cb " ),
496+ get_named_compile_time_arg_val (" qrope_sin_cb " ),
497+ get_named_compile_time_arg_val (" qrope_trans_mat_cb " ),
498+ get_named_compile_time_arg_val (" qrope_rotated_in_interm_cb " ),
499+ get_named_compile_time_arg_val (" qrope_cos_interm_cb " ),
500+ get_named_compile_time_arg_val (" qrope_sin_interm_cb " ),
501+ get_named_compile_time_arg_val (" qrope_output_cb " ),
486502 };
487503
488- // Gather heads compute args (no-op for TRISC)
489- deepseek_b1_ops::GatherHeads::ComputeArgs gather_heads_args{};
504+ // CreateQHeads compute args (tilization on SDPA input cores)
505+ deepseek_b1_ops::CreateQHeads::ComputeArgs create_q_heads_args{
506+ get_named_compile_time_arg_val (" cqh_receiver_in_cb" ),
507+ get_named_compile_time_arg_val (" cqh_out_cb" ),
508+ get_named_compile_time_arg_val (" cqh_nope_tiles" ),
509+ get_named_compile_time_arg_val (" cqh_rope_tiles" ),
510+ };
490511
491512 // DKV Matmul compute args
492513 using DKV_MatmulCTArgs =
@@ -525,7 +546,7 @@ void kernel_main() {
525546 constexpr uint32_t krope_input_cb = get_named_compile_time_arg_val (" krope_in_cb" );
526547 constexpr uint32_t krope_cos_cb = get_named_compile_time_arg_val (" krope_cos_cb" );
527548 constexpr uint32_t krope_sin_cb = get_named_compile_time_arg_val (" krope_sin_cb" );
528- constexpr uint32_t trans_mat_cb = get_named_compile_time_arg_val (" trans_mat_cb " );
549+ constexpr uint32_t krope_trans_mat_cb = get_named_compile_time_arg_val (" krope_trans_mat_cb " );
529550 constexpr uint32_t krope_rotated_in_interm_cb = get_named_compile_time_arg_val (" krope_rotated_in_interm_cb" );
530551 constexpr uint32_t krope_cos_interm_cb = get_named_compile_time_arg_val (" krope_cos_interm_cb" );
531552 constexpr uint32_t krope_sin_interm_cb = get_named_compile_time_arg_val (" krope_sin_interm_cb" );
@@ -536,7 +557,7 @@ void kernel_main() {
536557 .in_cb = krope_input_cb,
537558 .cos_cb = krope_cos_cb,
538559 .sin_cb = krope_sin_cb,
539- .trans_mat_cb = trans_mat_cb ,
560+ .trans_mat_cb = krope_trans_mat_cb ,
540561 .rotated_in_interm_cb = krope_rotated_in_interm_cb,
541562 .cos_interm_cb = krope_cos_interm_cb,
542563 .sin_interm_cb = krope_sin_interm_cb,
@@ -587,10 +608,10 @@ void kernel_main() {
587608
588609 if constexpr (Core::is_qrope_core) {
589610 // Qrope CB indices and parameters from named compile-time args
590- constexpr uint32_t qrope_cos_cb = get_named_compile_time_arg_val (" cos_cb " );
591- constexpr uint32_t qrope_sin_cb = get_named_compile_time_arg_val (" sin_cb " );
592- constexpr uint32_t qrope_trans_mat_cb = get_named_compile_time_arg_val (" trans_mat_cb " );
593- constexpr uint32_t Wt = get_named_compile_time_arg_val (" Wt " );
611+ constexpr uint32_t qrope_cos_cb = get_named_compile_time_arg_val (" qrope_cos_cb " );
612+ constexpr uint32_t qrope_sin_cb = get_named_compile_time_arg_val (" qrope_sin_cb " );
613+ constexpr uint32_t qrope_trans_mat_cb = get_named_compile_time_arg_val (" qrope_trans_mat_cb " );
614+ constexpr uint32_t Wt = get_named_compile_time_arg_val (" qrope_Wt " );
594615
595616 // NOTE: Do NOT setup qrope input CB (matmul2_output_cb) as sharded buffer!
596617 // The input to RoPE comes from matmul2 compute output, NOT from a sharded tensor.
@@ -601,14 +622,6 @@ void kernel_main() {
601622 unified_kernels::setup_sharded_buffer (qrope_trans_mat_cb, 1 ); // trans_mat is 1 tile (32x32)
602623 }
603624
604- // NCRISC: Receiver args for SDPA input cores
605- deepseek_b1_ops::GatherHeads::ReceiverArgs gather_heads_args{
606- get_named_compile_time_arg_val (" receiver_semaphore_id" ),
607- get_named_compile_time_arg_val (" num_senders" ),
608- get_named_compile_time_arg_val (" receive_cb" ), // Output CB
609- get_named_compile_time_arg_val (" dst_num_pages" ),
610- };
611-
612625 if constexpr (Core::is_dkv_matmul_core) {
613626 // Matmul weights (in1)
614627 constexpr uint32_t dkv_matmul_in1 = get_named_compile_time_arg_val (" dkv_matmul_in1" );
@@ -792,24 +805,23 @@ void kernel_main() {
792805 }
793806
794807 // ========================================================================
795- // GatherHeads: QNOPE/QROPE -> SDPA interleaved transfer
796- // QNOPE cores (cols 0-7): send [1, 512] to SDPA at offset = head_idx * 576
797- // QROPE cores (cols 8-11): send 2x [1, 64] to SDPA at offsets:
798- // - head_idx * 576 + 512
799- // - (head_idx + 1) * 576 + 512
808+ // CreateQHeads: 3-phase QNOPE/QROPE -> SDPA transfer with tilization
809+ // Phase 1: QNOPE first 256 elements → [8, 256] row-major → 8 tiles
810+ // Phase 2: QNOPE second 256 elements → [8, 256] row-major → 8 tiles
811+ // Phase 3: QROPE 64 elements per head → [8, 64] row-major → 2 tiles
812+ // Senders write to intermediate CB, TRISC tilizes to output CB
813+ // NCRISC sends from qnope/qrope cores, BRISC receives on sdpa input cores, TRISC no-op
800814 // ========================================================================
801815 {
802- DeviceZoneScopedN (" GATHER_HEADS " );
803- // GatherHeads Op configuration:
816+ DeviceZoneScopedN (" CREATE_Q_HEADS " );
817+ // CreateQHeads Op configuration:
804818 // - IsSenderCore: is_qnope_core || is_qrope_core
805819 // - IsReceiverCore: is_sdpa_input_core
806- // - setup_sharded_input: false (data already in CB from previous compute)
807820 // - pop_src: true (pop source CB after sending)
808- // - use_cb_output: true (receiver uses cb_reserve_back/cb_push_back, writes directly to output tensor)
809- constexpr bool is_gather_heads_sender = Core::is_qnope_core || Core::is_qrope_core;
810- deepseek_b1_ops::GatherHeads::Op<is_gather_heads_sender, Core::is_sdpa_input_core, false , true , true >
811- gather_heads;
812- gather_heads (gather_heads_args);
821+ constexpr bool is_create_q_heads_sender = Core::is_qnope_core || Core::is_qrope_core;
822+ deepseek_b1_ops::CreateQHeads::Op<is_create_q_heads_sender, Core::is_sdpa_input_core, false , true >
823+ create_q_heads;
824+ create_q_heads (create_q_heads_args);
813825 }
814826 }
815827 {
0 commit comments