Skip to content

Commit 1073a3f

Browse files
committed
fix LLMPipeline Non deterministic CPU inference issue
1 parent feffa0c commit 1073a3f

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,15 +2941,12 @@ struct MHAHelper {
29412941
});
29422942
}
29432943

2944-
// attn_w * V
2945-
_output_bhl.resize<float>({static_cast<size_t>(_nthr), B, q_len, H, SV});
2946-
// m_attn_w {B, H, q_len, kv_len}
2947-
parallel_nt_static(_nthr, [&](const size_t ithr, [[maybe_unused]] const size_t nthr) {
2948-
memset(_output_bhl.ptr<float>(ithr, 0, 0, 0, 0), 0, _output_bhl.stride(0) * sizeof(float));
2944+
_output_bhl.resize<float>({B, kv_len_in_blocks, loop_hk ? Hk : H, q_len, SV});
2945+
parallel_for3d(B, kv_len_in_blocks, (loop_hk ? Hk : H), [&](size_t b, size_t pv_in_blocks, size_t h) {
2946+
memset(_output_bhl.ptr<float>(b, pv_in_blocks, h, 0, 0), 0, q_len * SV * sizeof(float));
29492947
});
29502948

29512949
auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) {
2952-
auto ithr = parallel_get_thread_num();
29532950
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
29542951
auto pv = pv_in_blocks * _block_size;
29552952
size_t hk, hq_beg, hq_end;
@@ -2962,7 +2959,7 @@ struct MHAHelper {
29622959
for (size_t h = hq_beg; h < hq_end; h++) {
29632960
if constexpr (one_of(VALUE_PREC, ov::element::u8, ov::element::u4)) {
29642961
attn_acc_value_block_quantized<uint8_t, VALUE_PREC>(
2965-
_output_bhl.ptr<float>(ithr, b, pq, h),
2962+
_output_bhl.ptr<float>(b, pv_in_blocks, h, pq),
29662963
_weight_bhl.ptr<float>(b, h, pq) + pv,
29672964
value_cache.ptr<uint8_t, VALUE_PREC>(block_number, hk),
29682965
SV,
@@ -2973,7 +2970,7 @@ struct MHAHelper {
29732970
auto* v_ptr =
29742971
value_cache.ptr<typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
29752972
attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
2976-
_output_bhl.ptr<float>(ithr, b, pq, h),
2973+
_output_bhl.ptr<float>(b, pv_in_blocks, h, pq),
29772974
_weight_bhl.ptr<float>(b, h, pq) + pv,
29782975
v_ptr,
29792976
SV,
@@ -2992,10 +2989,10 @@ struct MHAHelper {
29922989
}
29932990

29942991
parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) {
2995-
auto* temp = _output_bhl.ptr<float>(0, b, pq, h);
2996-
size_t temp_stride = _output_bhl.stride(0);
2992+
auto* temp = _output_bhl.ptr<float>(b, 0, h, pq);
2993+
size_t temp_stride = _output_bhl.stride(1); // split with pv_in_blocks steps
29972994
auto* dst = output_emb.ptr<DATA_TYPE>(b, pq, h * SV);
2998-
attn_reduce(dst, temp, _nthr, SV, temp_stride);
2995+
attn_reduce(dst, temp, kv_len_in_blocks, SV, temp_stride);
29992996
});
30002997
}
30012998
};

0 commit comments

Comments
 (0)