Skip to content

Commit 19650e4

Browse files
fixed num_block of kvcache in sliding window (#301)
Signed-off-by: Jinseok Lee <jindol21@rebellions.ai>
1 parent 62b8fa2 commit 19650e4

1 file changed

Lines changed: 14 additions & 8 deletions

File tree

vllm_rbln/triton_kernels/sliding_window_attention.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def sliding_window_attention_naive_prefill(
3838
QUERY_LEN: tl.constexpr, # 256(prefill)
3939
WINDOW_SIZE: tl.constexpr,
4040
NUM_BATCH: tl.constexpr,
41+
NUM_BLOCK: tl.constexpr,
4142
DIM_BLOCK_TABLE: tl.constexpr,
4243
):
4344
tl.static_assert(NUM_BATCH == 1)
@@ -109,9 +110,9 @@ def sliding_window_attention_naive_prefill(
109110
)
110111
k_cache_base_ptr = tl.make_block_ptr(
111112
base=kv_cache_base,
112-
shape=(2, NUM_BATCH, NUM_HEAD, 1, WINDOW_SIZE, HEAD_DIM),
113+
shape=(2, NUM_BLOCK, NUM_HEAD, 1, WINDOW_SIZE, HEAD_DIM),
113114
strides=(
114-
NUM_BATCH * NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
115+
NUM_BLOCK * NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
115116
NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
116117
WINDOW_SIZE * HEAD_DIM,
117118
WINDOW_SIZE * HEAD_DIM,
@@ -124,9 +125,9 @@ def sliding_window_attention_naive_prefill(
124125
)
125126
v_cache_base_ptr = tl.make_block_ptr(
126127
base=kv_cache_base,
127-
shape=(2, NUM_BATCH, NUM_HEAD, 1, WINDOW_SIZE, HEAD_DIM),
128+
shape=(2, NUM_BLOCK, NUM_HEAD, 1, WINDOW_SIZE, HEAD_DIM),
128129
strides=(
129-
NUM_BATCH * NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
130+
NUM_BLOCK * NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
130131
NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
131132
WINDOW_SIZE * HEAD_DIM,
132133
WINDOW_SIZE * HEAD_DIM,
@@ -232,6 +233,7 @@ def sliding_window_attention_naive_decode(
232233
QUERY_LEN: tl.constexpr, # 1(decode)
233234
WINDOW_SIZE: tl.constexpr,
234235
NUM_BATCH: tl.constexpr,
236+
NUM_BLOCK: tl.constexpr,
235237
DIM_BLOCK_TABLE: tl.constexpr,
236238
):
237239
tl.static_assert(QUERY_LEN == 1)
@@ -306,9 +308,9 @@ def sliding_window_attention_naive_decode(
306308
)
307309
k_cache_base_ptr = tl.make_block_ptr(
308310
base=kv_cache_base,
309-
shape=(2, NUM_BATCH, NUM_HEAD, 1, WINDOW_SIZE, HEAD_DIM),
311+
shape=(2, NUM_BLOCK, NUM_HEAD, 1, WINDOW_SIZE, HEAD_DIM),
310312
strides=(
311-
NUM_BATCH * NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
313+
NUM_BLOCK * NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
312314
NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
313315
WINDOW_SIZE * HEAD_DIM,
314316
WINDOW_SIZE * HEAD_DIM,
@@ -321,9 +323,9 @@ def sliding_window_attention_naive_decode(
321323
)
322324
v_cache_base_ptr = tl.make_block_ptr(
323325
base=kv_cache_base,
324-
shape=(2, NUM_BATCH, NUM_HEAD, 1, WINDOW_SIZE, HEAD_DIM),
326+
shape=(2, NUM_BLOCK, NUM_HEAD, 1, WINDOW_SIZE, HEAD_DIM),
325327
strides=(
326-
NUM_BATCH * NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
328+
NUM_BLOCK * NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
327329
NUM_HEAD * WINDOW_SIZE * HEAD_DIM,
328330
WINDOW_SIZE * HEAD_DIM,
329331
WINDOW_SIZE * HEAD_DIM,
@@ -457,6 +459,7 @@ def _(
457459
QUERY_LEN = query.shape[-2]
458460
WINDOW_SIZE = kv_cache.shape[-2]
459461
NUM_BATCH = query.shape[0]
462+
NUM_BLOCK = kv_cache.shape[1]
460463
DIM_BLOCK_TABLE = block_table.dim()
461464

462465
params = [
@@ -476,6 +479,7 @@ def _(
476479
QUERY_LEN,
477480
WINDOW_SIZE,
478481
NUM_BATCH,
482+
NUM_BLOCK,
479483
DIM_BLOCK_TABLE,
480484
]
481485

@@ -527,6 +531,7 @@ def _(
527531
QUERY_LEN = query.shape[-2]
528532
WINDOW_SIZE = kv_cache.shape[-2]
529533
NUM_BATCH = query.shape[0]
534+
NUM_BLOCK = kv_cache.shape[1]
530535
DIM_BLOCK_TABLE = block_table.dim()
531536

532537
params = [
@@ -546,6 +551,7 @@ def _(
546551
QUERY_LEN,
547552
WINDOW_SIZE,
548553
NUM_BATCH,
554+
NUM_BLOCK,
549555
DIM_BLOCK_TABLE,
550556
]
551557

0 commit comments

Comments
 (0)