@@ -3034,13 +3034,12 @@ struct MHAHelper {
30343034 }
30353035
30363036 // attn_w * V
3037- _output_bhl.resize <float >({static_cast <size_t >(_nthr), B, q_len, H, SV});
3038- // m_attn_w {B, H, q_len, kv_len}
3039- parallel_nt_static (_nthr, [&](const size_t ithr, [[maybe_unused]] const size_t nthr) {
3040- memset (_output_bhl.ptr <float >(ithr, 0 , 0 , 0 , 0 ), 0 , _output_bhl.stride (0 ) * sizeof (float ));
3037+ _output_bhl.resize <float >({B, kv_len_in_blocks, H, q_len, SV});
3038+ parallel_for3d (B, kv_len_in_blocks, H, [&](size_t b, size_t pv_in_blocks, size_t h) {
3039+ memset (_output_bhl.ptr <float >(b, pv_in_blocks, h, 0 , 0 ), 0 , q_len * SV * sizeof (float ));
30413040 });
30423041
3043- auto loop_wk_static = [&](size_t ithr, size_t b, size_t pv_in_blocks, size_t hx) {
3042+ auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) {
30443043 auto context_len = static_cast <size_t >(past_lens.ptr <int32_t >()[b]) + 1 ;
30453044 auto pv = pv_in_blocks * _block_size;
30463045 size_t hk;
@@ -3055,7 +3054,7 @@ struct MHAHelper {
30553054 for (size_t h = hq_beg; h < hq_end; h++) {
30563055 if constexpr (one_of (VALUE_PREC, ov::element::u8 , ov::element::u4)) {
30573056 attn_acc_value_block_quantized<uint8_t , VALUE_PREC>(
3058- _output_bhl.ptr <float >(ithr, b, pq , h),
3057+ _output_bhl.ptr <float >(b, pv_in_blocks , h, pq ),
30593058 _weight_bhl.ptr <float >(b, h, pq) + pv,
30603059 value_cache.ptr <uint8_t , VALUE_PREC>(block_number, hk),
30613060 SV,
@@ -3066,45 +3065,7 @@ struct MHAHelper {
30663065 auto * v_ptr =
30673066 value_cache.ptr <typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
30683067 attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
3069- _output_bhl.ptr <float >(ithr, b, pq, h),
3070- _weight_bhl.ptr <float >(b, h, pq) + pv,
3071- v_ptr,
3072- SV,
3073- std::min (_block_size, context_len - pv),
3074- _value_group_size);
3075- }
3076- }
3077- }
3078- }
3079- };
3080-
3081- // TODO: align with loop_wk_static
3082- auto loop_wk_dynamic = [&](size_t b, size_t pv_in_blocks, size_t hx) {
3083- auto ithr = parallel_get_thread_num ();
3084- auto context_len = static_cast <size_t >(past_lens.ptr <int32_t >()[b]) + 1 ;
3085- auto pv = pv_in_blocks * _block_size;
3086- size_t hk, hq_beg, hq_end;
3087- get_h_params (loop_hk, hx, _h_each_group_len, hq_beg, hq_end, hk);
3088-
3089- // kv_len must be valid
3090- if (pv < context_len) {
3091- auto block_number = block_indices.ptr <int32_t >()[block_indices_begins.ptr <int32_t >()[b] + pv_in_blocks];
3092- for (size_t pq = 0 ; pq < q_len; pq++) {
3093- for (size_t h = hq_beg; h < hq_end; h++) {
3094- if constexpr (one_of (VALUE_PREC, ov::element::u8 , ov::element::u4)) {
3095- attn_acc_value_block_quantized<uint8_t , VALUE_PREC>(
3096- _output_bhl.ptr <float >(ithr, b, pq, h),
3097- _weight_bhl.ptr <float >(b, h, pq) + pv,
3098- value_cache.ptr <uint8_t , VALUE_PREC>(block_number, hk),
3099- SV,
3100- _quant_value_bychannel,
3101- std::min (_block_size, context_len - pv),
3102- _value_group_size);
3103- } else {
3104- auto * v_ptr =
3105- value_cache.ptr <typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
3106- attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
3107- _output_bhl.ptr <float >(ithr, b, pq, h),
3068+ _output_bhl.ptr <float >(b, pv_in_blocks, h, pq),
31083069 _weight_bhl.ptr <float >(b, h, pq) + pv,
31093070 v_ptr,
31103071 SV,
@@ -3117,16 +3078,16 @@ struct MHAHelper {
31173078 };
31183079
31193080 if (prefer_static_loop) {
3120- parallel_for3d (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_static );
3081+ parallel_for3d (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk );
31213082 } else {
3122- parallel_for3d_dynamic (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_dynamic );
3083+ parallel_for3d_dynamic (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk );
31233084 }
31243085
31253086 parallel_for3d (B, H, q_len, [&](size_t b, size_t h, size_t pq) {
3126- auto * temp = _output_bhl.ptr <float >(0 , b, pq , h);
3127- size_t temp_stride = _output_bhl.stride (0 );
3087+ auto * temp = _output_bhl.ptr <float >(b, 0 , h, pq );
3088+ size_t temp_stride = _output_bhl.stride (1 ); // split with pv_in_blocks steps
31283089 auto * dst = output_emb.ptr <DATA_TYPE>(b, pq, h * SV);
3129- attn_reduce (dst, temp, _nthr , SV, temp_stride);
3090+ attn_reduce (dst, temp, kv_len_in_blocks , SV, temp_stride);
31303091 });
31313092 }
31323093};
0 commit comments