@@ -38,6 +38,7 @@ def sliding_window_attention_naive_prefill(
3838 QUERY_LEN : tl .constexpr , # 256(prefill)
3939 WINDOW_SIZE : tl .constexpr ,
4040 NUM_BATCH : tl .constexpr ,
41+ NUM_BLOCK : tl .constexpr ,
4142 DIM_BLOCK_TABLE : tl .constexpr ,
4243):
4344 tl .static_assert (NUM_BATCH == 1 )
@@ -109,9 +110,9 @@ def sliding_window_attention_naive_prefill(
109110 )
110111 k_cache_base_ptr = tl .make_block_ptr (
111112 base = kv_cache_base ,
112- shape = (2 , NUM_BATCH , NUM_HEAD , 1 , WINDOW_SIZE , HEAD_DIM ),
113+ shape = (2 , NUM_BLOCK , NUM_HEAD , 1 , WINDOW_SIZE , HEAD_DIM ),
113114 strides = (
114- NUM_BATCH * NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
115+ NUM_BLOCK * NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
115116 NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
116117 WINDOW_SIZE * HEAD_DIM ,
117118 WINDOW_SIZE * HEAD_DIM ,
@@ -124,9 +125,9 @@ def sliding_window_attention_naive_prefill(
124125 )
125126 v_cache_base_ptr = tl .make_block_ptr (
126127 base = kv_cache_base ,
127- shape = (2 , NUM_BATCH , NUM_HEAD , 1 , WINDOW_SIZE , HEAD_DIM ),
128+ shape = (2 , NUM_BLOCK , NUM_HEAD , 1 , WINDOW_SIZE , HEAD_DIM ),
128129 strides = (
129- NUM_BATCH * NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
130+ NUM_BLOCK * NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
130131 NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
131132 WINDOW_SIZE * HEAD_DIM ,
132133 WINDOW_SIZE * HEAD_DIM ,
@@ -232,6 +233,7 @@ def sliding_window_attention_naive_decode(
232233 QUERY_LEN : tl .constexpr , # 1(decode)
233234 WINDOW_SIZE : tl .constexpr ,
234235 NUM_BATCH : tl .constexpr ,
236+ NUM_BLOCK : tl .constexpr ,
235237 DIM_BLOCK_TABLE : tl .constexpr ,
236238):
237239 tl .static_assert (QUERY_LEN == 1 )
@@ -306,9 +308,9 @@ def sliding_window_attention_naive_decode(
306308 )
307309 k_cache_base_ptr = tl .make_block_ptr (
308310 base = kv_cache_base ,
309- shape = (2 , NUM_BATCH , NUM_HEAD , 1 , WINDOW_SIZE , HEAD_DIM ),
311+ shape = (2 , NUM_BLOCK , NUM_HEAD , 1 , WINDOW_SIZE , HEAD_DIM ),
310312 strides = (
311- NUM_BATCH * NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
313+ NUM_BLOCK * NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
312314 NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
313315 WINDOW_SIZE * HEAD_DIM ,
314316 WINDOW_SIZE * HEAD_DIM ,
@@ -321,9 +323,9 @@ def sliding_window_attention_naive_decode(
321323 )
322324 v_cache_base_ptr = tl .make_block_ptr (
323325 base = kv_cache_base ,
324- shape = (2 , NUM_BATCH , NUM_HEAD , 1 , WINDOW_SIZE , HEAD_DIM ),
326+ shape = (2 , NUM_BLOCK , NUM_HEAD , 1 , WINDOW_SIZE , HEAD_DIM ),
325327 strides = (
326- NUM_BATCH * NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
328+ NUM_BLOCK * NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
327329 NUM_HEAD * WINDOW_SIZE * HEAD_DIM ,
328330 WINDOW_SIZE * HEAD_DIM ,
329331 WINDOW_SIZE * HEAD_DIM ,
@@ -457,6 +459,7 @@ def _(
457459 QUERY_LEN = query .shape [- 2 ]
458460 WINDOW_SIZE = kv_cache .shape [- 2 ]
459461 NUM_BATCH = query .shape [0 ]
462+ NUM_BLOCK = kv_cache .shape [1 ]
460463 DIM_BLOCK_TABLE = block_table .dim ()
461464
462465 params = [
@@ -476,6 +479,7 @@ def _(
476479 QUERY_LEN ,
477480 WINDOW_SIZE ,
478481 NUM_BATCH ,
482+ NUM_BLOCK ,
479483 DIM_BLOCK_TABLE ,
480484 ]
481485
@@ -527,6 +531,7 @@ def _(
527531 QUERY_LEN = query .shape [- 2 ]
528532 WINDOW_SIZE = kv_cache .shape [- 2 ]
529533 NUM_BATCH = query .shape [0 ]
534+ NUM_BLOCK = kv_cache .shape [1 ]
530535 DIM_BLOCK_TABLE = block_table .dim ()
531536
532537 params = [
@@ -546,6 +551,7 @@ def _(
546551 QUERY_LEN ,
547552 WINDOW_SIZE ,
548553 NUM_BATCH ,
554+ NUM_BLOCK ,
549555 DIM_BLOCK_TABLE ,
550556 ]
551557
0 commit comments