Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ cross_compiled_file(${TARGET_NAME}
NAMESPACE ov::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 SVE ANY
ARCH AVX512F AVX2 SVE NEON_FP16 ANY
src/nodes/kernels/scaled_attn/attn_quant.cpp
API src/nodes/kernels/scaled_attn/attn_quant.hpp
NAME attn_quantkv paged_attn_quantkv attn_quant_u8 attn_dequant_u8 attn_quant_by_channel_u8 attn_dequant_by_channel_u8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,58 @@ void find_minmax(const T* src, size_t n, float& min, float& max) {
hmin(v0_min);
max = _mm256_cvtss_f32(v0_max);
min = _mm256_cvtss_f32(v0_min);
#elif defined(OPENVINO_ARCH_ARM64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question to @maxnick: do we need to add a comment that ARM behavior differs from x86? ARM path uses fp16 accumulator while x86 - fp32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment would definitely be helpful.

// ARM uses FP16 accumulator for FP16 Inference
# if defined(HAVE_SVE)
if constexpr (std::is_same_v<T, float>) {
auto v_max = svdup_f32(max);
auto v_min = svdup_f32(min);
for (; i < n; i += svcntw()) {
svbool_t pg = svwhilelt_b32(i, n);
auto va = svld1_f32(pg, src + i);
v_max = svmax_f32_m(pg, v_max, va);
v_min = svmin_f32_m(pg, v_min, va);
}
max = svmaxv(svptrue_b32(), v_max);
min = svminv(svptrue_b32(), v_min);
} else if constexpr (std::is_same_v<T, ov::float16>) {
auto v_max = svdup_f16(max);
auto v_min = svdup_f16(min);
for (; i < n; i += svcnth()) {
svbool_t pg = svwhilelt_b16(i, n);
auto va = svld1_f16(pg, reinterpret_cast<const float16_t*>(src + i));
v_max = svmax_f16_m(pg, v_max, va);
v_min = svmin_f16_m(pg, v_min, va);
}
max = svmaxv(svptrue_b16(), v_max);
min = svminv(svptrue_b16(), v_min);
}
# else
if constexpr (std::is_same_v<T, float>) {
auto v_max = vdupq_n_f32(max);
auto v_min = vdupq_n_f32(min);
for (; i + 4 <= n; i += 4) {
auto va = vld1q_f32(src + i);
v_max = vmaxq_f32(v_max, va);
v_min = vminq_f32(v_min, va);
}
max = vmaxvq_f32(v_max);
min = vminvq_f32(v_min);
}
# if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
if constexpr (std::is_same_v<T, ov::float16>) {
auto v_max = vdupq_n_f16(max);
auto v_min = vdupq_n_f16(min);
for (; i + 8 <= n; i += 8) {
auto va = vld1q_f16(reinterpret_cast<const float16_t*>(src) + i);
v_max = vmaxq_f16(v_max, va);
v_min = vminq_f16(v_min, va);
}
max = vmaxvq_f16(v_max);
min = vminvq_f16(v_min);
}
# endif
# endif
#endif
for (; i < n; i++) {
float tmp = src[i];
Expand All @@ -135,6 +187,12 @@ void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) {
scale = (max - min) / 255;
if (scale == 0) {
scale = 0.0001f;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
// For FP16 in ARM we use FP16 accumulator
if constexpr (std::is_same_v<T, ov::float16>) {
scale = std::max(0.05f, std::abs(min) / 65504.0f);
}
#endif
}
zp = -min / scale;
#if defined(HAVE_AVX512F)
Expand Down
Loading
Loading