Skip to content

Commit 3ee92f4

Browse files
committed
Fix type and formatting issues for int4 paged KV
1 parent 4ac20c4 commit 3ee92f4

3 files changed

Lines changed: 53 additions & 25 deletions

File tree

flashinfer/prefill.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import logging
1919
import math
2020
from types import SimpleNamespace
21-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
21+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast, overload
2222

2323
import torch
2424

@@ -1285,19 +1285,26 @@ def single_prefill_with_kv_cache(
12851285
if return_lse:
12861286
lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device)
12871287

1288+
k_tensor = cast(torch.Tensor, k)
1289+
v_tensor = cast(torch.Tensor, v)
1290+
12881291
if is_float8(q):
12891292
# FP8 quant enabled, do sanity check:
12901293
# 1. unsupported feature
12911294
# 2. dtype check
12921295
assert window_left == -1
1293-
assert q.dtype == k.dtype == v.dtype
1294-
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
1296+
assert q.dtype == k_tensor.dtype == v_tensor.dtype
1297+
assert q.shape[-1] == k_tensor.shape[-1] == v_tensor.shape[-1]
12951298
if scale_q is None:
12961299
scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.device)
12971300
if scale_k is None:
1298-
scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.device)
1301+
scale_k = torch.ones(
1302+
k_tensor.shape[1], dtype=torch.float32, device=q.device
1303+
)
12991304
if scale_v is None:
1300-
scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device)
1305+
scale_v = torch.ones(
1306+
v_tensor.shape[1], dtype=torch.float32, device=q.device
1307+
)
13011308
else:
13021309
if scale_q is not None:
13031310
sm_scale *= scale_q
@@ -1318,21 +1325,23 @@ def single_prefill_with_kv_cache(
13181325
use_fp16_qk_reduction,
13191326
packed_custom_mask is not None, # use_custom_mask
13201327
q.dtype,
1321-
k.dtype,
1328+
k_tensor.dtype,
13221329
)
13231330

13241331
# o_dtype should be provided for FP8 attention
13251332
if o_dtype is None:
13261333
o_dtype = q.dtype
1327-
out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=o_dtype, device=q.device)
1334+
out = torch.empty(
1335+
q.shape[:-1] + v_tensor.shape[-1:], dtype=o_dtype, device=q.device
1336+
)
13281337

13291338
module = get_single_prefill_module(
13301339
backend,
13311340
q.dtype,
1332-
k.dtype,
1341+
k_tensor.dtype,
13331342
out.dtype,
13341343
q.shape[-1], # head_dim_qk
1335-
v.shape[-1], # head_dim_vo
1344+
v_tensor.shape[-1], # head_dim_vo
13361345
PosEncodingMode[pos_encoding_mode].value,
13371346
window_left >= 0, # use_sliding_window
13381347
logits_soft_cap > 0, # use_logits_soft_cap
@@ -1341,8 +1350,8 @@ def single_prefill_with_kv_cache(
13411350

13421351
module.run(
13431352
q,
1344-
k,
1345-
v,
1353+
k_tensor,
1354+
v_tensor,
13461355
tmp,
13471356
out,
13481357
lse,

flashinfer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ def is_int4_paged_kv_cache(
869869
INT4Tensor,
870870
Tuple[torch.Tensor, torch.Tensor],
871871
Tuple[INT4Tensor, INT4Tensor],
872-
]
872+
],
873873
) -> bool:
874874
if isinstance(paged_kv_cache, INT4Tensor):
875875
return True

tests/attention/test_int4_paged_kv.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,12 @@ def test_append_paged_kv_cache_int4_matches_quantized_layout(
149149
torch.testing.assert_close(
150150
v_cache.scale[page_indices, :, page_positions, :], expected_v.scale
151151
)
152-
gathered_k = flashinfer.int4_dequantize(k_cache)[page_indices, :, page_positions]
153-
gathered_v = flashinfer.int4_dequantize(v_cache)[page_indices, :, page_positions]
152+
gathered_k = flashinfer.int4_dequantize(k_cache)[
153+
page_indices, :, page_positions
154+
]
155+
gathered_v = flashinfer.int4_dequantize(v_cache)[
156+
page_indices, :, page_positions
157+
]
154158

155159
torch.testing.assert_close(
156160
gathered_k,
@@ -177,11 +181,19 @@ def test_single_decode_with_kv_cache_int4(kv_layout, head_dim, use_tensor_cores)
177181

178182
q = torch.randn(num_qo_heads, head_dim, dtype=torch.float16, device=device)
179183
if kv_layout == "NHD":
180-
k = torch.randn(kv_len, num_kv_heads, head_dim, dtype=torch.float16, device=device)
181-
v = torch.randn(kv_len, num_kv_heads, head_dim, dtype=torch.float16, device=device)
184+
k = torch.randn(
185+
kv_len, num_kv_heads, head_dim, dtype=torch.float16, device=device
186+
)
187+
v = torch.randn(
188+
kv_len, num_kv_heads, head_dim, dtype=torch.float16, device=device
189+
)
182190
else:
183-
k = torch.randn(num_kv_heads, kv_len, head_dim, dtype=torch.float16, device=device)
184-
v = torch.randn(num_kv_heads, kv_len, head_dim, dtype=torch.float16, device=device)
191+
k = torch.randn(
192+
num_kv_heads, kv_len, head_dim, dtype=torch.float16, device=device
193+
)
194+
v = torch.randn(
195+
num_kv_heads, kv_len, head_dim, dtype=torch.float16, device=device
196+
)
185197

186198
k_int4 = flashinfer.int4_quantize(k)
187199
v_int4 = flashinfer.int4_quantize(v)
@@ -232,11 +244,19 @@ def test_single_prefill_with_kv_cache_int4(kv_layout, head_dim):
232244

233245
q = torch.randn(qo_len, num_qo_heads, head_dim, dtype=torch.float16, device=device)
234246
if kv_layout == "NHD":
235-
k = torch.randn(kv_len, num_kv_heads, head_dim, dtype=torch.float16, device=device)
236-
v = torch.randn(kv_len, num_kv_heads, head_dim, dtype=torch.float16, device=device)
247+
k = torch.randn(
248+
kv_len, num_kv_heads, head_dim, dtype=torch.float16, device=device
249+
)
250+
v = torch.randn(
251+
kv_len, num_kv_heads, head_dim, dtype=torch.float16, device=device
252+
)
237253
else:
238-
k = torch.randn(num_kv_heads, kv_len, head_dim, dtype=torch.float16, device=device)
239-
v = torch.randn(num_kv_heads, kv_len, head_dim, dtype=torch.float16, device=device)
254+
k = torch.randn(
255+
num_kv_heads, kv_len, head_dim, dtype=torch.float16, device=device
256+
)
257+
v = torch.randn(
258+
num_kv_heads, kv_len, head_dim, dtype=torch.float16, device=device
259+
)
240260

241261
k_int4 = flashinfer.int4_quantize(k)
242262
v_int4 = flashinfer.int4_quantize(v)
@@ -666,9 +686,8 @@ def test_int4_paged_kv_cache_cuda_graph_unsupported():
666686
head_dim = 128
667687
device = "cuda:0"
668688

669-
kv_indptr = (
670-
torch.arange(0, batch_size + 1, device=device, dtype=torch.int32)
671-
* ((kv_len + page_size - 1) // page_size)
689+
kv_indptr = torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * (
690+
(kv_len + page_size - 1) // page_size
672691
)
673692
kv_indices = torch.arange(kv_indptr[-1].item(), device=device, dtype=torch.int32)
674693
kv_last_page_len = torch.full(

0 commit comments

Comments
 (0)