Skip to content

Commit 63f73d3

Browse files
committed
To avoid core dump cases, currently only fix Non_deterministic of loop_wk_static cases
1 parent 1073a3f commit 63f73d3

File tree

1 file changed

+54
-11
lines changed

1 file changed

+54
-11
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,12 +2941,55 @@ struct MHAHelper {
29412941
});
29422942
}
29432943

2944-
_output_bhl.resize<float>({B, kv_len_in_blocks, loop_hk ? Hk : H, q_len, SV});
2945-
parallel_for3d(B, kv_len_in_blocks, (loop_hk ? Hk : H), [&](size_t b, size_t pv_in_blocks, size_t h) {
2946-
memset(_output_bhl.ptr<float>(b, pv_in_blocks, h, 0, 0), 0, q_len * SV * sizeof(float));
2944+
// attn_w * V
2945+
_output_bhl.resize<float>({static_cast<size_t>(_nthr), B, q_len, H, SV});
2946+
// m_attn_w {B, H, q_len, kv_len}
2947+
parallel_nt_static(_nthr, [&](const size_t ithr, [[maybe_unused]] const size_t nthr) {
2948+
memset(_output_bhl.ptr<float>(ithr, 0, 0, 0, 0), 0, _output_bhl.stride(0) * sizeof(float));
29472949
});
29482950

2949-
auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) {
2951+
auto loop_wk_static =
2952+
[&](size_t ithr, size_t b, size_t pv_in_blocks, size_t hx) {
2953+
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
2954+
auto pv = pv_in_blocks * _block_size;
2955+
size_t hk, hq_beg, hq_end;
2956+
get_h_params(loop_hk, hx, _h_each_group_len, hq_beg, hq_end, hk);
2957+
2958+
// kv_len must be valid
2959+
if (pv < context_len) {
2960+
auto block_number =
2961+
block_indices.ptr<int32_t>()[block_indices_begins.ptr<int32_t>()[b] + pv_in_blocks];
2962+
for (size_t pq = 0; pq < q_len; pq++) {
2963+
for (size_t h = hq_beg; h < hq_end; h++) {
2964+
if constexpr (one_of(VALUE_PREC, ov::element::u8, ov::element::u4)) {
2965+
attn_acc_value_block_quantized<uint8_t, VALUE_PREC>(
2966+
_output_bhl.ptr<float>(ithr, b, pq, h),
2967+
_weight_bhl.ptr<float>(b, h, pq) + pv,
2968+
value_cache.ptr<uint8_t, VALUE_PREC>(block_number, hk),
2969+
SV,
2970+
_quant_value_bychannel,
2971+
std::min(_block_size, context_len - pv),
2972+
_value_group_size);
2973+
} else {
2974+
auto* v_ptr =
2975+
value_cache.ptr<typename element_type_traits<VALUE_PREC>::value_type>(block_number,
2976+
hk);
2977+
attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
2978+
_output_bhl.ptr<float>(ithr, b, pq, h),
2979+
_weight_bhl.ptr<float>(b, h, pq) + pv,
2980+
v_ptr,
2981+
SV,
2982+
std::min(_block_size, context_len - pv),
2983+
_value_group_size);
2984+
}
2985+
}
2986+
}
2987+
}
2988+
};
2989+
2990+
// TODO: align with loop_wk_static
2991+
auto loop_wk_dynamic = [&](size_t b, size_t pv_in_blocks, size_t hx) {
2992+
auto ithr = parallel_get_thread_num();
29502993
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
29512994
auto pv = pv_in_blocks * _block_size;
29522995
size_t hk, hq_beg, hq_end;
@@ -2959,7 +3002,7 @@ struct MHAHelper {
29593002
for (size_t h = hq_beg; h < hq_end; h++) {
29603003
if constexpr (one_of(VALUE_PREC, ov::element::u8, ov::element::u4)) {
29613004
attn_acc_value_block_quantized<uint8_t, VALUE_PREC>(
2962-
_output_bhl.ptr<float>(b, pv_in_blocks, h, pq),
3005+
_output_bhl.ptr<float>(ithr, b, pq, h),
29633006
_weight_bhl.ptr<float>(b, h, pq) + pv,
29643007
value_cache.ptr<uint8_t, VALUE_PREC>(block_number, hk),
29653008
SV,
@@ -2970,7 +3013,7 @@ struct MHAHelper {
29703013
auto* v_ptr =
29713014
value_cache.ptr<typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
29723015
attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
2973-
_output_bhl.ptr<float>(b, pv_in_blocks, h, pq),
3016+
_output_bhl.ptr<float>(ithr, b, pq, h),
29743017
_weight_bhl.ptr<float>(b, h, pq) + pv,
29753018
v_ptr,
29763019
SV,
@@ -2983,16 +3026,16 @@ struct MHAHelper {
29833026
};
29843027

29853028
if (prefer_static_loop) {
2986-
parallel_for3d(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk);
3029+
parallel_for3d(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_static);
29873030
} else {
2988-
parallel_for3d_dynamic(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk);
3031+
parallel_for3d_dynamic(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_dynamic);
29893032
}
29903033

29913034
parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) {
2992-
auto* temp = _output_bhl.ptr<float>(b, 0, h, pq);
2993-
size_t temp_stride = _output_bhl.stride(1); // split with pv_in_blocks steps
3035+
auto* temp = _output_bhl.ptr<float>(0, b, pq, h);
3036+
size_t temp_stride = _output_bhl.stride(0);
29943037
auto* dst = output_emb.ptr<DATA_TYPE>(b, pq, h * SV);
2995-
attn_reduce(dst, temp, kv_len_in_blocks, SV, temp_stride);
3038+
attn_reduce(dst, temp, _nthr, SV, temp_stride);
29963039
});
29973040
}
29983041
};

0 commit comments

Comments
 (0)