@@ -169,23 +169,32 @@ def _nki_apply_rotary_embedding_core(q_tile, k_tile, cos_tile, sin_tile, output_
169
169
The function applies rotary position embedding to query and key tensors
170
170
using the provided cosine and sine embeddings.
171
171
"""
172
+
173
+ assert q_tile .shape [- 1 ] % 2 == 0 , "Sequence length for q_tile must be even!"
174
+ assert k_tile .shape [- 1 ] % 2 == 0 , "Sequence length for k_tile must be even!"
175
+ assert (
176
+ q_tile .shape [- 1 ] == k_tile .shape [- 1 ]
177
+ ), "q_tile and k_tile must have the same sequence length"
178
+
179
+ seq_len = q_tile .shape [- 1 ]
180
+
172
181
# Rotate Q
173
182
output_tile [0 , :, :] = q_tile * cos_tile
174
- output_tile [0 , :, : q_tile . shape [ - 1 ] // 2 ] = output_tile [
175
- 0 , :, : q_tile . shape [ - 1 ] // 2
176
- ] + ( - 1 * q_tile [:, q_tile . shape [ - 1 ] // 2 :] * sin_tile [:, : q_tile . shape [ - 1 ] // 2 ] )
177
- output_tile [0 , :, q_tile . shape [ - 1 ] // 2 :] = output_tile [
178
- 0 , :, q_tile . shape [ - 1 ] // 2 :
179
- ] + ( q_tile [:, : q_tile . shape [ - 1 ] // 2 ] * sin_tile [:, q_tile . shape [ - 1 ] // 2 :] )
183
+ output_tile [0 , :, : seq_len // 2 ] = output_tile [0 , :, : seq_len // 2 ] + (
184
+ - 1 * q_tile [ :, seq_len // 2 :] * sin_tile [:, : seq_len // 2 ]
185
+ )
186
+ output_tile [0 , :, seq_len // 2 :] = output_tile [0 , :, seq_len // 2 :] + (
187
+ q_tile [: , : seq_len // 2 ] * sin_tile [:, seq_len // 2 :]
188
+ )
180
189
181
190
# Rotate K
182
191
output_tile [1 , :, :] = k_tile * cos_tile
183
- output_tile [1 , :, : k_tile . shape [ - 1 ] // 2 ] = output_tile [
184
- 1 , :, : k_tile . shape [ - 1 ] // 2
185
- ] + ( - 1 * k_tile [:, k_tile . shape [ - 1 ] // 2 :] * sin_tile [:, : k_tile . shape [ - 1 ] // 2 ] )
186
- output_tile [1 , :, k_tile . shape [ - 1 ] // 2 :] = output_tile [
187
- 1 , :, k_tile . shape [ - 1 ] // 2 :
188
- ] + ( k_tile [:, : k_tile . shape [ - 1 ] // 2 ] * sin_tile [:, k_tile . shape [ - 1 ] // 2 :] )
192
+ output_tile [1 , :, : seq_len // 2 ] = output_tile [1 , :, : seq_len // 2 ] + (
193
+ - 1 * k_tile [ :, seq_len // 2 :] * sin_tile [:, : seq_len // 2 ]
194
+ )
195
+ output_tile [1 , :, seq_len // 2 :] = output_tile [1 , :, seq_len // 2 :] + (
196
+ k_tile [: , : seq_len // 2 ] * sin_tile [:, seq_len // 2 :]
197
+ )
189
198
190
199
191
200
def div_ceil (n : int , d : int ) -> int :
@@ -258,15 +267,15 @@ def nki_apply_rotary_embedding(q, k, cos, sin):
258
267
AssertionError
259
268
If input tensor shapes don't match or head dimension > 128
260
269
"""
261
- assert q . shape == k . shape , (
262
- f"Shape of Q Tensor: { q .shape } doesn't match shape of K Tensor: { k .shape } "
263
- )
264
- assert cos . shape == sin . shape , (
265
- f"Shape of cos Tensor: { cos .shape } doesn't match shape of sin Tensor: { sin .shape } "
266
- )
267
- assert q . shape [ - 1 ] <= 128 , (
268
- f"Shape of head dim (last dim) is more than 128: { q .shape } "
269
- )
270
+ assert (
271
+ q .shape == k .shape
272
+ ), f"Shape of Q Tensor: { q . shape } doesn't match shape of K Tensor: { k . shape } "
273
+ assert (
274
+ cos .shape == sin .shape
275
+ ), f"Shape of cos Tensor: { cos . shape } doesn't match shape of sin Tensor: { sin . shape } "
276
+ assert (
277
+ q .shape [ - 1 ] <= 128
278
+ ), f"Shape of head dim (last dim) is more than 128: { q . shape } "
270
279
271
280
batch_id = nl .program_id (axis = 0 )
272
281
head_id = nl .program_id (axis = 1 )
0 commit comments