@@ -242,7 +242,7 @@ attn_acc_value(ov::float16* out, ov::float16 weight, uint8_t* v, size_t S, float
242242 auto v_group_zp = svdup_n_f16 (group_zp);
243243 auto v_weighed_group_scale = svdup_n_f16 (weight * group_scale);
244244 auto pg_b08 = svptrue_b8 ();
245- for (; i + svcntb () < group_size; i += svcntb ()) {
245+ for (; i + svcntb () <= group_size; i += svcntb ()) {
246246 auto v_a = svld1_u8 (pg_b08, v + offset + i);
247247 auto v_a_low = svunpklo_u16 (v_a);
248248 auto v_a_high = svunpkhi_u16 (v_a);
@@ -264,7 +264,7 @@ attn_acc_value(ov::float16* out, ov::float16 weight, uint8_t* v, size_t S, float
264264 float16x8_t v_group_zp = vdupq_n_f16 (group_zp);
265265 float16x8_t v_weighed_group_scale = vdupq_n_f16 (weight * group_scale);
266266
267- for (; i + 16 < group_size; i += 16 ) {
267+ for (; i + 16 <= group_size; i += 16 ) {
268268 uint8x16_t v_u8 = vld1q_u8 (v + offset + i);
269269
270270 uint16x8_t v_u16_lo = vmovl_u8 (vget_low_u8 (v_u8));
@@ -898,8 +898,8 @@ dot_product_fp16(ov::float16* a, void* b, size_t n, float* scale, float* zp, flo
898898 svfloat16_t v_group_zp = svdup_n_f16 (group_zp);
899899 auto v_group_sum = svdup_n_f16 (0 );
900900
901- for (; i + svcntb () < group_size; i += svcntb ()) {
902- svfloat16_t a0 = svld1_f16 (pg_b16, _a + i);
901+ for (; i + svcntb () <= group_size; i += svcntb ()) {
902+ svfloat16_t a0 = svld1_f16 (pg_b16, _a + i + offset );
903903 svfloat16_t a1 = svld1_f16 (pg_b16, _a + i + offset + svcnth ());
904904
905905 svuint8_t v_b8 = svld1 (pg_b08, _b + i + offset);
@@ -989,7 +989,7 @@ dot_product_fp16(ov::float16* a, void* b, size_t n, float* scale, float* zp, flo
989989
990990 auto v_group_zp = vdupq_n_f16 (group_zp);
991991
992- for (; i + 16 < group_size; i += 16 ) {
992+ for (; i + 16 <= group_size; i += 16 ) {
993993 float16x8_t v_a_lo = vld1q_f16 (_a + offset + i);
994994 float16x8_t v_a_hi = vld1q_f16 (_a + offset + i + 8 );
995995
@@ -1082,10 +1082,10 @@ dot_product_fp16(ov::float16* a, void* b, size_t n, float* scale, float* zp, flo
10821082 if constexpr (KEY_PREC == ov::element::u8 ) {
10831083 size_t group_id = 0 ;
10841084 auto _b = reinterpret_cast <uint8_t *>(b);
1085- size_t offset = group_id * group_size;
1086- float16_t group_scale = *(scale + group_id * 2 );
1087- float16_t group_zp = *(zp + group_id * 2 );
10881085 while (group_id < n / group_size) {
1086+ size_t offset = group_id * group_size;
1087+ float16_t group_scale = *(scale + group_id * 2 );
1088+ float16_t group_zp = *(zp + group_id * 2 );
10891089 float16_t group_sum = 0 .0f ;
10901090 i = 0 ;
10911091 for (; i < group_size; i++) {
@@ -1488,7 +1488,7 @@ static float dot_product(TA* a, uint8_t* b, size_t n, float* scale, float* zp, f
14881488
14891489 float32x4_t v_group_zp = vdupq_n_f32 (group_zp);
14901490
1491- for (; i + 16 < group_size; i += 16 ) {
1491+ for (; i + 16 <= group_size; i += 16 ) {
14921492 uint8x16_t v_u8 = vld1q_u8 (b + i + offset);
14931493
14941494 uint16x8_t v_u16_lo = vmovl_u8 (vget_low_u8 (v_u8));
@@ -1756,11 +1756,10 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
17561756 for (size_t iwork = start; iwork < end; ++iwork) {
17571757 auto * p = past_k_scale_zp.ptr <float >(pk, 0 , h_group);
17581758#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
1759- if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16> &&
1760- ov::intel_cpu::any_of<T2, ov::float16, uint8_t >) {
1761- auto p_k = present_key.ptr <T2>(0 , h_group, pk);
1762- prefetch_bytes (S, _MM_HINT_T0, 4096 , p_k);
1759+ if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16>) {
17631760 if constexpr (std::is_same_v<T2, uint8_t >) {
1761+ auto p_k = present_key.ptr <T2>(0 , h_group, pk);
1762+ prefetch_bytes (S, _MM_HINT_T0, 4096 , p_k);
17641763 auto _qk = dot_product_fp16<ov::element::u8 >(query.ptr <ov::float16>(0 , h_group),
17651764 p_k,
17661765 S,
@@ -1771,7 +1770,9 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
17711770 buf_attn_w.ptr <T3>(0 , h_group, 0 )[pk] = _qk;
17721771 parallel_it_step (pk, kv_len, b, B, h_group, h_group_num);
17731772 continue ;
1774- } else {
1773+ } else if constexpr (std::is_same_v<T2, ov::float16>) {
1774+ auto p_k = present_key.ptr <T2>(0 , h_group, pk);
1775+ prefetch_bytes (S, _MM_HINT_T0, 4096 , p_k);
17751776 auto _qk = dot_product_fp16<ov::element::f16 >(query.ptr <ov::float16>(0 , h_group),
17761777 p_k,
17771778 S,
@@ -1810,8 +1811,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
18101811 auto b_kv = beams ? beams.ptr <int32_t >(b)[pk] : b;
18111812 auto * p = past_k_scale_zp.ptr <float >(pk, b_kv, h_group);
18121813#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
1813- if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16> &&
1814- ov::intel_cpu::any_of<T2, ov::float16, uint8_t >) {
1814+ if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16>) {
18151815 if constexpr (std::is_same_v<T2, uint8_t >) {
18161816 auto _qk = dot_product_fp16<ov::element::u8 >(query.ptr <ov::float16>(b, h_group),
18171817 present_key.ptr <T2>(b_kv, h_group, pk),
@@ -1823,7 +1823,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
18231823 buf_attn_w.ptr <T3>(b, h_group, 0 )[pk] = _qk;
18241824 parallel_it_step (pk, kv_len, b, B, h_group, h_group_num);
18251825 continue ;
1826- } else {
1826+ } else if constexpr (std::is_same_v<T2, ov::float16>) {
18271827 auto _qk = dot_product_fp16<ov::element::f16 >(query.ptr <ov::float16>(b, h_group),
18281828 present_key.ptr <T2>(b_kv, h_group, pk),
18291829 S,
@@ -1863,8 +1863,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
18631863 auto * p = past_k_scale_zp.ptr <float >(pk, b_kv, h_group);
18641864 for (size_t h = h_group * h_each_group_len; h < (h_group + 1 ) * h_each_group_len; h++) {
18651865#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
1866- if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16> &&
1867- ov::intel_cpu::any_of<T2, ov::float16, uint8_t >) {
1866+ if (std::is_same_v<T3, ov::float16> && std::is_same_v<T, ov::float16>) {
18681867 if constexpr (std::is_same_v<T2, uint8_t >) {
18691868 auto _qk = dot_product_fp16<ov::element::u8 >(query.ptr <ov::float16>(b, h, pq),
18701869 present_key.ptr <T2>(b_kv, h_group, pk),
@@ -1875,7 +1874,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
18751874 key_group_size);
18761875 buf_attn_w.ptr <T3>(b, h, pq)[pk] = _qk;
18771876 continue ;
1878- } else {
1877+ } else if constexpr (std::is_same_v<T2, ov::float16>) {
18791878 auto _qk =
18801879 dot_product_fp16<ov::element::f16 >(query.ptr <ov::float16>(b, h, pq),
18811880 present_key.ptr <T2>(b_kv, h_group, pk),
@@ -2126,7 +2125,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
21262125 value_group_size,
21272126 quant_key_by_channel,
21282127 sink_input);
2129- } else if (present_key.get_precision () == ov::element::u8 ) {
2128+ } else if (present_key.get_precision () == ov::element::u8 && !quant_key_by_channel ) {
21302129 mha_single_token_kernel<ov::float16, uint8_t , ov::float16>(query,
21312130 present_key,
21322131 present_value,
0 commit comments