Skip to content

Commit c8e79ac

Browse files
committed
To avoid core dump cases, currently only fix Non_deterministic of loop_wk_static cases
1 parent 1073a3f commit c8e79ac

File tree

1 file changed

+51
-11
lines changed

1 file changed

+51
-11
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,12 +2941,14 @@ struct MHAHelper {
29412941
});
29422942
}
29432943

2944-
_output_bhl.resize<float>({B, kv_len_in_blocks, loop_hk ? Hk : H, q_len, SV});
2945-
parallel_for3d(B, kv_len_in_blocks, (loop_hk ? Hk : H), [&](size_t b, size_t pv_in_blocks, size_t h) {
2946-
memset(_output_bhl.ptr<float>(b, pv_in_blocks, h, 0, 0), 0, q_len * SV * sizeof(float));
2944+
// attn_w * V
2945+
_output_bhl.resize<float>({static_cast<size_t>(_nthr), B, q_len, H, SV});
2946+
// m_attn_w {B, H, q_len, kv_len}
2947+
parallel_nt_static(_nthr, [&](const size_t ithr, [[maybe_unused]] const size_t nthr) {
2948+
memset(_output_bhl.ptr<float>(ithr, 0, 0, 0, 0), 0, _output_bhl.stride(0) * sizeof(float));
29472949
});
29482950

2949-
auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) {
2951+
auto loop_wk_static = [&](size_t ithr, size_t b, size_t pv_in_blocks, size_t hx) {
29502952
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
29512953
auto pv = pv_in_blocks * _block_size;
29522954
size_t hk, hq_beg, hq_end;
@@ -2959,7 +2961,7 @@ struct MHAHelper {
29592961
for (size_t h = hq_beg; h < hq_end; h++) {
29602962
if constexpr (one_of(VALUE_PREC, ov::element::u8, ov::element::u4)) {
29612963
attn_acc_value_block_quantized<uint8_t, VALUE_PREC>(
2962-
_output_bhl.ptr<float>(b, pv_in_blocks, h, pq),
2964+
_output_bhl.ptr<float>(ithr, b, pq, h),
29632965
_weight_bhl.ptr<float>(b, h, pq) + pv,
29642966
value_cache.ptr<uint8_t, VALUE_PREC>(block_number, hk),
29652967
SV,
@@ -2970,7 +2972,45 @@ struct MHAHelper {
29702972
auto* v_ptr =
29712973
value_cache.ptr<typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
29722974
attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
2973-
_output_bhl.ptr<float>(b, pv_in_blocks, h, pq),
2975+
_output_bhl.ptr<float>(ithr, b, pq, h),
2976+
_weight_bhl.ptr<float>(b, h, pq) + pv,
2977+
v_ptr,
2978+
SV,
2979+
std::min(_block_size, context_len - pv),
2980+
_value_group_size);
2981+
}
2982+
}
2983+
}
2984+
}
2985+
};
2986+
2987+
// TODO: align with loop_wk_static
2988+
auto loop_wk_dynamic = [&](size_t b, size_t pv_in_blocks, size_t hx) {
2989+
auto ithr = parallel_get_thread_num();
2990+
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
2991+
auto pv = pv_in_blocks * _block_size;
2992+
size_t hk, hq_beg, hq_end;
2993+
get_h_params(loop_hk, hx, _h_each_group_len, hq_beg, hq_end, hk);
2994+
2995+
// kv_len must be valid
2996+
if (pv < context_len) {
2997+
auto block_number = block_indices.ptr<int32_t>()[block_indices_begins.ptr<int32_t>()[b] + pv_in_blocks];
2998+
for (size_t pq = 0; pq < q_len; pq++) {
2999+
for (size_t h = hq_beg; h < hq_end; h++) {
3000+
if constexpr (one_of(VALUE_PREC, ov::element::u8, ov::element::u4)) {
3001+
attn_acc_value_block_quantized<uint8_t, VALUE_PREC>(
3002+
_output_bhl.ptr<float>(ithr, b, pq, h),
3003+
_weight_bhl.ptr<float>(b, h, pq) + pv,
3004+
value_cache.ptr<uint8_t, VALUE_PREC>(block_number, hk),
3005+
SV,
3006+
_quant_value_bychannel,
3007+
std::min(_block_size, context_len - pv),
3008+
_value_group_size);
3009+
} else {
3010+
auto* v_ptr =
3011+
value_cache.ptr<typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
3012+
attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
3013+
_output_bhl.ptr<float>(ithr, b, pq, h),
29743014
_weight_bhl.ptr<float>(b, h, pq) + pv,
29753015
v_ptr,
29763016
SV,
@@ -2983,16 +3023,16 @@ struct MHAHelper {
29833023
};
29843024

29853025
if (prefer_static_loop) {
2986-
parallel_for3d(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk);
3026+
parallel_for3d(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_static);
29873027
} else {
2988-
parallel_for3d_dynamic(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk);
3028+
parallel_for3d_dynamic(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_dynamic);
29893029
}
29903030

29913031
parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) {
2992-
auto* temp = _output_bhl.ptr<float>(b, 0, h, pq);
2993-
size_t temp_stride = _output_bhl.stride(1); // split with pv_in_blocks steps
3032+
auto* temp = _output_bhl.ptr<float>(0, b, pq, h);
3033+
size_t temp_stride = _output_bhl.stride(0);
29943034
auto* dst = output_emb.ptr<DATA_TYPE>(b, pq, h * SV);
2995-
attn_reduce(dst, temp, kv_len_in_blocks, SV, temp_stride);
3035+
attn_reduce(dst, temp, _nthr, SV, temp_stride);
29963036
});
29973037
}
29983038
};

0 commit comments

Comments
 (0)