@@ -2941,12 +2941,55 @@ struct MHAHelper {
29412941 });
29422942 }
29432943
2944- _output_bhl.resize <float >({B, kv_len_in_blocks, loop_hk ? Hk : H, q_len, SV});
2945- parallel_for3d (B, kv_len_in_blocks, (loop_hk ? Hk : H), [&](size_t b, size_t pv_in_blocks, size_t h) {
2946- memset (_output_bhl.ptr <float >(b, pv_in_blocks, h, 0 , 0 ), 0 , q_len * SV * sizeof (float ));
2944+ // attn_w * V
2945+ _output_bhl.resize <float >({static_cast <size_t >(_nthr), B, q_len, H, SV});
2946+ // m_attn_w {B, H, q_len, kv_len}
2947+ parallel_nt_static (_nthr, [&](const size_t ithr, [[maybe_unused]] const size_t nthr) {
2948+ memset (_output_bhl.ptr <float >(ithr, 0 , 0 , 0 , 0 ), 0 , _output_bhl.stride (0 ) * sizeof (float ));
29472949 });
29482950
2949- auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) {
2951+ auto loop_wk_static =
2952+ [&](size_t ithr, size_t b, size_t pv_in_blocks, size_t hx) {
2953+ auto context_len = static_cast <size_t >(past_lens.ptr <int32_t >()[b]) + 1 ;
2954+ auto pv = pv_in_blocks * _block_size;
2955+ size_t hk, hq_beg, hq_end;
2956+ get_h_params (loop_hk, hx, _h_each_group_len, hq_beg, hq_end, hk);
2957+
2958+ // kv_len must be valid
2959+ if (pv < context_len) {
2960+ auto block_number =
2961+ block_indices.ptr <int32_t >()[block_indices_begins.ptr <int32_t >()[b] + pv_in_blocks];
2962+ for (size_t pq = 0 ; pq < q_len; pq++) {
2963+ for (size_t h = hq_beg; h < hq_end; h++) {
2964+ if constexpr (one_of (VALUE_PREC, ov::element::u8 , ov::element::u4)) {
2965+ attn_acc_value_block_quantized<uint8_t , VALUE_PREC>(
2966+ _output_bhl.ptr <float >(ithr, b, pq, h),
2967+ _weight_bhl.ptr <float >(b, h, pq) + pv,
2968+ value_cache.ptr <uint8_t , VALUE_PREC>(block_number, hk),
2969+ SV,
2970+ _quant_value_bychannel,
2971+ std::min (_block_size, context_len - pv),
2972+ _value_group_size);
2973+ } else {
2974+ auto * v_ptr =
2975+ value_cache.ptr <typename element_type_traits<VALUE_PREC>::value_type>(block_number,
2976+ hk);
2977+ attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
2978+ _output_bhl.ptr <float >(ithr, b, pq, h),
2979+ _weight_bhl.ptr <float >(b, h, pq) + pv,
2980+ v_ptr,
2981+ SV,
2982+ std::min (_block_size, context_len - pv),
2983+ _value_group_size);
2984+ }
2985+ }
2986+ }
2987+ }
2988+ };
2989+
2990+ // TODO: align with loop_wk_static
2991+ auto loop_wk_dynamic = [&](size_t b, size_t pv_in_blocks, size_t hx) {
2992+ auto ithr = parallel_get_thread_num ();
29502993 auto context_len = static_cast <size_t >(past_lens.ptr <int32_t >()[b]) + 1 ;
29512994 auto pv = pv_in_blocks * _block_size;
29522995 size_t hk, hq_beg, hq_end;
@@ -2959,7 +3002,7 @@ struct MHAHelper {
29593002 for (size_t h = hq_beg; h < hq_end; h++) {
29603003 if constexpr (one_of (VALUE_PREC, ov::element::u8 , ov::element::u4)) {
29613004 attn_acc_value_block_quantized<uint8_t , VALUE_PREC>(
2962- _output_bhl.ptr <float >(b, pv_in_blocks, h, pq ),
3005+ _output_bhl.ptr <float >(ithr, b, pq, h ),
29633006 _weight_bhl.ptr <float >(b, h, pq) + pv,
29643007 value_cache.ptr <uint8_t , VALUE_PREC>(block_number, hk),
29653008 SV,
@@ -2970,7 +3013,7 @@ struct MHAHelper {
29703013 auto * v_ptr =
29713014 value_cache.ptr <typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
29723015 attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
2973- _output_bhl.ptr <float >(b, pv_in_blocks, h, pq ),
3016+ _output_bhl.ptr <float >(ithr, b, pq, h ),
29743017 _weight_bhl.ptr <float >(b, h, pq) + pv,
29753018 v_ptr,
29763019 SV,
@@ -2983,16 +3026,16 @@ struct MHAHelper {
29833026 };
29843027
29853028 if (prefer_static_loop) {
2986- parallel_for3d (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk );
3029+ parallel_for3d (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_static );
29873030 } else {
2988- parallel_for3d_dynamic (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk );
3031+ parallel_for3d_dynamic (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_dynamic );
29893032 }
29903033
29913034 parallel_for3d (B, H, q_len, [&](size_t b, size_t h, size_t pq) {
2992- auto * temp = _output_bhl.ptr <float >(b, 0 , h , pq);
2993- size_t temp_stride = _output_bhl.stride (1 ); // split with pv_in_blocks steps
3035+ auto * temp = _output_bhl.ptr <float >(0 , b , pq, h );
3036+ size_t temp_stride = _output_bhl.stride (0 );
29943037 auto * dst = output_emb.ptr <DATA_TYPE>(b, pq, h * SV);
2995- attn_reduce (dst, temp, kv_len_in_blocks , SV, temp_stride);
3038+ attn_reduce (dst, temp, _nthr , SV, temp_stride);
29963039 });
29973040 }
29983041};
0 commit comments