Skip to content

Commit cb565d7

Browse files
authored
Update rotary_nki_kernels.py
1 parent 4ff60f0 commit cb565d7

File tree

1 file changed

+30
-21
lines changed

1 file changed

+30
-21
lines changed

src/nki_samples/tutorials/rotary/rotary_nki_kernels.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -169,23 +169,32 @@ def _nki_apply_rotary_embedding_core(q_tile, k_tile, cos_tile, sin_tile, output_
169169
The function applies rotary position embedding to query and key tensors
170170
using the provided cosine and sine embeddings.
171171
"""
172+
173+
assert q_tile.shape[-1] % 2 == 0, "Sequence length for q_tile must be even!"
174+
assert k_tile.shape[-1] % 2 == 0, "Sequence length for k_tile must be even!"
175+
assert (
176+
q_tile.shape[-1] == k_tile.shape[-1]
177+
), "q_tile and k_tile must have the same sequence length"
178+
179+
seq_len = q_tile.shape[-1]
180+
172181
# Rotate Q
173182
output_tile[0, :, :] = q_tile * cos_tile
174-
output_tile[0, :, : q_tile.shape[-1] // 2] = output_tile[
175-
0, :, : q_tile.shape[-1] // 2
176-
] + (-1 * q_tile[:, q_tile.shape[-1] // 2 :] * sin_tile[:, : q_tile.shape[-1] // 2])
177-
output_tile[0, :, q_tile.shape[-1] // 2 :] = output_tile[
178-
0, :, q_tile.shape[-1] // 2 :
179-
] + (q_tile[:, : q_tile.shape[-1] // 2] * sin_tile[:, q_tile.shape[-1] // 2 :])
183+
output_tile[0, :, : seq_len // 2] = output_tile[0, :, : seq_len // 2] + (
184+
-1 * q_tile[:, seq_len // 2 :] * sin_tile[:, : seq_len // 2]
185+
)
186+
output_tile[0, :, seq_len // 2 :] = output_tile[0, :, seq_len // 2 :] + (
187+
q_tile[:, : seq_len // 2] * sin_tile[:, seq_len // 2 :]
188+
)
180189

181190
# Rotate K
182191
output_tile[1, :, :] = k_tile * cos_tile
183-
output_tile[1, :, : k_tile.shape[-1] // 2] = output_tile[
184-
1, :, : k_tile.shape[-1] // 2
185-
] + (-1 * k_tile[:, k_tile.shape[-1] // 2 :] * sin_tile[:, : k_tile.shape[-1] // 2])
186-
output_tile[1, :, k_tile.shape[-1] // 2 :] = output_tile[
187-
1, :, k_tile.shape[-1] // 2 :
188-
] + (k_tile[:, : k_tile.shape[-1] // 2] * sin_tile[:, k_tile.shape[-1] // 2 :])
192+
output_tile[1, :, : seq_len // 2] = output_tile[1, :, : seq_len // 2] + (
193+
-1 * k_tile[:, seq_len // 2 :] * sin_tile[:, : seq_len // 2]
194+
)
195+
output_tile[1, :, seq_len // 2 :] = output_tile[1, :, seq_len // 2 :] + (
196+
k_tile[:, : seq_len // 2] * sin_tile[:, seq_len // 2 :]
197+
)
189198

190199

191200
def div_ceil(n: int, d: int) -> int:
@@ -258,15 +267,15 @@ def nki_apply_rotary_embedding(q, k, cos, sin):
258267
AssertionError
259268
If input tensor shapes don't match or head dimension > 128
260269
"""
261-
assert q.shape == k.shape, (
262-
f"Shape of Q Tensor: {q.shape} doesn't match shape of K Tensor: {k.shape}"
263-
)
264-
assert cos.shape == sin.shape, (
265-
f"Shape of cos Tensor: {cos.shape} doesn't match shape of sin Tensor: {sin.shape}"
266-
)
267-
assert q.shape[-1] <= 128, (
268-
f"Shape of head dim (last dim) is more than 128: {q.shape}"
269-
)
270+
assert (
271+
q.shape == k.shape
272+
), f"Shape of Q Tensor: {q.shape} doesn't match shape of K Tensor: {k.shape}"
273+
assert (
274+
cos.shape == sin.shape
275+
), f"Shape of cos Tensor: {cos.shape} doesn't match shape of sin Tensor: {sin.shape}"
276+
assert (
277+
q.shape[-1] <= 128
278+
), f"Shape of head dim (last dim) is more than 128: {q.shape}"
270279

271280
batch_id = nl.program_id(axis=0)
272281
head_id = nl.program_id(axis=1)

0 commit comments

Comments
 (0)