@@ -52,15 +52,23 @@ torch::Tensor get_device_ptrs(const std::vector<TTensor> &tensors) {
5252 return gpu_data_ptrs;
5353}
5454
55- __device__ __forceinline__ int64_t
56- get_kv_cache_offset ( const int64_t kv_type, const int64_t num_blocks ,
57- const int64_t block_size, const int64_t embed_dim,
58- const int64_t slot_idx, const int64_t scalar_offset) {
55+ __device__ __forceinline__ int64_t get_kv_cache_offset (
56+ const int64_t kv_type, bool kv_layout_blocks_first ,
57+ const int64_t num_blocks, const int64_t block_size, const int64_t embed_dim,
58+ const int64_t slot_idx, const int64_t scalar_offset) {
5959 const int64_t block_idx = slot_idx / block_size;
6060 const int64_t block_offset = slot_idx % block_size;
61- return kv_type * num_blocks * block_size * embed_dim +
62- block_idx * block_size * embed_dim + block_offset * embed_dim +
63- scalar_offset;
61+ if (kv_layout_blocks_first) {
62+ // [num_blocks, kv_type, block_size, num_heads, head_size]
63+ return block_idx * 2 * block_size * embed_dim +
64+ kv_type * block_size * embed_dim + block_offset * embed_dim +
65+ scalar_offset;
66+ } else {
67+ // [kv_type, num_blocks, block_size, num_heads, head_size]
68+ return kv_type * num_blocks * block_size * embed_dim +
69+ block_idx * block_size * embed_dim + block_offset * embed_dim +
70+ scalar_offset;
71+ }
6472}
6573
6674__device__ __forceinline__ int64_t get_offload_kv_cache_offset_lcnd (
@@ -96,19 +104,20 @@ enum class KVCacheOffloadLayout {
96104 * - offload_kv_cache: Supports LCND and NCLD layouts.
97105 * LCND: [num_blocks, num_layers, 2, block_size, dim]
98106 * NCLD: [num_blocks, block_size, 2, num_layers, dim]
99- * - kv_cache: Supports both [num_layers, 2, num_blocks, block_size, num_heads,
100- * head_size] and [num_layers, 2, num_blocks, block_size * num_heads *
101- * head_size]
107+ * - kv_cache: Supports [num_layers, 2, num_blocks, block_size, num_heads,
108+ * head_size], [num_layers, 2, num_blocks, block_size * num_heads *
109+ * head_size], [num_layers, num_blocks, 2, block_size, num_heads, head_size],
110+ * and [num_layers, num_blocks, 2, block_size * num_heads * head_size]
102111 * - slot_mapping: [num_tokens]
103112 */
104113template <typename scalar_t , typename cache_t , vllm::Fp8KVCacheDataType kv_dt,
105114 bool TOnload, KVCacheOffloadLayout TLayout>
106115__global__ void reshape_and_cache_multi_layer_kernel (
107116 scalar_t **__restrict__ offload_kv_cache,
108117 const int64_t offload_kv_cache_block_size, cache_t **__restrict__ kv_cache,
109- const int64_t kv_cache_block_size , const int64_t kv_cache_num_blocks ,
110- const int64_t * __restrict__ slot_mapping , const int64_t num_layers ,
111- const int64_t embed_dim,
118+ bool kv_layout_blocks_first , const int64_t kv_cache_block_size ,
119+ const int64_t kv_cache_num_blocks , const int64_t * __restrict__ slot_mapping ,
120+ const int64_t num_layers, const int64_t embed_dim,
112121 const float **k_scales, // Scaling factor for keys
113122 const float **v_scales // Scaling factor for values
114123) {
@@ -144,9 +153,9 @@ __global__ void reshape_and_cache_multi_layer_kernel(
144153 embed_dim, token_idx, i);
145154 }
146155
147- int64_t kv_cache_offset =
148- get_kv_cache_offset ( kv_type, kv_cache_num_blocks, kv_cache_block_size ,
149- embed_dim, slot_idx, i);
156+ int64_t kv_cache_offset = get_kv_cache_offset (
157+ kv_type, kv_layout_blocks_first, kv_cache_num_blocks ,
158+ kv_cache_block_size, embed_dim, slot_idx, i);
150159
151160 if (TOnload) { // true: offload_kv_cache to kv_cache
152161 kv_cache_layer[kv_cache_offset] =
@@ -170,18 +179,19 @@ __global__ void reshape_and_cache_multi_layer_kernel(
170179 * - offload_kv_cache: Supports LCND and NCLD layouts.
171180 * LCND: [num_blocks, num_layers, 2, block_size, dim]
172181 * NCLD: [num_blocks, block_size, 2, num_layers, dim]
173- * - kv_cache: Supports both [num_layers, 2, num_blocks, block_size, num_heads,
174- * head_size] and [num_layers, 2, num_blocks, block_size * num_heads *
175- * head_size]
182+ * - kv_cache: Supports [num_layers, 2, num_blocks, block_size, num_heads,
183+ * head_size], [num_layers, 2, num_blocks, block_size * num_heads *
184+ * head_size], [num_layers, num_blocks, 2, block_size, num_heads, head_size],
185+ * and [num_layers, num_blocks, 2, block_size * num_heads * head_size]
176186 * - slot_mapping: [num_tokens]
177187 */
178188template <typename vec_t , bool TOnload, KVCacheOffloadLayout TLayout>
179189__global__ void reshape_and_cache_multi_layer_vec_kernel (
180190 vec_t **__restrict__ offload_kv_cache,
181191 const int64_t offload_kv_cache_block_size, vec_t **__restrict__ kv_cache,
182- const int64_t kv_cache_block_size , const int64_t kv_cache_num_blocks ,
183- const int64_t * __restrict__ slot_mapping , const int64_t num_layers ,
184- const int64_t num_vecs) {
192+ bool kv_layout_blocks_first , const int64_t kv_cache_block_size ,
193+ const int64_t kv_cache_num_blocks , const int64_t * __restrict__ slot_mapping ,
194+ const int64_t num_layers, const int64_t num_vecs) {
185195 const int64_t token_idx = blockIdx .x ;
186196 const int64_t layer_idx = blockIdx .y ;
187197 const int64_t kv_type = blockIdx .z ;
@@ -211,9 +221,9 @@ __global__ void reshape_and_cache_multi_layer_vec_kernel(
211221 token_idx, i);
212222 }
213223
214- int64_t kv_cache_offset =
215- get_kv_cache_offset ( kv_type, kv_cache_num_blocks, kv_cache_block_size ,
216- num_vecs, slot_idx, i);
224+ int64_t kv_cache_offset = get_kv_cache_offset (
225+ kv_type, kv_layout_blocks_first, kv_cache_num_blocks ,
226+ kv_cache_block_size, num_vecs, slot_idx, i);
217227
218228 if (TOnload) { // true: offload_kv_cache to kv_cache
219229 kv_cache_layer[kv_cache_offset] =
@@ -236,9 +246,9 @@ __global__ void reshape_and_cache_multi_layer_vec_kernel(
236246 <<<grid, block, 0 , stream>>> ( \
237247 reinterpret_cast <KV_T **>(offload_kv_cache_ptrs.data_ptr()), \
238248 offload_kv_cache_block_size, \
239- reinterpret_cast <CACHE_T **>(kv_cache_ptrs.data_ptr()), block_size, \
240- kv_cache_num_blocks, slot_mapping.data_ptr< int64_t >(), num_layers, \
241- embed_dim, \
249+ reinterpret_cast <CACHE_T **>(kv_cache_ptrs.data_ptr()), \
250+ kv_layout_blocks_first, block_size, kv_cache_num_blocks, \
251+ slot_mapping.data_ptr< int64_t >(), num_layers, embed_dim, \
242252 reinterpret_cast <const float **>(k_scale_ptrs.data_ptr()), \
243253 reinterpret_cast <const float **>(v_scale_ptrs.data_ptr()));
244254
@@ -247,9 +257,9 @@ __global__ void reshape_and_cache_multi_layer_vec_kernel(
247257 <<<grid, block, 0 , stream>>> ( \
248258 reinterpret_cast <vec_t **>(offload_kv_cache_ptrs.data_ptr()), \
249259 offload_kv_cache_block_size, \
250- reinterpret_cast <vec_t **>(kv_cache_ptrs.data_ptr()), block_size, \
251- kv_cache_num_blocks, slot_mapping.data_ptr< int64_t >(), num_layers, \
252- num_vecs);
260+ reinterpret_cast <vec_t **>(kv_cache_ptrs.data_ptr()), \
261+ kv_layout_blocks_first, block_size, kv_cache_num_blocks, \
262+ slot_mapping.data_ptr< int64_t >(), num_layers, num_vecs);
253263
254264#define DISPATCH_RESHAPE_AND_CACHE_MULTI_LAYER_BY_KV_CACHE_DTYPE \
255265 DISPATCH_BY_KV_CACHE_DTYPE (kv_caches[0 ].dtype(), kv_cache_dtype, \
@@ -280,7 +290,7 @@ __global__ void reshape_and_cache_multi_layer_vec_kernel(
280290void reshape_and_cache_multi_layer_impl (
281291 const std::vector<torch::Tensor> &offload_kv_cache_blocks, // [num_blocks]
282292 const std::vector<torch::Tensor> &kv_caches, // [num_layers]
283- torch::Tensor &slot_mapping, // [num_tokens]
293+ bool kv_layout_blocks_first, torch::Tensor &slot_mapping, // [num_tokens]
284294 const int64_t block_size, const std::string &kv_cache_dtype,
285295 const std::vector<torch::Tensor> &k_scales,
286296 const std::vector<torch::Tensor> &v_scales, bool onload,
@@ -291,12 +301,15 @@ void reshape_and_cache_multi_layer_impl(
291301 const int64_t num_tokens = slot_mapping.size (0 );
292302 torch::IntArrayRef kv_cache_shape = kv_caches[0 ].sizes ();
293303 int64_t embed_dim;
304+
294305 if (kv_cache_shape.size () == 3 ) {
295306 // [2, num_blocks, block_size * num_heads * head_size]
307+ // or [num_blocks, 2, block_size * num_heads * head_size]
296308 const int64_t block_dim = kv_caches[0 ].stride (1 );
297309 embed_dim = block_dim / block_size;
298310 } else {
299311 // [2, num_blocks, block_size, num_heads, head_size]
312+ // or [num_blocks, 2, block_size, num_heads, head_size]
300313 embed_dim = kv_caches[0 ].stride (2 );
301314 }
302315
@@ -322,7 +335,8 @@ void reshape_and_cache_multi_layer_impl(
322335 TORCH_CHECK (kv_cache_shape == kv_caches[i].sizes ());
323336 }
324337
325- const int64_t kv_cache_num_blocks = kv_cache_shape[1 ];
338+ const int64_t kv_cache_num_blocks =
339+ kv_layout_blocks_first ? kv_cache_shape[0 ] : kv_cache_shape[1 ];
326340
327341 torch::Tensor offload_kv_cache_ptrs =
328342 aibrix::get_device_ptrs (offload_kv_cache_blocks);
@@ -353,23 +367,23 @@ void reshape_and_cache_multi_layer_impl(
353367void reshape_and_cache_multi_layer (
354368 const std::vector<torch::Tensor> &offload_kv_cache_blocks, // [num_blocks]
355369 const std::vector<torch::Tensor> &kv_caches, // [num_layers]
356- torch::Tensor &slot_mapping, // [num_tokens]
370+ bool kv_layout_blocks_first, torch::Tensor &slot_mapping, // [num_tokens]
357371 const int64_t block_size, const std::string &kv_cache_dtype,
358372 const std::vector<torch::Tensor> &k_scales,
359373 const std::vector<torch::Tensor> &v_scales, const std::string &layout_str) {
360374 aibrix::reshape_and_cache_multi_layer_impl (
361- offload_kv_cache_blocks, kv_caches, slot_mapping, block_size ,
362- kv_cache_dtype, k_scales, v_scales, true , layout_str);
375+ offload_kv_cache_blocks, kv_caches, kv_layout_blocks_first, slot_mapping ,
376+ block_size, kv_cache_dtype, k_scales, v_scales, true , layout_str);
363377}
364378
365379void reshape_and_offload_multi_layer (
366380 const std::vector<torch::Tensor> &offload_kv_cache_blocks, // [num_blocks]
367381 const std::vector<torch::Tensor> &kv_caches, // [num_layers]
368- torch::Tensor &slot_mapping, // [num_tokens]
382+ bool kv_layout_blocks_first, torch::Tensor &slot_mapping, // [num_tokens]
369383 const int64_t block_size, const std::string &kv_cache_dtype,
370384 const std::vector<torch::Tensor> &k_scales,
371385 const std::vector<torch::Tensor> &v_scales, const std::string &layout_str) {
372386 aibrix::reshape_and_cache_multi_layer_impl (
373- offload_kv_cache_blocks, kv_caches, slot_mapping, block_size ,
374- kv_cache_dtype, k_scales, v_scales, false , layout_str);
387+ offload_kv_cache_blocks, kv_caches, kv_layout_blocks_first, slot_mapping ,
388+ block_size, kv_cache_dtype, k_scales, v_scales, false , layout_str);
375389}
0 commit comments