Skip to content

Commit f8f0a47

Browse files
cuda: reserve space for quantize kv-cache at startup (#23907)
* cuda: reserve space for quantize kv-cache at startup * address review comments * remove forward decl Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * remove assert in ggml-cuda.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
1 parent 06938ac commit f8f0a47

4 files changed

Lines changed: 96 additions & 14 deletions

File tree

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,46 @@ typedef void (* fattn_kernel_t)(
4444
typedef float (*vec_dot_KQ_t)(
4545
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
4646

47+
struct ggml_cuda_flash_attn_ext_f16_extra_data {
48+
uintptr_t K;
49+
uintptr_t V;
50+
uintptr_t end;
51+
};
52+
53+
static inline ggml_cuda_flash_attn_ext_f16_extra_data ggml_cuda_flash_attn_ext_get_f16_extra_data(
54+
const ggml_tensor * dst, const bool need_f16_K, const bool need_f16_V) {
55+
GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT);
56+
57+
const ggml_tensor * K = dst->src[1];
58+
const ggml_tensor * V = dst->src[2];
59+
60+
GGML_ASSERT(K != nullptr);
61+
GGML_ASSERT(V != nullptr);
62+
63+
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
64+
65+
ggml_cuda_flash_attn_ext_f16_extra_data data = {};
66+
data.end = (uintptr_t) dst->data + ggml_nbytes(dst);
67+
68+
if (need_f16_K && K->type != GGML_TYPE_F16) {
69+
data.end = GGML_PAD(data.end, 128);
70+
data.K = data.end;
71+
data.end += ggml_nelements(K)*ggml_type_size(GGML_TYPE_F16);
72+
}
73+
74+
if (need_f16_V && V->type != GGML_TYPE_F16) {
75+
if (V_is_K_view) {
76+
data.V = data.K;
77+
} else {
78+
data.end = GGML_PAD(data.end, 128);
79+
data.V = data.end;
80+
data.end += ggml_nelements(V)*ggml_type_size(GGML_TYPE_F16);
81+
}
82+
}
83+
84+
return data;
85+
}
86+
4787
template <int D, int nthreads>
4888
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
4989
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
@@ -952,8 +992,9 @@ void launch_fattn(
952992
const int cc = ggml_cuda_info().devices[id].cc;
953993
const int nsm = ggml_cuda_info().devices[id].nsm;
954994

955-
ggml_cuda_pool_alloc<half> K_f16(pool);
956-
ggml_cuda_pool_alloc<half> V_f16(pool);
995+
const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra =
996+
ggml_cuda_flash_attn_ext_get_f16_extra_data(KQV, need_f16_K, need_f16_V);
997+
957998
ggml_cuda_pool_alloc<int> KV_max(pool);
958999
ggml_cuda_pool_alloc<float> dst_tmp(pool);
9591000
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
@@ -972,10 +1013,11 @@ void launch_fattn(
9721013
const size_t bs = ggml_blck_size(K->type);
9731014
const size_t ts = ggml_type_size(K->type);
9741015

975-
K_f16.alloc(ggml_nelements(K));
1016+
GGML_ASSERT(f16_extra.K != 0);
1017+
half * K_f16 = (half *) f16_extra.K;
9761018
if (ggml_is_contiguously_allocated(K)) {
9771019
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
978-
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
1020+
to_fp16(K_data, K_f16, ggml_nelements(K), main_stream);
9791021

9801022
nb11 = nb11*bs*sizeof(half)/ts;
9811023
nb12 = nb12*bs*sizeof(half)/ts;
@@ -986,13 +1028,13 @@ void launch_fattn(
9861028
const int64_t s01 = nb11 / ts;
9871029
const int64_t s02 = nb12 / ts;
9881030
const int64_t s03 = nb13 / ts;
989-
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
1031+
to_fp16(K_data, K_f16, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
9901032

9911033
nb11 = K->ne[0] * sizeof(half);
9921034
nb12 = K->ne[1] * nb11;
9931035
nb13 = K->ne[2] * nb12;
9941036
}
995-
K_data = (char *) K_f16.ptr;
1037+
K_data = (char *) K_f16;
9961038
}
9971039

9981040
if (need_f16_V && V->type != GGML_TYPE_F16) {
@@ -1005,11 +1047,12 @@ void launch_fattn(
10051047
const size_t bs = ggml_blck_size(V->type);
10061048
const size_t ts = ggml_type_size(V->type);
10071049

1008-
V_f16.alloc(ggml_nelements(V));
1050+
GGML_ASSERT(f16_extra.V != 0);
1051+
half * V_f16 = (half *) f16_extra.V;
10091052
if (ggml_is_contiguously_allocated(V)) {
10101053
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
1011-
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
1012-
V_data = (char *) V_f16.ptr;
1054+
to_fp16(V_data, V_f16, ggml_nelements(V), main_stream);
1055+
V_data = (char *) V_f16;
10131056

10141057
nb21 = nb21*bs*sizeof(half)/ts;
10151058
nb22 = nb22*bs*sizeof(half)/ts;
@@ -1020,13 +1063,13 @@ void launch_fattn(
10201063
const int64_t s01 = nb21 / ts;
10211064
const int64_t s02 = nb22 / ts;
10221065
const int64_t s03 = nb23 / ts;
1023-
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
1066+
to_fp16(V_data, V_f16, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
10241067

10251068
nb21 = V->ne[0] * sizeof(half);
10261069
nb22 = V->ne[1] * nb21;
10271070
nb23 = V->ne[2] * nb22;
10281071
}
1029-
V_data = (char *) V_f16.ptr;
1072+
V_data = (char *) V_f16;
10301073
}
10311074
}
10321075

ggml/src/ggml-cuda/fattn.cu

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,41 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
537537
return BEST_FATTN_KERNEL_TILE;
538538
}
539539

540+
size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst) {
541+
GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT);
542+
543+
const ggml_tensor * K = dst->src[1];
544+
const ggml_tensor * V = dst->src[2];
545+
546+
GGML_ASSERT(K != nullptr);
547+
GGML_ASSERT(V != nullptr);
548+
549+
const best_fattn_kernel kernel = ggml_cuda_get_best_fattn_kernel(device, dst);
550+
551+
bool need_f16_K = false;
552+
bool need_f16_V = false;
553+
554+
switch (kernel) {
555+
case BEST_FATTN_KERNEL_TILE:
556+
case BEST_FATTN_KERNEL_WMMA_F16:
557+
case BEST_FATTN_KERNEL_MMA_F16:
558+
need_f16_K = true;
559+
need_f16_V = true;
560+
break;
561+
case BEST_FATTN_KERNEL_VEC:
562+
need_f16_K = K->type == GGML_TYPE_F32;
563+
need_f16_V = V->type == GGML_TYPE_F32;
564+
break;
565+
case BEST_FATTN_KERNEL_NONE:
566+
break;
567+
}
568+
569+
const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra =
570+
ggml_cuda_flash_attn_ext_get_f16_extra_data(dst, need_f16_K, need_f16_V);
571+
572+
return f16_extra.end - (uintptr_t) dst->data;
573+
}
574+
540575
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
541576
ggml_cuda_set_device(ctx.device);
542577
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {

ggml/src/ggml-cuda/fattn.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
44

55
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst);
6+
7+
size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,11 @@ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_ty
801801
}
802802

803803
static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
804-
size_t size = ggml_nbytes(tensor);
804+
ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *) buft->context;
805+
806+
size_t size = tensor->op == GGML_OP_FLASH_ATTN_EXT
807+
? ggml_cuda_flash_attn_ext_get_alloc_size(buft_ctx->device, tensor)
808+
: ggml_nbytes(tensor);
805809
int64_t ne0 = tensor->ne[0];
806810

807811
if (ggml_is_quantized(tensor->type)) {
@@ -812,8 +816,6 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t
812816
}
813817

814818
return size;
815-
816-
GGML_UNUSED(buft);
817819
}
818820

819821
static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {

0 commit comments

Comments
 (0)