Skip to content

Commit 32fa83a

Browse files
committed
[CI Failure fixes] Refactor Code
1 parent 05bf41f commit 32fa83a

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void find_minmax(const T* src, size_t n, float& min, float& max) {
159159
if constexpr (std::is_same_v<T, ov::float16>) {
160160
auto v_max = vdupq_n_f16(max);
161161
auto v_min = vdupq_n_f16(min);
162-
for (; i + 8 < n; i += 8) {
162+
for (; i + 8 <= n; i += 8) {
163163
auto va = vld1q_f16(reinterpret_cast<const float16_t*>(src) + i);
164164
v_max = vmaxq_f16(v_max, va);
165165
v_min = vminq_f16(v_min, va);

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ attn_acc_value(ov::float16* out, ov::float16 weight, uint8_t* v, size_t S, float
242242
auto v_group_zp = svdup_n_f16(group_zp);
243243
auto v_weighed_group_scale = svdup_n_f16(weight * group_scale);
244244
auto pg_b08 = svptrue_b8();
245-
for (; i + svcntb() < group_size; i += svcntb()) {
245+
for (; i + svcntb() <= group_size; i += svcntb()) {
246246
auto v_a = svld1_u8(pg_b08, v + offset + i);
247247
auto v_a_low = svunpklo_u16(v_a);
248248
auto v_a_high = svunpkhi_u16(v_a);
@@ -264,7 +264,7 @@ attn_acc_value(ov::float16* out, ov::float16 weight, uint8_t* v, size_t S, float
264264
float16x8_t v_group_zp = vdupq_n_f16(group_zp);
265265
float16x8_t v_weighed_group_scale = vdupq_n_f16(weight * group_scale);
266266

267-
for (; i + 16 < group_size; i += 16) {
267+
for (; i + 16 <= group_size; i += 16) {
268268
uint8x16_t v_u8 = vld1q_u8(v + offset + i);
269269

270270
uint16x8_t v_u16_lo = vmovl_u8(vget_low_u8(v_u8));
@@ -898,8 +898,8 @@ dot_product_fp16(ov::float16* a, void* b, size_t n, float* scale, float* zp, flo
898898
svfloat16_t v_group_zp = svdup_n_f16(group_zp);
899899
auto v_group_sum = svdup_n_f16(0);
900900

901-
for (; i + svcntb() < group_size; i += svcntb()) {
902-
svfloat16_t a0 = svld1_f16(pg_b16, _a + i);
901+
for (; i + svcntb() <= group_size; i += svcntb()) {
902+
svfloat16_t a0 = svld1_f16(pg_b16, _a + i + offset);
903903
svfloat16_t a1 = svld1_f16(pg_b16, _a + i + offset + svcnth());
904904

905905
svuint8_t v_b8 = svld1(pg_b08, _b + i + offset);
@@ -989,7 +989,7 @@ dot_product_fp16(ov::float16* a, void* b, size_t n, float* scale, float* zp, flo
989989

990990
auto v_group_zp = vdupq_n_f16(group_zp);
991991

992-
for (; i + 16 < group_size; i += 16) {
992+
for (; i + 16 <= group_size; i += 16) {
993993
float16x8_t v_a_lo = vld1q_f16(_a + offset + i);
994994
float16x8_t v_a_hi = vld1q_f16(_a + offset + i + 8);
995995

@@ -1082,10 +1082,10 @@ dot_product_fp16(ov::float16* a, void* b, size_t n, float* scale, float* zp, flo
10821082
if constexpr (KEY_PREC == ov::element::u8) {
10831083
size_t group_id = 0;
10841084
auto _b = reinterpret_cast<uint8_t*>(b);
1085-
size_t offset = group_id * group_size;
1086-
float16_t group_scale = *(scale + group_id * 2);
1087-
float16_t group_zp = *(zp + group_id * 2);
10881085
while (group_id < n / group_size) {
1086+
size_t offset = group_id * group_size;
1087+
float16_t group_scale = *(scale + group_id * 2);
1088+
float16_t group_zp = *(zp + group_id * 2);
10891089
float16_t group_sum = 0.0f;
10901090
i = 0;
10911091
for (; i < group_size; i++) {
@@ -1488,7 +1488,7 @@ static float dot_product(TA* a, uint8_t* b, size_t n, float* scale, float* zp, f
14881488

14891489
float32x4_t v_group_zp = vdupq_n_f32(group_zp);
14901490

1491-
for (; i + 16 < group_size; i += 16) {
1491+
for (; i + 16 <= group_size; i += 16) {
14921492
uint8x16_t v_u8 = vld1q_u8(b + i + offset);
14931493

14941494
uint16x8_t v_u16_lo = vmovl_u8(vget_low_u8(v_u8));
@@ -1756,11 +1756,10 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
17561756
for (size_t iwork = start; iwork < end; ++iwork) {
17571757
auto* p = past_k_scale_zp.ptr<float>(pk, 0, h_group);
17581758
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
1759-
if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16> &&
1760-
ov::intel_cpu::any_of<T2, ov::float16, uint8_t>) {
1761-
auto p_k = present_key.ptr<T2>(0, h_group, pk);
1762-
prefetch_bytes(S, _MM_HINT_T0, 4096, p_k);
1759+
if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16>) {
17631760
if constexpr (std::is_same_v<T2, uint8_t>) {
1761+
auto p_k = present_key.ptr<T2>(0, h_group, pk);
1762+
prefetch_bytes(S, _MM_HINT_T0, 4096, p_k);
17641763
auto _qk = dot_product_fp16<ov::element::u8>(query.ptr<ov::float16>(0, h_group),
17651764
p_k,
17661765
S,
@@ -1771,7 +1770,9 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
17711770
buf_attn_w.ptr<T3>(0, h_group, 0)[pk] = _qk;
17721771
parallel_it_step(pk, kv_len, b, B, h_group, h_group_num);
17731772
continue;
1774-
} else {
1773+
} else if constexpr (std::is_same_v<T2, ov::float16>) {
1774+
auto p_k = present_key.ptr<T2>(0, h_group, pk);
1775+
prefetch_bytes(S, _MM_HINT_T0, 4096, p_k);
17751776
auto _qk = dot_product_fp16<ov::element::f16>(query.ptr<ov::float16>(0, h_group),
17761777
p_k,
17771778
S,
@@ -1810,8 +1811,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
18101811
auto b_kv = beams ? beams.ptr<int32_t>(b)[pk] : b;
18111812
auto* p = past_k_scale_zp.ptr<float>(pk, b_kv, h_group);
18121813
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
1813-
if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16> &&
1814-
ov::intel_cpu::any_of<T2, ov::float16, uint8_t>) {
1814+
if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16>) {
18151815
if constexpr (std::is_same_v<T2, uint8_t>) {
18161816
auto _qk = dot_product_fp16<ov::element::u8>(query.ptr<ov::float16>(b, h_group),
18171817
present_key.ptr<T2>(b_kv, h_group, pk),
@@ -1823,7 +1823,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
18231823
buf_attn_w.ptr<T3>(b, h_group, 0)[pk] = _qk;
18241824
parallel_it_step(pk, kv_len, b, B, h_group, h_group_num);
18251825
continue;
1826-
} else {
1826+
} else if constexpr (std::is_same_v<T2, ov::float16>) {
18271827
auto _qk = dot_product_fp16<ov::element::f16>(query.ptr<ov::float16>(b, h_group),
18281828
present_key.ptr<T2>(b_kv, h_group, pk),
18291829
S,
@@ -1863,8 +1863,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
18631863
auto* p = past_k_scale_zp.ptr<float>(pk, b_kv, h_group);
18641864
for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) {
18651865
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
1866-
if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16> &&
1867-
ov::intel_cpu::any_of<T2, ov::float16, uint8_t>) {
1866+
if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16>) {
18681867
if constexpr (std::is_same_v<T2, uint8_t>) {
18691868
auto _qk = dot_product_fp16<ov::element::u8>(query.ptr<ov::float16>(b, h, pq),
18701869
present_key.ptr<T2>(b_kv, h_group, pk),
@@ -1875,7 +1874,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
18751874
key_group_size);
18761875
buf_attn_w.ptr<T3>(b, h, pq)[pk] = _qk;
18771876
continue;
1878-
} else {
1877+
} else if constexpr (std::is_same_v<T2, ov::float16>) {
18791878
auto _qk =
18801879
dot_product_fp16<ov::element::f16>(query.ptr<ov::float16>(b, h, pq),
18811880
present_key.ptr<T2>(b_kv, h_group, pk),
@@ -2126,7 +2125,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
21262125
value_group_size,
21272126
quant_key_by_channel,
21282127
sink_input);
2129-
} else if (present_key.get_precision() == ov::element::u8) {
2128+
} else if (present_key.get_precision() == ov::element::u8 && !quant_key_by_channel) {
21302129
mha_single_token_kernel<ov::float16, uint8_t, ov::float16>(query,
21312130
present_key,
21322131
present_value,

0 commit comments

Comments
 (0)