@@ -2941,15 +2941,12 @@ struct MHAHelper {
29412941 });
29422942 }
29432943
2944- // attn_w * V
2945- _output_bhl.resize <float >({static_cast <size_t >(_nthr), B, q_len, H, SV});
2946- // m_attn_w {B, H, q_len, kv_len}
2947- parallel_nt_static (_nthr, [&](const size_t ithr, [[maybe_unused]] const size_t nthr) {
2948- memset (_output_bhl.ptr <float >(ithr, 0 , 0 , 0 , 0 ), 0 , _output_bhl.stride (0 ) * sizeof (float ));
2944+ _output_bhl.resize <float >({B, kv_len_in_blocks, loop_hk ? Hk : H, q_len, SV});
2945+ parallel_for3d (B, kv_len_in_blocks, (loop_hk ? Hk : H), [&](size_t b, size_t pv_in_blocks, size_t h) {
2946+ memset (_output_bhl.ptr <float >(b, pv_in_blocks, h, 0 , 0 ), 0 , q_len * SV * sizeof (float ));
29492947 });
29502948
29512949 auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) {
2952- auto ithr = parallel_get_thread_num ();
29532950 auto context_len = static_cast <size_t >(past_lens.ptr <int32_t >()[b]) + 1 ;
29542951 auto pv = pv_in_blocks * _block_size;
29552952 size_t hk, hq_beg, hq_end;
@@ -2962,7 +2959,7 @@ struct MHAHelper {
29622959 for (size_t h = hq_beg; h < hq_end; h++) {
29632960 if constexpr (one_of (VALUE_PREC, ov::element::u8 , ov::element::u4)) {
29642961 attn_acc_value_block_quantized<uint8_t , VALUE_PREC>(
2965- _output_bhl.ptr <float >(ithr, b, pq , h),
2962+ _output_bhl.ptr <float >(b, pv_in_blocks , h, pq ),
29662963 _weight_bhl.ptr <float >(b, h, pq) + pv,
29672964 value_cache.ptr <uint8_t , VALUE_PREC>(block_number, hk),
29682965 SV,
@@ -2973,7 +2970,7 @@ struct MHAHelper {
29732970 auto * v_ptr =
29742971 value_cache.ptr <typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
29752972 attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
2976- _output_bhl.ptr <float >(ithr, b, pq , h),
2973+ _output_bhl.ptr <float >(b, pv_in_blocks , h, pq ),
29772974 _weight_bhl.ptr <float >(b, h, pq) + pv,
29782975 v_ptr,
29792976 SV,
@@ -2992,10 +2989,10 @@ struct MHAHelper {
29922989 }
29932990
29942991 parallel_for3d (B, H, q_len, [&](size_t b, size_t h, size_t pq) {
2995- auto * temp = _output_bhl.ptr <float >(0 , b, pq , h);
2996- size_t temp_stride = _output_bhl.stride (0 );
2992+ auto * temp = _output_bhl.ptr <float >(b, 0 , h, pq );
2993+ size_t temp_stride = _output_bhl.stride (1 ); // split with pv_in_blocks steps
29972994 auto * dst = output_emb.ptr <DATA_TYPE>(b, pq, h * SV);
2998- attn_reduce (dst, temp, _nthr , SV, temp_stride);
2995+ attn_reduce (dst, temp, kv_len_in_blocks , SV, temp_stride);
29992996 });
30002997 }
30012998};
0 commit comments