44
55#include "include/batch_headers/common.cl"
66
7- inline void FUNC (quantize_and_save_k )(__global const INPUT0_TYPE * in_data ,
7+ inline void FUNC (quantize_and_save )(__global const INPUT0_TYPE * in_data ,
88 const uint in_data_offset ,
99 __global OUTPUT_TYPE * out_data ,
1010 const uint out_data_offset ,
1111 const uint out_data_pitch ,
1212 const uint comp_offset ,
1313 const uint token_pos_in_block ,
14- const uint sglid ) {
15- INPUT0_TYPE input_data [K_HEAD_SIZE / SUBGROUP_SIZE ];
14+ const uint sglid ,
15+ const uint num_groups ,
16+ INPUT0_TYPE * input_data ) {
1617 INPUT0_TYPE grp_max = 0.001 ;
1718 INPUT0_TYPE max_value = INPUT0_VAL_MIN ;
1819 INPUT0_TYPE min_value = INPUT0_VAL_MAX ;
1920
20- unroll_for (uint i = 0 ; i < K_HEAD_SIZE / SUBGROUP_SIZE ; i ++ ) {
21+ unroll_for (uint i = 0 ; i < num_groups ; i ++ ) {
2122 input_data [i ] = BLOCK_READN (INPUT0_TYPE , 1 , in_data , in_data_offset + i * SUBGROUP_SIZE );
2223 max_value = fmax (max_value , input_data [i ]);
2324 min_value = fmin (min_value , input_data [i ]);
@@ -35,54 +36,7 @@ inline void FUNC(quantize_and_save_k)(__global const INPUT0_TYPE* in_data,
3536 INPUT0_TYPE zp = (INPUT1_TYPE )(zp_tmp );
3637 #undef ACCUMULATOR_TYPE
3738
38- unroll_for (uint i = 0 ; i < K_HEAD_SIZE / SUBGROUP_SIZE ; i ++ ) {
39- OUTPUT_TYPE res = convert_char_rte (input_data [i ] * scale + zp );
40-
41- uint offset = out_data_offset + (i * SUBGROUP_SIZE + sglid ) * out_data_pitch ;
42- out_data [offset ] = res ;
43- }
44-
45- INPUT0_TYPE * comp_ptr = out_data + comp_offset ;
46-
47- if (sglid == 0 ) {
48- comp_ptr [token_pos_in_block ] = 1.0 / scale ;
49- comp_ptr [PAGED_ATTENTION_BLOCK_SIZE + token_pos_in_block ] = zp ;
50- }
51- }
52-
53- inline void FUNC (quantize_and_save_v )(__global const INPUT0_TYPE * in_data ,
54- const uint in_data_offset ,
55- __global OUTPUT_TYPE * out_data ,
56- const uint out_data_offset ,
57- const uint out_data_pitch ,
58- const uint comp_offset ,
59- const uint token_pos_in_block ,
60- const uint sglid ) {
61- INPUT0_TYPE input_data [V_HEAD_SIZE / SUBGROUP_SIZE ];
62- INPUT0_TYPE grp_max = 0.001 ;
63- INPUT0_TYPE max_value = INPUT0_VAL_MIN ;
64- INPUT0_TYPE min_value = INPUT0_VAL_MAX ;
65-
66- unroll_for (uint i = 0 ; i < V_HEAD_SIZE / SUBGROUP_SIZE ; i ++ ) {
67- input_data [i ] = BLOCK_READN (INPUT0_TYPE , 1 , in_data , in_data_offset + i * SUBGROUP_SIZE );
68- max_value = fmax (max_value , input_data [i ]);
69- min_value = fmin (min_value , input_data [i ]);
70- }
71-
72- min_value = sub_group_reduce_min (min_value );
73- max_value = sub_group_reduce_max (max_value );
74-
75- // If the range of input data is zero, it is adjusted to the minimum value(0.001).
76- #define ACCUMULATOR_TYPE float
77- ACCUMULATOR_TYPE diff_value = max_value == min_value ? (grp_max ) : (max_value - min_value );
78- ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE )((CHAR_MAX - CHAR_MIN ) / diff_value );
79- ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE )(- min_value * scale_tmp ) + CHAR_MIN ;
80- INPUT0_TYPE scale = (INPUT1_TYPE )(scale_tmp );
81- INPUT0_TYPE zp = (INPUT1_TYPE )(zp_tmp );
82- #undef ACCUMULATOR_TYPE
83-
84-
85- unroll_for (uint i = 0 ; i < V_HEAD_SIZE / SUBGROUP_SIZE ; i ++ ) {
39+ unroll_for (uint i = 0 ; i < num_groups ; i ++ ) {
8640 OUTPUT_TYPE res = convert_char_rte (input_data [i ] * scale + zp );
8741
8842 uint offset = out_data_offset + (i * SUBGROUP_SIZE + sglid ) * out_data_pitch ;
@@ -178,11 +132,19 @@ KERNEL(pa_kv_cache_update)(
178132 }
179133
180134#else // IS_KV_COMPRESSED
181- // key processing
182- FUNC_CALL (quantize_and_save_k )(key_data , key_in_offset , key_cache_data , key_out_offset , PAGED_ATTENTION_BLOCK_SIZE , comp_k_offset , current_token_pos_in_block , sglid );
135+ {
136+ // key processing
137+ INPUT0_TYPE input_data [K_HEAD_SIZE / SUBGROUP_SIZE ];
138+ FUNC_CALL (quantize_and_save )(key_data , key_in_offset , key_cache_data , key_out_offset , PAGED_ATTENTION_BLOCK_SIZE , comp_k_offset ,
139+ current_token_pos_in_block , sglid , K_HEAD_SIZE / SUBGROUP_SIZE , & input_data [0 ]);
140+ }
183141
184- // value processing
185- FUNC_CALL (quantize_and_save_v )(value_data , value_in_offset , value_cache_data , value_out_offset , 1 , comp_v_offset , current_token_pos_in_block , sglid );
142+ {
143+ // value processing
144+ INPUT0_TYPE input_data [V_HEAD_SIZE / SUBGROUP_SIZE ];
145+ FUNC_CALL (quantize_and_save )(value_data , value_in_offset , value_cache_data , value_out_offset , 1 , comp_v_offset ,
146+ current_token_pos_in_block , sglid , V_HEAD_SIZE / SUBGROUP_SIZE , & input_data [0 ]);
147+ }
186148#endif // IS_KV_COMPRESSED
187149 } else {
188150 // 1st token
@@ -343,11 +305,19 @@ KERNEL(pa_kv_cache_update)(
343305 }
344306
345307#else // IS_KV_COMPRESSED
308+ {
346309 // key processing
347- FUNC_CALL (quantize_and_save_k )(key_data , key_in_offset , key_cache_data , key_out_offset , PAGED_ATTENTION_BLOCK_SIZE , comp_k_offset , token_num , sglid );
310+ INPUT0_TYPE input_data [K_HEAD_SIZE / SUBGROUP_SIZE ];
311+ FUNC_CALL (quantize_and_save )(key_data , key_in_offset , key_cache_data , key_out_offset , PAGED_ATTENTION_BLOCK_SIZE ,
312+ comp_k_offset , token_num , sglid , K_HEAD_SIZE / SUBGROUP_SIZE , & input_data [0 ]);
313+ }
348314
315+ {
349316 // value processing
350- FUNC_CALL (quantize_and_save_v )(value_data , value_in_offset , value_cache_data , value_out_offset , 1 , comp_v_offset , token_num , sglid );
317+ INPUT0_TYPE input_data [V_HEAD_SIZE / SUBGROUP_SIZE ];
318+ FUNC_CALL (quantize_and_save )(value_data , value_in_offset , value_cache_data , value_out_offset , 1 ,
319+ comp_v_offset , token_num , sglid , V_HEAD_SIZE / SUBGROUP_SIZE , & input_data [0 ]);
320+ }
351321#endif // IS_KV_COMPRESSED
352322
353323 key_in_offset += (KV_HEADS_NUM * K_HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM );
@@ -379,14 +349,22 @@ KERNEL(pa_kv_cache_update)(
379349 uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i ;
380350 value_cache_data [value_offset ] = input_data ;
381351 }
382- }
352+ }
383353
384354#else // IS_KV_COMPRESSED
385- // key processing
386- FUNC_CALL (quantize_and_save_k )(key_data , key_in_offset , key_cache_data , key_out_offset , PAGED_ATTENTION_BLOCK_SIZE , comp_k_offset , token_start_pos + token_num , sglid );
355+ {
356+ // key processing
357+ INPUT0_TYPE input_data [K_HEAD_SIZE / SUBGROUP_SIZE ];
358+ FUNC_CALL (quantize_and_save )(key_data , key_in_offset , key_cache_data , key_out_offset , PAGED_ATTENTION_BLOCK_SIZE ,
359+ comp_k_offset , token_start_pos + token_num , sglid , K_HEAD_SIZE / SUBGROUP_SIZE , & input_data [0 ]);
360+ }
387361
388- // value processing
389- FUNC_CALL (quantize_and_save_v )(value_data , value_in_offset , value_cache_data , value_out_offset , 1 , comp_v_offset , token_start_pos + token_num , sglid );
362+ {
363+ // value processing
364+ INPUT0_TYPE input_data [V_HEAD_SIZE / SUBGROUP_SIZE ];
365+ FUNC_CALL (quantize_and_save )(value_data , value_in_offset , value_cache_data , value_out_offset , 1 ,
366+ comp_v_offset , token_start_pos + token_num , sglid , V_HEAD_SIZE / SUBGROUP_SIZE , & input_data [0 ]);
367+ }
390368#endif // IS_KV_COMPRESSED
391369 key_in_offset += (KV_HEADS_NUM * K_HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM );
392370 value_in_offset += (KV_HEADS_NUM * V_HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM );
0 commit comments