Skip to content
2 changes: 1 addition & 1 deletion csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void append_paged_kv_cache(TensorView append_key, TensorView append_value, Tenso

ffi::CUDADeviceGuard device_guard(append_key.device().device_id);
const cudaStream_t stream = get_stream(append_key.device());
bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE(paged_k_cache.dtype(), c_type, [&] {
bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_QKV(paged_k_cache.dtype(), c_type, [&] {
paged_kv_t<c_type, int32_t> paged_kv(
num_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_type*>(paged_k_cache.data_ptr()),
Expand Down
23 changes: 23 additions & 0 deletions csrc/tvm_ffi_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ constexpr int64_t float16_code = encode_dlpack_dtype(dl_float16);
constexpr int64_t bfloat16_code = encode_dlpack_dtype(dl_bfloat16);
constexpr int64_t float32_code = encode_dlpack_dtype(dl_float32);
constexpr int64_t uint8_code = encode_dlpack_dtype(dl_uint8);
constexpr int64_t int8_code = encode_dlpack_dtype(dl_int8);
constexpr int64_t int32_code = encode_dlpack_dtype(dl_int32);
constexpr int64_t int64_code = encode_dlpack_dtype(dl_int64);
constexpr int64_t float8_e4m3fn_code = encode_dlpack_dtype(dl_float8_e4m3fn);
Expand Down Expand Up @@ -116,6 +117,12 @@ constexpr DLDevice cpu = DLDevice{kDLCPU, 0};
return __VA_ARGS__(); \
}

#define _DISPATCH_CASE_I8(c_type, ...) \
case int8_code: { \
using c_type = int8_t; \
return __VA_ARGS__(); \
}

#define _DISPATCH_CASE_I64(c_type, ...) \
case int64_code: { \
using c_type = int64_t; \
Expand Down Expand Up @@ -252,6 +259,22 @@ constexpr DLDevice cpu = DLDevice{kDLCPU, 0};
} \
}()

#define DISPATCH_DLPACK_DTYPE_TO_CTYPE_QKV(dlpack_dtype, c_type, ...) \
[&]() -> bool { \
switch (encode_dlpack_dtype(dlpack_dtype)) { \
_DISPATCH_CASE_I8(c_type, __VA_ARGS__) \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP4_E2M1(c_type, __VA_ARGS__) \
default: \
TVM_FFI_ICHECK(false) << __PRETTY_FUNCTION__ << " failed to dispatch data type " \
<< (dlpack_dtype).code << " " << (dlpack_dtype).bits; \
return false; \
} \
}()

#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
Expand Down
3 changes: 3 additions & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@
from .prefill import trtllm_fmha_v2_prefill as trtllm_fmha_v2_prefill
from .quantization import packbits as packbits
from .quantization import segment_packbits as segment_packbits
from .quantization import int4_dequantize as int4_dequantize
from .quantization import int4_quantize as int4_quantize
from .rope import apply_llama31_rope as apply_llama31_rope
from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace
from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids
Expand Down Expand Up @@ -181,6 +183,7 @@
prepare_low_latency_gemm_weights as prepare_low_latency_gemm_weights,
)
from .utils import next_positive_power_of_2 as next_positive_power_of_2
from .utils import INT4Tensor as INT4Tensor
from .xqa import xqa as xqa
from .xqa import xqa_mla as xqa_mla
from . import mamba as mamba
73 changes: 58 additions & 15 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
get_batch_prefill_module,
get_single_prefill_module,
)
from .quantization import int4_dequantize
from .utils import (
INT4Tensor,
log2e,
FP4Tensor,
MaskMode,
Expand All @@ -60,6 +62,7 @@
_check_pos_encoding_mode,
check_shape_dtype_device,
get_alibi_slopes,
_dequantize_int4_paged_kv_cache,
_get_cache_alibi_slopes_buf,
_get_range_buf,
_unpack_paged_kv_cache,
Expand All @@ -76,6 +79,8 @@
GPUArchitectureError,
SINGLE_KERNEL_TMP_SIZE,
prepare_jit_additional_args,
is_int4_dtype,
is_int4_tensor,
)


Expand Down Expand Up @@ -363,8 +368,8 @@ def get_batch_decode_mla_module(*args):
@overload
def single_decode_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
k: Union[torch.Tensor, INT4Tensor],
v: Union[torch.Tensor, INT4Tensor],
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
use_tensor_cores: bool = False,
Expand All @@ -383,8 +388,8 @@ def single_decode_with_kv_cache(
@overload
def single_decode_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
k: Union[torch.Tensor, INT4Tensor],
v: Union[torch.Tensor, INT4Tensor],
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
use_tensor_cores: bool = False,
Expand All @@ -403,8 +408,8 @@ def single_decode_with_kv_cache(
@flashinfer_api
def single_decode_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
k: Union[torch.Tensor, INT4Tensor],
v: Union[torch.Tensor, INT4Tensor],
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
use_tensor_cores: bool = False,
Expand Down Expand Up @@ -498,6 +503,15 @@ def single_decode_with_kv_cache(
"""
_check_pos_encoding_mode(pos_encoding_mode)
_check_kv_layout(kv_layout)
if is_int4_tensor(k) or is_int4_tensor(v):
if not (is_int4_tensor(k) and is_int4_tensor(v)):
raise ValueError("k and v must both be INT4Tensor when using int4 KV.")
if k_scale is not None or v_scale is not None:
raise ValueError(
"k_scale and v_scale are not supported for INT4Tensor inputs."
)
k = int4_dequantize(k)
v = int4_dequantize(v)
tmp = torch.empty(SINGLE_KERNEL_TMP_SIZE, dtype=torch.uint8, device=q.device)
head_dim = q.shape[-1]
if logits_soft_cap is None:
Expand Down Expand Up @@ -977,13 +991,31 @@ def plan(
if kv_data_type is None:
kv_data_type = data_type

if is_int4_dtype(q_data_type) or is_int4_dtype(o_data_type):
raise ValueError("q_data_type and o_data_type do not support int4.")
q_data_type = canonicalize_torch_dtype(q_data_type)
if kv_data_type is None:
kv_data_type = q_data_type
kv_data_type = canonicalize_torch_dtype(kv_data_type)
self._int4_kv_enabled = is_int4_dtype(kv_data_type)
effective_kv_data_type = (
torch.float16
if self._int4_kv_enabled
else canonicalize_torch_dtype(kv_data_type)
)
if self._int4_kv_enabled:
if self._backend == "auto":
self._backend = "fa2"
elif self._backend != "fa2":
raise NotImplementedError(
"INT4 paged KV cache only supports the fa2/common decode path."
)
if o_data_type is None:
o_data_type = q_data_type
o_data_type = canonicalize_torch_dtype(o_data_type)
if self.is_cuda_graph_enabled and self._int4_kv_enabled:
raise NotImplementedError(
"INT4 paged KV cache is not supported with CUDA graph decode yet."
)

if fixed_split_size is not None and not self.use_tensor_cores:
raise ValueError(
Expand All @@ -993,7 +1025,8 @@ def plan(
fixed_split_size = -1

self._cached_q_data_type = q_data_type
self._cached_kv_data_type = kv_data_type
self._cached_kv_data_type = effective_kv_data_type
self._external_kv_data_type = kv_data_type
self._cached_o_data_type = o_data_type
self._batch_size = batch_size
self._num_qo_heads = num_qo_heads
Expand Down Expand Up @@ -1038,7 +1071,7 @@ def plan(
block_id += num_blocks_needed
self._cached_module = get_trtllm_gen_decode_module(
q_data_type,
kv_data_type,
effective_kv_data_type,
o_data_type,
indptr.dtype,
head_dim,
Expand All @@ -1058,21 +1091,21 @@ def plan(
if {
torch.float8_e4m3fn,
torch.float8_e5m2,
} & {q_data_type, kv_data_type}:
} & {q_data_type, effective_kv_data_type}:
self._backend = determine_attention_backend(
self.device,
PosEncodingMode[pos_encoding_mode].value,
False, # use_fp16_qk_reductions
False, # use_custom_mask
q_data_type,
kv_data_type,
effective_kv_data_type,
)
else:
self._backend = "fa2"
self._cached_module = get_batch_prefill_module(
self._backend,
q_data_type,
kv_data_type,
effective_kv_data_type,
o_data_type,
indptr.dtype,
head_dim, # head_dim_qk
Expand Down Expand Up @@ -1114,7 +1147,7 @@ def plan(
else:
self._cached_module = get_batch_decode_module(
q_data_type,
kv_data_type,
effective_kv_data_type,
o_data_type,
indptr.dtype,
head_dim, # head_dim_qk
Expand All @@ -1138,7 +1171,7 @@ def plan(
head_dim,
head_dim,
torch.empty(0, dtype=q_data_type),
torch.empty(0, dtype=kv_data_type),
torch.empty(0, dtype=effective_kv_data_type),
)

self._pos_encoding_mode = pos_encoding_mode
Expand Down Expand Up @@ -1289,7 +1322,17 @@ def run(
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(q.device)
k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)
if self._int4_kv_enabled:
if k_scale is not None or v_scale is not None:
raise ValueError(
"k_scale and v_scale are not supported for INT4 paged KV cache."
)
k_cache, v_cache = _dequantize_int4_paged_kv_cache(
paged_kv_cache,
self._kv_layout,
)
else:
k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)

if (
k_cache.dtype == torch.uint8 or v_cache.dtype == torch.uint8
Expand Down
86 changes: 85 additions & 1 deletion flashinfer/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@

from .api_logging import flashinfer_api
from .jit.page import gen_page_module
from .quantization import int4_quantize
from .utils import (
INT4Tensor,
TensorLayout,
_check_kv_layout,
_split_int4_paged_kv_cache_views,
check_shape_dtype_device,
is_int4_paged_kv_cache,
_unpack_paged_kv_cache,
register_custom_op,
register_fake_op,
Expand Down Expand Up @@ -120,6 +124,69 @@ def _fake_append_paged_kv_cache_kernel(
pass


def _append_paged_kv_cache_int4(
append_key: torch.Tensor,
append_value: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
paged_kv_cache: Union[INT4Tensor, Tuple[INT4Tensor, INT4Tensor]],
kv_indices: torch.Tensor,
kv_indptr: torch.Tensor,
kv_layout: str,
) -> None:
packed_key = int4_quantize(append_key)
packed_value = int4_quantize(append_value)

k_data, v_data, k_scale, v_scale = _split_int4_paged_kv_cache_views(
paged_kv_cache, kv_layout
)
if packed_key.data.shape[-1] != k_data.shape[-1]:
raise ValueError(
"The append key head dimension does not match the paged int4 cache."
)
if packed_value.data.shape[-1] != v_data.shape[-1]:
raise ValueError(
"The append value head dimension does not match the paged int4 cache."
)
if packed_key.scale.shape[-1] != k_scale.shape[-1]:
raise ValueError(
"The append key group count does not match the paged int4 cache."
)
if packed_value.scale.shape[-1] != v_scale.shape[-1]:
raise ValueError(
"The append value group count does not match the paged int4 cache."
)
cache_num_heads = k_data.shape[2] if kv_layout == "NHD" else k_data.shape[1]
if packed_key.data.shape[1] != cache_num_heads:
raise ValueError(
"The append key head count does not match the paged int4 cache."
)
if packed_value.data.shape[1] != cache_num_heads:
raise ValueError(
"The append value head count does not match the paged int4 cache."
)

page_size = k_data.shape[1] if kv_layout == "NHD" else k_data.shape[2]
batch_indices = batch_indices.to(torch.int64)
positions = positions.to(torch.int64)
kv_indices = kv_indices.to(torch.int64)
kv_indptr = kv_indptr.to(torch.int64)
page_offsets = torch.div(positions, page_size, rounding_mode="floor")
page_positions = torch.remainder(positions, page_size)
page_indices = kv_indices[kv_indptr[batch_indices] + page_offsets]

if kv_layout == "NHD":
k_data[page_indices, page_positions] = packed_key.data
v_data[page_indices, page_positions] = packed_value.data
k_scale[page_indices, page_positions] = packed_key.scale
v_scale[page_indices, page_positions] = packed_value.scale
else:
k_data[page_indices, :, page_positions, :] = packed_key.data
v_data[page_indices, :, page_positions, :] = packed_value.data
k_scale[page_indices, :, page_positions, :] = packed_key.scale
v_scale[page_indices, :, page_positions, :] = packed_value.scale
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@flashinfer_api
def get_batch_indices_positions(
append_indptr: torch.Tensor,
Expand Down Expand Up @@ -278,7 +345,12 @@ def append_paged_kv_cache(
append_value: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
paged_kv_cache: Union[
torch.Tensor,
INT4Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[INT4Tensor, INT4Tensor],
],
kv_indices: torch.Tensor,
kv_indptr: torch.Tensor,
kv_last_page_len: torch.Tensor,
Expand Down Expand Up @@ -389,6 +461,18 @@ def append_paged_kv_cache(
get_batch_indices_positions
"""
_check_kv_layout(kv_layout)
if is_int4_paged_kv_cache(paged_kv_cache):
_append_paged_kv_cache_int4(
append_key,
append_value,
batch_indices,
positions,
paged_kv_cache,
kv_indices,
kv_indptr,
kv_layout,
)
return
_append_paged_kv_cache_kernel(
append_key,
append_value,
Expand Down
Loading
Loading