@@ -2941,12 +2941,14 @@ struct MHAHelper {
29412941 });
29422942 }
29432943
2944- _output_bhl.resize <float >({B, kv_len_in_blocks, loop_hk ? Hk : H, q_len, SV});
2945- parallel_for3d (B, kv_len_in_blocks, (loop_hk ? Hk : H), [&](size_t b, size_t pv_in_blocks, size_t h) {
2946- memset (_output_bhl.ptr <float >(b, pv_in_blocks, h, 0 , 0 ), 0 , q_len * SV * sizeof (float ));
2944+ // attn_w * V
2945+ _output_bhl.resize <float >({static_cast <size_t >(_nthr), B, q_len, H, SV});
2946+ // m_attn_w {B, H, q_len, kv_len}
2947+ parallel_nt_static (_nthr, [&](const size_t ithr, [[maybe_unused]] const size_t nthr) {
2948+ memset (_output_bhl.ptr <float >(ithr, 0 , 0 , 0 , 0 ), 0 , _output_bhl.stride (0 ) * sizeof (float ));
29472949 });
29482950
2949- auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) {
2951+ auto loop_wk_static = [&](size_t ithr, size_t b, size_t pv_in_blocks, size_t hx) {
29502952 auto context_len = static_cast <size_t >(past_lens.ptr <int32_t >()[b]) + 1 ;
29512953 auto pv = pv_in_blocks * _block_size;
29522954 size_t hk, hq_beg, hq_end;
@@ -2959,7 +2961,7 @@ struct MHAHelper {
29592961 for (size_t h = hq_beg; h < hq_end; h++) {
29602962 if constexpr (one_of (VALUE_PREC, ov::element::u8 , ov::element::u4)) {
29612963 attn_acc_value_block_quantized<uint8_t , VALUE_PREC>(
2962- _output_bhl.ptr <float >(b, pv_in_blocks, h, pq ),
2964+ _output_bhl.ptr <float >(ithr, b, pq, h ),
29632965 _weight_bhl.ptr <float >(b, h, pq) + pv,
29642966 value_cache.ptr <uint8_t , VALUE_PREC>(block_number, hk),
29652967 SV,
@@ -2970,7 +2972,45 @@ struct MHAHelper {
29702972 auto * v_ptr =
29712973 value_cache.ptr <typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
29722974 attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
2973- _output_bhl.ptr <float >(b, pv_in_blocks, h, pq),
2975+ _output_bhl.ptr <float >(ithr, b, pq, h),
2976+ _weight_bhl.ptr <float >(b, h, pq) + pv,
2977+ v_ptr,
2978+ SV,
2979+ std::min (_block_size, context_len - pv),
2980+ _value_group_size);
2981+ }
2982+ }
2983+ }
2984+ }
2985+ };
2986+
2987+ // TODO: align with loop_wk_static
2988+ auto loop_wk_dynamic = [&](size_t b, size_t pv_in_blocks, size_t hx) {
2989+ auto ithr = parallel_get_thread_num ();
2990+ auto context_len = static_cast <size_t >(past_lens.ptr <int32_t >()[b]) + 1 ;
2991+ auto pv = pv_in_blocks * _block_size;
2992+ size_t hk, hq_beg, hq_end;
2993+ get_h_params (loop_hk, hx, _h_each_group_len, hq_beg, hq_end, hk);
2994+
2995+ // kv_len must be valid
2996+ if (pv < context_len) {
2997+ auto block_number = block_indices.ptr <int32_t >()[block_indices_begins.ptr <int32_t >()[b] + pv_in_blocks];
2998+ for (size_t pq = 0 ; pq < q_len; pq++) {
2999+ for (size_t h = hq_beg; h < hq_end; h++) {
3000+ if constexpr (one_of (VALUE_PREC, ov::element::u8 , ov::element::u4)) {
3001+ attn_acc_value_block_quantized<uint8_t , VALUE_PREC>(
3002+ _output_bhl.ptr <float >(ithr, b, pq, h),
3003+ _weight_bhl.ptr <float >(b, h, pq) + pv,
3004+ value_cache.ptr <uint8_t , VALUE_PREC>(block_number, hk),
3005+ SV,
3006+ _quant_value_bychannel,
3007+ std::min (_block_size, context_len - pv),
3008+ _value_group_size);
3009+ } else {
3010+ auto * v_ptr =
3011+ value_cache.ptr <typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
3012+ attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
3013+ _output_bhl.ptr <float >(ithr, b, pq, h),
29743014 _weight_bhl.ptr <float >(b, h, pq) + pv,
29753015 v_ptr,
29763016 SV,
@@ -2983,16 +3023,16 @@ struct MHAHelper {
29833023 };
29843024
29853025 if (prefer_static_loop) {
2986- parallel_for3d (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk );
3026+ parallel_for3d (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_static );
29873027 } else {
2988- parallel_for3d_dynamic (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk );
3028+ parallel_for3d_dynamic (B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_dynamic );
29893029 }
29903030
29913031 parallel_for3d (B, H, q_len, [&](size_t b, size_t h, size_t pq) {
2992- auto * temp = _output_bhl.ptr <float >(b, 0 , h , pq);
2993- size_t temp_stride = _output_bhl.stride (1 ); // split with pv_in_blocks steps
3032+ auto * temp = _output_bhl.ptr <float >(0 , b , pq, h );
3033+ size_t temp_stride = _output_bhl.stride (0 );
29943034 auto * dst = output_emb.ptr <DATA_TYPE>(b, pq, h * SV);
2995- attn_reduce (dst, temp, kv_len_in_blocks , SV, temp_stride);
3035+ attn_reduce (dst, temp, _nthr , SV, temp_stride);
29963036 });
29973037 }
29983038};
0 commit comments