diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index a128ff0c577b52..76c870e24123bc 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -760,111 +760,116 @@ dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float* head_sum, [[ sum = _mm256_cvtss_f32(vsum0); #elif defined(OPENVINO_ARCH_ARM64) + static_assert(!std::is_same_v && !std::is_same_v, + "bfloat16 is not supported on ARM64 platform."); # if defined(HAVE_SVE) - svbool_t pg = svptrue_b32(); - svfloat32_t sum0 = svdup_n_f32(0.0f); - svfloat32_t sum1 = svdup_n_f32(0.0f); - svfloat32_t sum2 = svdup_n_f32(0.0f); - svfloat32_t sum3 = svdup_n_f32(0.0f); - auto vec_len = vec_len_f32_sve(); - - auto _a = reinterpret_cast(a); - auto _b = reinterpret_cast(b); - - for (; i + 4 * vec_len <= n; i += 4 * vec_len) { - svfloat32_t a0 = svld1_f32(pg, _a + i); - svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len); - svfloat32_t a2 = svld1_f32(pg, _a + i + vec_len * 2); - svfloat32_t a3 = svld1_f32(pg, _a + i + vec_len * 3); - - svfloat32_t b0 = svld1_f32(pg, _b + i); - svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len); - svfloat32_t b2 = svld1_f32(pg, _b + i + vec_len * 2); - svfloat32_t b3 = svld1_f32(pg, _b + i + vec_len * 3); - - sum0 = svmla_f32_z(pg, sum0, a0, b0); - sum1 = svmla_f32_z(pg, sum1, a1, b1); - sum2 = svmla_f32_z(pg, sum2, a2, b2); - sum3 = svmla_f32_z(pg, sum3, a3, b3); - } - if (i + 2 * vec_len <= n) { - svfloat32_t a0 = svld1_f32(pg, _a + i); - svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len); + if constexpr (std::is_same_v && std::is_same_v) { + svbool_t pg = svptrue_b32(); + svfloat32_t sum0 = svdup_n_f32(0.0f); + svfloat32_t sum1 = svdup_n_f32(0.0f); + svfloat32_t sum2 = svdup_n_f32(0.0f); + svfloat32_t sum3 = svdup_n_f32(0.0f); + auto vec_len = vec_len_f32_sve(); - svfloat32_t b0 = svld1_f32(pg, _b + i); - svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len); + auto _a = reinterpret_cast(a); + auto _b = reinterpret_cast(b); - sum0 = svmla_f32_z(pg, sum0, a0, b0); - sum1 = svmla_f32_z(pg, sum1, a1, b1); - i += 2 * vec_len; - } - if (i + vec_len <= n) { - svfloat32_t a0 = svld1_f32(pg, _a + i); - svfloat32_t b0 = svld1_f32(pg, _b + i); - sum0 = svmla_f32_z(pg, sum0, a0, b0); - i += vec_len; - } - // Process the tail elements parallely as well (if any) - if (i != n) { - svbool_t pg_rem = svwhilelt_b32(0, static_cast(n - i)); - svfloat32_t a0 = svld1_f32(pg_rem, _a + i); - svfloat32_t b0 = svld1_f32(pg_rem, _b + i); - sum0 = svmla_f32_m(pg_rem, sum0, a0, b0); - i = n; - } - float32_t sum_0 = svaddv_f32(pg, sum0); - float32_t sum_1 = svaddv_f32(pg, sum1); - float32_t sum_2 = svaddv_f32(pg, sum2); - float32_t sum_3 = svaddv_f32(pg, sum3); - sum = static_cast(sum_0 + sum_1 + sum_2 + sum_3); -# else - float32x4_t vsum0 = vdupq_n_f32(0.0f); - float32x4_t vsum1 = vdupq_n_f32(0.0f); - float32x4_t vsum2 = vdupq_n_f32(0.0f); - float32x4_t vsum3 = vdupq_n_f32(0.0f); - - for (; i + 4 * vec_len_f32_neon <= n; i += vec_len_f32_neon * 4) { - float32x4_t va0 = __vld1q_f32(a + i); - float32x4_t va1 = __vld1q_f32(a + i + vec_len_f32_neon); - float32x4_t va2 = __vld1q_f32(a + i + vec_len_f32_neon * 2); - float32x4_t va3 = __vld1q_f32(a + i + vec_len_f32_neon * 3); - - float32x4_t vb0 = __vld1q_f32(b + i); - float32x4_t vb1 = __vld1q_f32(b + i + vec_len_f32_neon); - float32x4_t vb2 = __vld1q_f32(b + i + vec_len_f32_neon * 2); - float32x4_t vb3 = __vld1q_f32(b + i + vec_len_f32_neon * 3); - - vsum0 = vmlaq_f32(vsum0, va0, vb0); - vsum1 = vmlaq_f32(vsum1, va1, vb1); - vsum2 = vmlaq_f32(vsum2, va2, vb2); - vsum3 = vmlaq_f32(vsum3, va3, vb3); - } - if (i + 2 * vec_len_f32_neon <= n) { - float32x4_t va0 = __vld1q_f32(a + i); - float32x4_t va1 = __vld1q_f32(a + i + vec_len_f32_neon); + for (; i + 4 * vec_len <= n; i += 4 * vec_len) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len); + svfloat32_t a2 = svld1_f32(pg, _a + i + vec_len * 2); + svfloat32_t a3 = svld1_f32(pg, _a + i + vec_len * 3); - float32x4_t vb0 = __vld1q_f32(b + i); - float32x4_t vb1 = __vld1q_f32(b + i + vec_len_f32_neon); + svfloat32_t b0 = svld1_f32(pg, _b + i); + svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len); + svfloat32_t b2 = svld1_f32(pg, _b + i + vec_len * 2); + svfloat32_t b3 = svld1_f32(pg, _b + i + vec_len * 3); - vsum0 = vmlaq_f32(vsum0, va0, vb0); - vsum1 = vmlaq_f32(vsum1, va1, vb1); - i += 2 * vec_len_f32_neon; - } - if (i + vec_len_f32_neon <= n) { - float32x4_t va0 = __vld1q_f32(a + i); - float32x4_t vb0 = __vld1q_f32(b + i); - vsum0 = vmlaq_f32(vsum0, va0, vb0); - i += vec_len_f32_neon; - } + sum0 = svmla_f32_z(pg, sum0, a0, b0); + sum1 = svmla_f32_z(pg, sum1, a1, b1); + sum2 = svmla_f32_z(pg, sum2, a2, b2); + sum3 = svmla_f32_z(pg, sum3, a3, b3); + } + if (i + 2 * vec_len <= n) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len); - vsum0 = vaddq_f32(vsum0, vsum1); - vsum2 = vaddq_f32(vsum2, vsum3); - vsum0 = vaddq_f32(vsum0, vsum2); + svfloat32_t b0 = svld1_f32(pg, _b + i); + svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len); - float32x2_t temp_sum = vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0)); - temp_sum = vpadd_f32(temp_sum, temp_sum); - sum = vget_lane_f32(temp_sum, 0); + sum0 = svmla_f32_z(pg, sum0, a0, b0); + sum1 = svmla_f32_z(pg, sum1, a1, b1); + i += 2 * vec_len; + } + if (i + vec_len <= n) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t b0 = svld1_f32(pg, _b + i); + sum0 = svmla_f32_z(pg, sum0, a0, b0); + i += vec_len; + } + // Process the tail elements parallely as well (if any) + if (i != n) { + svbool_t pg_rem = svwhilelt_b32(0, static_cast(n - i)); + svfloat32_t a0 = svld1_f32(pg_rem, _a + i); + svfloat32_t b0 = svld1_f32(pg_rem, _b + i); + sum0 = svmla_f32_m(pg_rem, sum0, a0, b0); + i = n; + } + float32_t sum_0 = svaddv_f32(pg, sum0); + float32_t sum_1 = svaddv_f32(pg, sum1); + float32_t sum_2 = svaddv_f32(pg, sum2); + float32_t sum_3 = svaddv_f32(pg, sum3); + sum = static_cast(sum_0 + sum_1 + sum_2 + sum_3); + } else # endif + { + float32x4_t vsum0 = vdupq_n_f32(0.0f); + float32x4_t vsum1 = vdupq_n_f32(0.0f); + float32x4_t vsum2 = vdupq_n_f32(0.0f); + float32x4_t vsum3 = vdupq_n_f32(0.0f); + + for (; i + 4 * vec_len_f32_neon <= n; i += vec_len_f32_neon * 4) { + float32x4_t va0 = __vld1q_f32(a + i); + float32x4_t va1 = __vld1q_f32(a + i + vec_len_f32_neon); + float32x4_t va2 = __vld1q_f32(a + i + vec_len_f32_neon * 2); + float32x4_t va3 = __vld1q_f32(a + i + vec_len_f32_neon * 3); + + float32x4_t vb0 = __vld1q_f32(b + i); + float32x4_t vb1 = __vld1q_f32(b + i + vec_len_f32_neon); + float32x4_t vb2 = __vld1q_f32(b + i + vec_len_f32_neon * 2); + float32x4_t vb3 = __vld1q_f32(b + i + vec_len_f32_neon * 3); + + vsum0 = vmlaq_f32(vsum0, va0, vb0); + vsum1 = vmlaq_f32(vsum1, va1, vb1); + vsum2 = vmlaq_f32(vsum2, va2, vb2); + vsum3 = vmlaq_f32(vsum3, va3, vb3); + } + if (i + 2 * vec_len_f32_neon <= n) { + float32x4_t va0 = __vld1q_f32(a + i); + float32x4_t va1 = __vld1q_f32(a + i + vec_len_f32_neon); + + float32x4_t vb0 = __vld1q_f32(b + i); + float32x4_t vb1 = __vld1q_f32(b + i + vec_len_f32_neon); + + vsum0 = vmlaq_f32(vsum0, va0, vb0); + vsum1 = vmlaq_f32(vsum1, va1, vb1); + i += 2 * vec_len_f32_neon; + } + if (i + vec_len_f32_neon <= n) { + float32x4_t va0 = __vld1q_f32(a + i); + float32x4_t vb0 = __vld1q_f32(b + i); + vsum0 = vmlaq_f32(vsum0, va0, vb0); + i += vec_len_f32_neon; + } + + vsum0 = vaddq_f32(vsum0, vsum1); + vsum2 = vaddq_f32(vsum2, vsum3); + vsum0 = vaddq_f32(vsum0, vsum2); + + float32x2_t temp_sum = vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0)); + temp_sum = vpadd_f32(temp_sum, temp_sum); + sum = vget_lane_f32(temp_sum, 0); + } #endif for (; i < n; i++) { sum += a[i] * b[i]; @@ -1479,6 +1484,8 @@ static float dot_product(TA* a, uint8_t* b, size_t n, float* scale, float* zp, f } return sum; #elif defined(OPENVINO_ARCH_ARM64) + static_assert(std::is_same_v || std::is_same_v, + "Only support float16 and float32 for ARM64 dot product."); while (group_id < n / group_size) { size_t i = 0; float group_scale = *(scale + group_id * 2); @@ -2063,6 +2070,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, bool quant_key_by_channel, const ov::intel_cpu::PlainTensor& sink_input, const ov::intel_cpu::CpuParallelPtr& cpu_parallel) { +#if !defined(OPENVINO_ARCH_ARM64) if (query.get_precision() == ov::element::bf16) { if (present_key.get_precision() == ov::element::u8) { mha_single_token_kernel(query, @@ -2107,7 +2115,9 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, sink_input, cpu_parallel); } - } else if (query.get_precision() == ov::element::f16) { + } +#endif + if (query.get_precision() == ov::element::f16) { #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) if (present_key.get_precision() == ov::element::f16) { mha_single_token_kernel(query,