@@ -44,6 +44,46 @@ typedef void (* fattn_kernel_t)(
4444typedef float (*vec_dot_KQ_t)(
4545 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
4646
47+ struct ggml_cuda_flash_attn_ext_f16_extra_data {
48+ uintptr_t K;
49+ uintptr_t V;
50+ uintptr_t end;
51+ };
52+
53+ static inline ggml_cuda_flash_attn_ext_f16_extra_data ggml_cuda_flash_attn_ext_get_f16_extra_data (
54+ const ggml_tensor * dst, const bool need_f16_K, const bool need_f16_V) {
55+ GGML_ASSERT (dst->op == GGML_OP_FLASH_ATTN_EXT );
56+
57+ const ggml_tensor * K = dst->src [1 ];
58+ const ggml_tensor * V = dst->src [2 ];
59+
60+ GGML_ASSERT (K != nullptr );
61+ GGML_ASSERT (V != nullptr );
62+
63+ const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs ));
64+
65+ ggml_cuda_flash_attn_ext_f16_extra_data data = {};
66+ data.end = (uintptr_t ) dst->data + ggml_nbytes (dst);
67+
68+ if (need_f16_K && K->type != GGML_TYPE_F16 ) {
69+ data.end = GGML_PAD (data.end , 128 );
70+ data.K = data.end ;
71+ data.end += ggml_nelements (K)*ggml_type_size (GGML_TYPE_F16 );
72+ }
73+
74+ if (need_f16_V && V->type != GGML_TYPE_F16 ) {
75+ if (V_is_K_view) {
76+ data.V = data.K ;
77+ } else {
78+ data.end = GGML_PAD (data.end , 128 );
79+ data.V = data.end ;
80+ data.end += ggml_nelements (V)*ggml_type_size (GGML_TYPE_F16 );
81+ }
82+ }
83+
84+ return data;
85+ }
86+
4787template <int D, int nthreads>
4888static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16 (
4989 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
@@ -952,8 +992,9 @@ void launch_fattn(
952992 const int cc = ggml_cuda_info ().devices [id].cc ;
953993 const int nsm = ggml_cuda_info ().devices [id].nsm ;
954994
955- ggml_cuda_pool_alloc<half> K_f16 (pool);
956- ggml_cuda_pool_alloc<half> V_f16 (pool);
995+ const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra =
996+ ggml_cuda_flash_attn_ext_get_f16_extra_data (KQV , need_f16_K, need_f16_V);
997+
957998 ggml_cuda_pool_alloc<int > KV_max (pool);
958999 ggml_cuda_pool_alloc<float > dst_tmp (pool);
9591000 ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
@@ -972,10 +1013,11 @@ void launch_fattn(
9721013 const size_t bs = ggml_blck_size (K->type );
9731014 const size_t ts = ggml_type_size (K->type );
9741015
975- K_f16.alloc (ggml_nelements (K));
1016+ GGML_ASSERT (f16_extra.K != 0 );
1017+ half * K_f16 = (half *) f16_extra.K ;
9761018 if (ggml_is_contiguously_allocated (K)) {
9771019 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (K->type );
978- to_fp16 (K_data, K_f16. ptr , ggml_nelements (K), main_stream);
1020+ to_fp16 (K_data, K_f16, ggml_nelements (K), main_stream);
9791021
9801022 nb11 = nb11*bs*sizeof (half)/ts;
9811023 nb12 = nb12*bs*sizeof (half)/ts;
@@ -986,13 +1028,13 @@ void launch_fattn(
9861028 const int64_t s01 = nb11 / ts;
9871029 const int64_t s02 = nb12 / ts;
9881030 const int64_t s03 = nb13 / ts;
989- to_fp16 (K_data, K_f16. ptr , K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ], s01, s02, s03, main_stream);
1031+ to_fp16 (K_data, K_f16, K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ], s01, s02, s03, main_stream);
9901032
9911033 nb11 = K->ne [0 ] * sizeof (half);
9921034 nb12 = K->ne [1 ] * nb11;
9931035 nb13 = K->ne [2 ] * nb12;
9941036 }
995- K_data = (char *) K_f16. ptr ;
1037+ K_data = (char *) K_f16;
9961038 }
9971039
9981040 if (need_f16_V && V->type != GGML_TYPE_F16 ) {
@@ -1005,11 +1047,12 @@ void launch_fattn(
10051047 const size_t bs = ggml_blck_size (V->type );
10061048 const size_t ts = ggml_type_size (V->type );
10071049
1008- V_f16.alloc (ggml_nelements (V));
1050+ GGML_ASSERT (f16_extra.V != 0 );
1051+ half * V_f16 = (half *) f16_extra.V ;
10091052 if (ggml_is_contiguously_allocated (V)) {
10101053 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (V->type );
1011- to_fp16 (V_data, V_f16. ptr , ggml_nelements (V), main_stream);
1012- V_data = (char *) V_f16. ptr ;
1054+ to_fp16 (V_data, V_f16, ggml_nelements (V), main_stream);
1055+ V_data = (char *) V_f16;
10131056
10141057 nb21 = nb21*bs*sizeof (half)/ts;
10151058 nb22 = nb22*bs*sizeof (half)/ts;
@@ -1020,13 +1063,13 @@ void launch_fattn(
10201063 const int64_t s01 = nb21 / ts;
10211064 const int64_t s02 = nb22 / ts;
10221065 const int64_t s03 = nb23 / ts;
1023- to_fp16 (V_data, V_f16. ptr , V->ne [0 ], V->ne [1 ], V->ne [2 ], V->ne [3 ], s01, s02, s03, main_stream);
1066+ to_fp16 (V_data, V_f16, V->ne [0 ], V->ne [1 ], V->ne [2 ], V->ne [3 ], s01, s02, s03, main_stream);
10241067
10251068 nb21 = V->ne [0 ] * sizeof (half);
10261069 nb22 = V->ne [1 ] * nb21;
10271070 nb23 = V->ne [2 ] * nb22;
10281071 }
1029- V_data = (char *) V_f16. ptr ;
1072+ V_data = (char *) V_f16;
10301073 }
10311074 }
10321075
0 commit comments