Skip to content

Commit 352e6f7

Browse files
authored
Update rotary_nki_kernels.py
specify indicies for other dims when storing tiles
1 parent d12ea47 commit 352e6f7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/nki_samples/tutorials/rotary/rotary_nki_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,12 +312,12 @@ def nki_apply_rotary_embedding(q, k, cos, sin):
312312

313313
nl.store(
314314
output_q_hbm_tile[seq_batch_id * nl.tile_size.pmax + i_p, i_f],
315-
output_tile[0],
315+
output_tile[0, :, :],
316316
mask=(seq_batch_id * nl.tile_size.pmax + i_p < seq_len),
317317
)
318318
nl.store(
319319
output_k_hbm_tile[seq_batch_id * nl.tile_size.pmax + i_p, i_f],
320-
output_tile[1],
320+
output_tile[1, :, :],
321321
mask=(seq_batch_id * nl.tile_size.pmax + i_p < seq_len),
322322
)
323323

0 commit comments

Comments
 (0)