Skip to content

Commit 0a1bc22

Browse files
committed
Merge conflict resolved, refactored code
1 parent 45874d2 commit 0a1bc22

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

src/plugins/intel_cpu/src/nodes/kernels/aarch64/sve_utils.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,33 @@ size_t sve_vlen() {
5353
return svcntd();
5454
}
5555
}
56+
57+
template <typename TA, typename TB>
58+
static void cvt_copy(TA* dst, TB* src, size_t n) {
59+
size_t i = 0;
60+
if constexpr (std::is_same<TA, TB>::value) {
61+
auto pg_dst = sve_predicate<sizeof(TA)>();
62+
auto vlen = sve_vlen<sizeof(TA)>();
63+
for (; i + vlen <= n; i += vlen) {
64+
auto vb = svld1(pg_dst, src + i);
65+
svst1(pg_dst, dst + i, vb);
66+
}
67+
auto pgt = sve_predicate<TA, sizeof(TA)>(i, n);
68+
auto vb = svld1(pg_dst, src + i);
69+
svst1(pg_dst, dst + i, vb);
70+
return;
71+
} else if constexpr (std::is_same<TA, float>::value && std::is_same<TB, ov::float16>::value) {
72+
auto src_ptr = reinterpret_cast<float16_t*>(src);
73+
auto pg_vl2 = svwhilelt_b16(svcnth() / 2, svcnth());
74+
auto vlen = svcnth() / 2;
75+
auto pg_dst = svptrue_b32();
76+
for (; i + vlen <= n; i += vlen) {
77+
auto load_src = svld1_f16(pg_vl2, src_ptr + i);
78+
auto src_interleave = svzip1_f16(load_src, load_src);
79+
auto cvt_dst = svcvt_f32_f16_z(pg_dst, src_interleave);
80+
svst1(pg_dst, dst + i, cvt_dst);
81+
}
82+
}
83+
}
84+
5685
} // namespace ov::intel_cpu::sve_utils

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
# include "nodes/kernels/aarch64/brgemm_kernel.hpp"
3333
# include "nodes/kernels/aarch64/sve_utils.hpp"
3434
# include "nodes/kernels/kai/kleidi_kernel.hpp"
35-
using namespace ov::intel_cpu::sve_utils;
3635
#endif
3736

3837
namespace ov::Extensions::Cpu::XARCH {
@@ -2593,9 +2592,9 @@ struct MHAHelper {
25932592
PlainTensor f32_cvt;
25942593
if (q_is_xf16) {
25952594
f32_cvt.resize<float>({size_t{rnd_up(cur_kv_len, _block_size)}});
2596-
cvt_copy(f32_cvt.ptr<float>(0),
2597-
reinterpret_cast<DATA_TYPE*>(score),
2598-
rnd_up(cur_kv_len, _block_size));
2595+
sve_utils::cvt_copy(f32_cvt.ptr<float>(0),
2596+
reinterpret_cast<DATA_TYPE*>(score),
2597+
rnd_up(cur_kv_len, _block_size));
25992598
soft_in = f32_cvt.ptr<float>(0);
26002599
}
26012600
if (_sliding_window) {
@@ -2641,9 +2640,9 @@ struct MHAHelper {
26412640
alibi_slope);
26422641
}
26432642
if (score_output) {
2644-
cvt_copy(score_output + h * rnd_up(cur_kv_len, 16),
2645-
reinterpret_cast<DATA_TYPE*>(score),
2646-
cur_kv_len);
2643+
sve_utils::cvt_copy(score_output + h * rnd_up(cur_kv_len, 16),
2644+
reinterpret_cast<DATA_TYPE*>(score),
2645+
cur_kv_len);
26472646
}
26482647
}
26492648

@@ -3164,7 +3163,8 @@ struct MHA {
31643163
v_ptr,
31653164
_helper._block_size,
31663165
_helper.SV,
3167-
_helper._value_group_size);
3166+
_helper._value_group_size,
3167+
_helper._quant_value_bychannel);
31683168
# else
31693169
pack_32NxK<DATA_TYPE, VALUE_PREC>(
31703170
_helper._wv_scratch_b.template ptr<DATA_TYPE>(batch_in_reorder, kv_block, hk),
@@ -3176,6 +3176,7 @@ struct MHA {
31763176
_helper.SV,
31773177
_helper._value_group_size,
31783178
_helper._quant_value_bychannel);
3179+
# endif
31793180
} else {
31803181
// need to decompress
31813182
if (!q_cache_is_same) {
@@ -3929,9 +3930,11 @@ std::shared_ptr<PagedAttentionExecutor> make_pa_executor(ov::element::Type data_
39293930
}
39303931
if (data_type == ov::element::f16) {
39313932
if (key_cache_type == ov::element::u8 && value_cache_type == ov::element::u8) {
3932-
executor = std::make_shared<AttentionExecutor<ov::float16, uint8_t, ov::element::u8>>(key_group_size,
3933-
value_group_size,
3934-
quant_key_bychannel);
3933+
executor = std::make_shared<AttentionExecutor<ov::float16, ov::element::u8, ov::element::u8>>(
3934+
key_group_size,
3935+
value_group_size,
3936+
quant_key_bychannel,
3937+
quant_value_bychannel);
39353938
} else {
39363939
OPENVINO_THROW("make_pa_executor: key_cache_type and value_cache_type of u8 is only support");
39373940
}

0 commit comments

Comments
 (0)