Skip to content

Commit 194cbc7

Browse files
DwyaneShiHaiyang Shi
andauthored
[Feature] KVCache support block-first layout (#1947)
- Support block-first kvcache layout used by FlashInfer etc. Signed-off-by: Haiyang Shi <haiyang.shi@bytedance.com> Co-authored-by: Haiyang Shi <haiyang.shi@bytedance.com>
1 parent cb5f1bf commit 194cbc7

File tree

4 files changed

+67
-45
lines changed

4 files changed

+67
-45
lines changed

python/aibrix_kvcache/aibrix_kvcache/_custom_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@ def reshape_and_cache_multi_layer(
3333
k_scales: list[torch.Tensor],
3434
v_scales: list[torch.Tensor],
3535
layout: str,
36+
kv_layout_blocks_first: bool = False,
3637
) -> None:
3738
torch.ops._aibrix_C_cache_ops.reshape_and_cache_multi_layer(
3839
offload_kv_cache_blocks,
3940
kv_caches,
41+
kv_layout_blocks_first,
4042
slot_mapping,
4143
block_size,
4244
kv_cache_dtype,
@@ -55,10 +57,12 @@ def reshape_and_offload_multi_layer(
5557
k_scales: list[torch.Tensor],
5658
v_scales: list[torch.Tensor],
5759
layout: str,
60+
kv_layout_blocks_first: bool = False,
5861
) -> None:
5962
torch.ops._aibrix_C_cache_ops.reshape_and_offload_multi_layer(
6063
offload_kv_cache_blocks,
6164
kv_caches,
65+
kv_layout_blocks_first,
6266
slot_mapping,
6367
block_size,
6468
kv_cache_dtype,

python/aibrix_kvcache/csrc/cache.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77

88
void reshape_and_cache_multi_layer(
99
const std::vector<torch::Tensor> &offload_kv_cache_blocks,
10-
const std::vector<torch::Tensor> &kv_caches, torch::Tensor &slot_mapping,
11-
const int64_t block_size, const std::string &kv_cache_dtype,
10+
const std::vector<torch::Tensor> &kv_caches, bool kv_layout_blocks_first,
11+
torch::Tensor &slot_mapping, const int64_t block_size,
12+
const std::string &kv_cache_dtype,
1213
const std::vector<torch::Tensor> &k_scales,
1314
const std::vector<torch::Tensor> &v_scales, const std::string &layout_str);
1415

1516
void reshape_and_offload_multi_layer(
1617
const std::vector<torch::Tensor> &offload_kv_cache_blocks,
17-
const std::vector<torch::Tensor> &kv_caches, torch::Tensor &slot_mapping,
18-
const int64_t block_size, const std::string &kv_cache_dtype,
18+
const std::vector<torch::Tensor> &kv_caches, bool kv_layout_blocks_first,
19+
torch::Tensor &slot_mapping, const int64_t block_size,
20+
const std::string &kv_cache_dtype,
1921
const std::vector<torch::Tensor> &k_scales,
2022
const std::vector<torch::Tensor> &v_scales, const std::string &layout_str);

python/aibrix_kvcache/csrc/cache_kernels.cu

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,23 @@ torch::Tensor get_device_ptrs(const std::vector<TTensor> &tensors) {
5252
return gpu_data_ptrs;
5353
}
5454

55-
__device__ __forceinline__ int64_t
56-
get_kv_cache_offset(const int64_t kv_type, const int64_t num_blocks,
57-
const int64_t block_size, const int64_t embed_dim,
58-
const int64_t slot_idx, const int64_t scalar_offset) {
55+
__device__ __forceinline__ int64_t get_kv_cache_offset(
56+
const int64_t kv_type, bool kv_layout_blocks_first,
57+
const int64_t num_blocks, const int64_t block_size, const int64_t embed_dim,
58+
const int64_t slot_idx, const int64_t scalar_offset) {
5959
const int64_t block_idx = slot_idx / block_size;
6060
const int64_t block_offset = slot_idx % block_size;
61-
return kv_type * num_blocks * block_size * embed_dim +
62-
block_idx * block_size * embed_dim + block_offset * embed_dim +
63-
scalar_offset;
61+
if (kv_layout_blocks_first) {
62+
// [num_blocks, kv_type, block_size, num_heads, head_size]
63+
return block_idx * 2 * block_size * embed_dim +
64+
kv_type * block_size * embed_dim + block_offset * embed_dim +
65+
scalar_offset;
66+
} else {
67+
// [kv_type, num_blocks, block_size, num_heads, head_size]
68+
return kv_type * num_blocks * block_size * embed_dim +
69+
block_idx * block_size * embed_dim + block_offset * embed_dim +
70+
scalar_offset;
71+
}
6472
}
6573

6674
__device__ __forceinline__ int64_t get_offload_kv_cache_offset_lcnd(
@@ -96,19 +104,20 @@ enum class KVCacheOffloadLayout {
96104
* - offload_kv_cache: Supports LCND and NCLD layouts.
97105
* LCND: [num_blocks, num_layers, 2, block_size, dim]
98106
* NCLD: [num_blocks, block_size, 2, num_layers, dim]
99-
* - kv_cache: Supports both [num_layers, 2, num_blocks, block_size, num_heads,
100-
* head_size] and [num_layers, 2, num_blocks, block_size * num_heads *
101-
* head_size]
107+
* - kv_cache: Supports [num_layers, 2, num_blocks, block_size, num_heads,
108+
* head_size], [num_layers, 2, num_blocks, block_size * num_heads *
109+
* head_size], [num_layers, num_blocks, 2, block_size, num_heads, head_size],
110+
* and [num_layers, num_blocks, 2, block_size * num_heads * head_size]
102111
* - slot_mapping: [num_tokens]
103112
*/
104113
template <typename scalar_t, typename cache_t, vllm::Fp8KVCacheDataType kv_dt,
105114
bool TOnload, KVCacheOffloadLayout TLayout>
106115
__global__ void reshape_and_cache_multi_layer_kernel(
107116
scalar_t **__restrict__ offload_kv_cache,
108117
const int64_t offload_kv_cache_block_size, cache_t **__restrict__ kv_cache,
109-
const int64_t kv_cache_block_size, const int64_t kv_cache_num_blocks,
110-
const int64_t *__restrict__ slot_mapping, const int64_t num_layers,
111-
const int64_t embed_dim,
118+
bool kv_layout_blocks_first, const int64_t kv_cache_block_size,
119+
const int64_t kv_cache_num_blocks, const int64_t *__restrict__ slot_mapping,
120+
const int64_t num_layers, const int64_t embed_dim,
112121
const float **k_scales, // Scaling factor for keys
113122
const float **v_scales // Scaling factor for values
114123
) {
@@ -144,9 +153,9 @@ __global__ void reshape_and_cache_multi_layer_kernel(
144153
embed_dim, token_idx, i);
145154
}
146155

147-
int64_t kv_cache_offset =
148-
get_kv_cache_offset(kv_type, kv_cache_num_blocks, kv_cache_block_size,
149-
embed_dim, slot_idx, i);
156+
int64_t kv_cache_offset = get_kv_cache_offset(
157+
kv_type, kv_layout_blocks_first, kv_cache_num_blocks,
158+
kv_cache_block_size, embed_dim, slot_idx, i);
150159

151160
if (TOnload) { // true: offload_kv_cache to kv_cache
152161
kv_cache_layer[kv_cache_offset] =
@@ -170,18 +179,19 @@ __global__ void reshape_and_cache_multi_layer_kernel(
170179
* - offload_kv_cache: Supports LCND and NCLD layouts.
171180
* LCND: [num_blocks, num_layers, 2, block_size, dim]
172181
* NCLD: [num_blocks, block_size, 2, num_layers, dim]
173-
* - kv_cache: Supports both [num_layers, 2, num_blocks, block_size, num_heads,
174-
* head_size] and [num_layers, 2, num_blocks, block_size * num_heads *
175-
* head_size]
182+
* - kv_cache: Supports [num_layers, 2, num_blocks, block_size, num_heads,
183+
* head_size], [num_layers, 2, num_blocks, block_size * num_heads *
184+
* head_size], [num_layers, num_blocks, 2, block_size, num_heads, head_size],
185+
* and [num_layers, num_blocks, 2, block_size * num_heads * head_size]
176186
* - slot_mapping: [num_tokens]
177187
*/
178188
template <typename vec_t, bool TOnload, KVCacheOffloadLayout TLayout>
179189
__global__ void reshape_and_cache_multi_layer_vec_kernel(
180190
vec_t **__restrict__ offload_kv_cache,
181191
const int64_t offload_kv_cache_block_size, vec_t **__restrict__ kv_cache,
182-
const int64_t kv_cache_block_size, const int64_t kv_cache_num_blocks,
183-
const int64_t *__restrict__ slot_mapping, const int64_t num_layers,
184-
const int64_t num_vecs) {
192+
bool kv_layout_blocks_first, const int64_t kv_cache_block_size,
193+
const int64_t kv_cache_num_blocks, const int64_t *__restrict__ slot_mapping,
194+
const int64_t num_layers, const int64_t num_vecs) {
185195
const int64_t token_idx = blockIdx.x;
186196
const int64_t layer_idx = blockIdx.y;
187197
const int64_t kv_type = blockIdx.z;
@@ -211,9 +221,9 @@ __global__ void reshape_and_cache_multi_layer_vec_kernel(
211221
token_idx, i);
212222
}
213223

214-
int64_t kv_cache_offset =
215-
get_kv_cache_offset(kv_type, kv_cache_num_blocks, kv_cache_block_size,
216-
num_vecs, slot_idx, i);
224+
int64_t kv_cache_offset = get_kv_cache_offset(
225+
kv_type, kv_layout_blocks_first, kv_cache_num_blocks,
226+
kv_cache_block_size, num_vecs, slot_idx, i);
217227

218228
if (TOnload) { // true: offload_kv_cache to kv_cache
219229
kv_cache_layer[kv_cache_offset] =
@@ -236,9 +246,9 @@ __global__ void reshape_and_cache_multi_layer_vec_kernel(
236246
<<<grid, block, 0, stream>>>( \
237247
reinterpret_cast<KV_T **>(offload_kv_cache_ptrs.data_ptr()), \
238248
offload_kv_cache_block_size, \
239-
reinterpret_cast<CACHE_T **>(kv_cache_ptrs.data_ptr()), block_size, \
240-
kv_cache_num_blocks, slot_mapping.data_ptr<int64_t>(), num_layers, \
241-
embed_dim, \
249+
reinterpret_cast<CACHE_T **>(kv_cache_ptrs.data_ptr()), \
250+
kv_layout_blocks_first, block_size, kv_cache_num_blocks, \
251+
slot_mapping.data_ptr<int64_t>(), num_layers, embed_dim, \
242252
reinterpret_cast<const float **>(k_scale_ptrs.data_ptr()), \
243253
reinterpret_cast<const float **>(v_scale_ptrs.data_ptr()));
244254

@@ -247,9 +257,9 @@ __global__ void reshape_and_cache_multi_layer_vec_kernel(
247257
<<<grid, block, 0, stream>>>( \
248258
reinterpret_cast<vec_t **>(offload_kv_cache_ptrs.data_ptr()), \
249259
offload_kv_cache_block_size, \
250-
reinterpret_cast<vec_t **>(kv_cache_ptrs.data_ptr()), block_size, \
251-
kv_cache_num_blocks, slot_mapping.data_ptr<int64_t>(), num_layers, \
252-
num_vecs);
260+
reinterpret_cast<vec_t **>(kv_cache_ptrs.data_ptr()), \
261+
kv_layout_blocks_first, block_size, kv_cache_num_blocks, \
262+
slot_mapping.data_ptr<int64_t>(), num_layers, num_vecs);
253263

254264
#define DISPATCH_RESHAPE_AND_CACHE_MULTI_LAYER_BY_KV_CACHE_DTYPE \
255265
DISPATCH_BY_KV_CACHE_DTYPE(kv_caches[0].dtype(), kv_cache_dtype, \
@@ -280,7 +290,7 @@ __global__ void reshape_and_cache_multi_layer_vec_kernel(
280290
void reshape_and_cache_multi_layer_impl(
281291
const std::vector<torch::Tensor> &offload_kv_cache_blocks, // [num_blocks]
282292
const std::vector<torch::Tensor> &kv_caches, // [num_layers]
283-
torch::Tensor &slot_mapping, // [num_tokens]
293+
bool kv_layout_blocks_first, torch::Tensor &slot_mapping, // [num_tokens]
284294
const int64_t block_size, const std::string &kv_cache_dtype,
285295
const std::vector<torch::Tensor> &k_scales,
286296
const std::vector<torch::Tensor> &v_scales, bool onload,
@@ -291,12 +301,15 @@ void reshape_and_cache_multi_layer_impl(
291301
const int64_t num_tokens = slot_mapping.size(0);
292302
torch::IntArrayRef kv_cache_shape = kv_caches[0].sizes();
293303
int64_t embed_dim;
304+
294305
if (kv_cache_shape.size() == 3) {
295306
// [2, num_blocks, block_size * num_heads * head_size]
307+
// or [num_blocks, 2, block_size * num_heads * head_size]
296308
const int64_t block_dim = kv_caches[0].stride(1);
297309
embed_dim = block_dim / block_size;
298310
} else {
299311
// [2, num_blocks, block_size, num_heads, head_size]
312+
// or [num_blocks, 2, block_size, num_heads, head_size]
300313
embed_dim = kv_caches[0].stride(2);
301314
}
302315

@@ -322,7 +335,8 @@ void reshape_and_cache_multi_layer_impl(
322335
TORCH_CHECK(kv_cache_shape == kv_caches[i].sizes());
323336
}
324337

325-
const int64_t kv_cache_num_blocks = kv_cache_shape[1];
338+
const int64_t kv_cache_num_blocks =
339+
kv_layout_blocks_first ? kv_cache_shape[0] : kv_cache_shape[1];
326340

327341
torch::Tensor offload_kv_cache_ptrs =
328342
aibrix::get_device_ptrs(offload_kv_cache_blocks);
@@ -353,23 +367,23 @@ void reshape_and_cache_multi_layer_impl(
353367
void reshape_and_cache_multi_layer(
354368
const std::vector<torch::Tensor> &offload_kv_cache_blocks, // [num_blocks]
355369
const std::vector<torch::Tensor> &kv_caches, // [num_layers]
356-
torch::Tensor &slot_mapping, // [num_tokens]
370+
bool kv_layout_blocks_first, torch::Tensor &slot_mapping, // [num_tokens]
357371
const int64_t block_size, const std::string &kv_cache_dtype,
358372
const std::vector<torch::Tensor> &k_scales,
359373
const std::vector<torch::Tensor> &v_scales, const std::string &layout_str) {
360374
aibrix::reshape_and_cache_multi_layer_impl(
361-
offload_kv_cache_blocks, kv_caches, slot_mapping, block_size,
362-
kv_cache_dtype, k_scales, v_scales, true, layout_str);
375+
offload_kv_cache_blocks, kv_caches, kv_layout_blocks_first, slot_mapping,
376+
block_size, kv_cache_dtype, k_scales, v_scales, true, layout_str);
363377
}
364378

365379
void reshape_and_offload_multi_layer(
366380
const std::vector<torch::Tensor> &offload_kv_cache_blocks, // [num_blocks]
367381
const std::vector<torch::Tensor> &kv_caches, // [num_layers]
368-
torch::Tensor &slot_mapping, // [num_tokens]
382+
bool kv_layout_blocks_first, torch::Tensor &slot_mapping, // [num_tokens]
369383
const int64_t block_size, const std::string &kv_cache_dtype,
370384
const std::vector<torch::Tensor> &k_scales,
371385
const std::vector<torch::Tensor> &v_scales, const std::string &layout_str) {
372386
aibrix::reshape_and_cache_multi_layer_impl(
373-
offload_kv_cache_blocks, kv_caches, slot_mapping, block_size,
374-
kv_cache_dtype, k_scales, v_scales, false, layout_str);
387+
offload_kv_cache_blocks, kv_caches, kv_layout_blocks_first, slot_mapping,
388+
block_size, kv_cache_dtype, k_scales, v_scales, false, layout_str);
375389
}

python/aibrix_kvcache/csrc/torch_bindings.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
1919
// Cache ops
2020
cache_ops.def(
2121
"reshape_and_cache_multi_layer(Tensor[] offload_kv_cache_blocks,"
22-
" Tensor(b!)[] key_caches,"
22+
" Tensor(b!)[] kv_caches,"
23+
" bool kv_layout_blocks_first,"
2324
" Tensor slot_mapping,"
2425
" SymInt block_size,"
2526
" str kv_cache_dtype,"
@@ -30,7 +31,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
3031

3132
cache_ops.def(
3233
"reshape_and_offload_multi_layer(Tensor(a!)[] offload_kv_cache_blocks,"
33-
" Tensor[] key_caches,"
34+
" Tensor[] kv_caches,"
35+
" bool kv_layout_blocks_first,"
3436
" Tensor slot_mapping,"
3537
" SymInt block_size,"
3638
" str kv_cache_dtype,"

0 commit comments

Comments
 (0)