Skip to content

Commit 7acba52

Browse files
committed
change both static and dynamic parallel split method to cover all these kind of non deterministic cases
1 parent f4e7e26 commit 7acba52

File tree

1 file changed

+11
-50
lines changed

1 file changed

+11
-50
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 11 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3034,13 +3034,12 @@ struct MHAHelper {
30343034
}
30353035

30363036
// attn_w * V
3037-
_output_bhl.resize<float>({static_cast<size_t>(_nthr), B, q_len, H, SV});
3038-
// m_attn_w {B, H, q_len, kv_len}
3039-
parallel_nt_static(_nthr, [&](const size_t ithr, [[maybe_unused]] const size_t nthr) {
3040-
memset(_output_bhl.ptr<float>(ithr, 0, 0, 0, 0), 0, _output_bhl.stride(0) * sizeof(float));
3037+
_output_bhl.resize<float>({B, kv_len_in_blocks, H, q_len, SV});
3038+
parallel_for3d(B, kv_len_in_blocks, H, [&](size_t b, size_t pv_in_blocks, size_t h) {
3039+
memset(_output_bhl.ptr<float>(b, pv_in_blocks, h, 0, 0), 0, q_len * SV * sizeof(float));
30413040
});
30423041

3043-
auto loop_wk_static = [&](size_t ithr, size_t b, size_t pv_in_blocks, size_t hx) {
3042+
auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) {
30443043
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
30453044
auto pv = pv_in_blocks * _block_size;
30463045
size_t hk;
@@ -3055,7 +3054,7 @@ struct MHAHelper {
30553054
for (size_t h = hq_beg; h < hq_end; h++) {
30563055
if constexpr (one_of(VALUE_PREC, ov::element::u8, ov::element::u4)) {
30573056
attn_acc_value_block_quantized<uint8_t, VALUE_PREC>(
3058-
_output_bhl.ptr<float>(ithr, b, pq, h),
3057+
_output_bhl.ptr<float>(b, pv_in_blocks, h, pq),
30593058
_weight_bhl.ptr<float>(b, h, pq) + pv,
30603059
value_cache.ptr<uint8_t, VALUE_PREC>(block_number, hk),
30613060
SV,
@@ -3066,45 +3065,7 @@ struct MHAHelper {
30663065
auto* v_ptr =
30673066
value_cache.ptr<typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
30683067
attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
3069-
_output_bhl.ptr<float>(ithr, b, pq, h),
3070-
_weight_bhl.ptr<float>(b, h, pq) + pv,
3071-
v_ptr,
3072-
SV,
3073-
std::min(_block_size, context_len - pv),
3074-
_value_group_size);
3075-
}
3076-
}
3077-
}
3078-
}
3079-
};
3080-
3081-
// TODO: align with loop_wk_static
3082-
auto loop_wk_dynamic = [&](size_t b, size_t pv_in_blocks, size_t hx) {
3083-
auto ithr = parallel_get_thread_num();
3084-
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
3085-
auto pv = pv_in_blocks * _block_size;
3086-
size_t hk, hq_beg, hq_end;
3087-
get_h_params(loop_hk, hx, _h_each_group_len, hq_beg, hq_end, hk);
3088-
3089-
// kv_len must be valid
3090-
if (pv < context_len) {
3091-
auto block_number = block_indices.ptr<int32_t>()[block_indices_begins.ptr<int32_t>()[b] + pv_in_blocks];
3092-
for (size_t pq = 0; pq < q_len; pq++) {
3093-
for (size_t h = hq_beg; h < hq_end; h++) {
3094-
if constexpr (one_of(VALUE_PREC, ov::element::u8, ov::element::u4)) {
3095-
attn_acc_value_block_quantized<uint8_t, VALUE_PREC>(
3096-
_output_bhl.ptr<float>(ithr, b, pq, h),
3097-
_weight_bhl.ptr<float>(b, h, pq) + pv,
3098-
value_cache.ptr<uint8_t, VALUE_PREC>(block_number, hk),
3099-
SV,
3100-
_quant_value_bychannel,
3101-
std::min(_block_size, context_len - pv),
3102-
_value_group_size);
3103-
} else {
3104-
auto* v_ptr =
3105-
value_cache.ptr<typename element_type_traits<VALUE_PREC>::value_type>(block_number, hk);
3106-
attn_acc_value_block<typename element_type_traits<VALUE_PREC>::value_type, VALUE_PREC>(
3107-
_output_bhl.ptr<float>(ithr, b, pq, h),
3068+
_output_bhl.ptr<float>(b, pv_in_blocks, h, pq),
31083069
_weight_bhl.ptr<float>(b, h, pq) + pv,
31093070
v_ptr,
31103071
SV,
@@ -3117,16 +3078,16 @@ struct MHAHelper {
31173078
};
31183079

31193080
if (prefer_static_loop) {
3120-
parallel_for3d(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_static);
3081+
parallel_for3d(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk);
31213082
} else {
3122-
parallel_for3d_dynamic(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk_dynamic);
3083+
parallel_for3d_dynamic(B, kv_len_in_blocks, loop_hk ? Hk : H, loop_wk);
31233084
}
31243085

31253086
parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) {
3126-
auto* temp = _output_bhl.ptr<float>(0, b, pq, h);
3127-
size_t temp_stride = _output_bhl.stride(0);
3087+
auto* temp = _output_bhl.ptr<float>(b, 0, h, pq);
3088+
size_t temp_stride = _output_bhl.stride(1); // split with pv_in_blocks steps
31283089
auto* dst = output_emb.ptr<DATA_TYPE>(b, pq, h * SV);
3129-
attn_reduce(dst, temp, _nthr, SV, temp_stride);
3090+
attn_reduce(dst, temp, kv_len_in_blocks, SV, temp_stride);
31303091
});
31313092
}
31323093
};

0 commit comments

Comments
 (0)