Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ cross_compiled_file(${TARGET_NAME}
NAMESPACE ov::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 SVE ANY
ARCH AVX512F AVX2 SVE NEON_FP16 ANY
src/nodes/kernels/scaled_attn/attn_quant.cpp
API src/nodes/kernels/scaled_attn/attn_quant.hpp
NAME attn_quantkv paged_attn_quantkv attn_quant_u8 attn_dequant_u8 attn_quant_by_channel_u8 attn_dequant_by_channel_u8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,58 @@ void find_minmax(const T* src, size_t n, float& min, float& max) {
hmin(v0_min);
max = _mm256_cvtss_f32(v0_max);
min = _mm256_cvtss_f32(v0_min);
#elif defined(OPENVINO_ARCH_ARM64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question to @maxnick: do we need to add a comment that ARM behavior differs from x86? ARM path uses fp16 accumulator while x86 - fp32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment would definitely be helpful.

// ARM uses FP16 accumulator for FP16 Inference
# if defined(HAVE_SVE)
if constexpr (std::is_same_v<T, float>) {
auto v_max = svdup_f32(max);
auto v_min = svdup_f32(min);
for (; i < n; i += svcntw()) {
svbool_t pg = svwhilelt_b32(i, n);
auto va = svld1_f32(pg, src + i);
v_max = svmax_f32_m(pg, v_max, va);
v_min = svmin_f32_m(pg, v_min, va);
}
max = svmaxv(svptrue_b32(), v_max);
min = svminv(svptrue_b32(), v_min);
} else if constexpr (std::is_same_v<T, ov::float16>) {
auto v_max = svdup_f16(max);
auto v_min = svdup_f16(min);
for (; i < n; i += svcnth()) {
svbool_t pg = svwhilelt_b16(i, n);
auto va = svld1_f16(pg, reinterpret_cast<const float16_t*>(src + i));
v_max = svmax_f16_m(pg, v_max, va);
v_min = svmin_f16_m(pg, v_min, va);
}
max = svmaxv(svptrue_b16(), v_max);
min = svminv(svptrue_b16(), v_min);
}
# else
if constexpr (std::is_same_v<T, float>) {
auto v_max = vdupq_n_f32(max);
auto v_min = vdupq_n_f32(min);
for (; i + 4 <= n; i += 4) {
auto va = vld1q_f32(src + i);
v_max = vmaxq_f32(v_max, va);
v_min = vminq_f32(v_min, va);
}
max = vmaxvq_f32(v_max);
min = vminvq_f32(v_min);
}
# if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
if constexpr (std::is_same_v<T, ov::float16>) {
auto v_max = vdupq_n_f16(max);
auto v_min = vdupq_n_f16(min);
for (; i + 8 <= n; i += 8) {
auto va = vld1q_f16(reinterpret_cast<const float16_t*>(src) + i);
v_max = vmaxq_f16(v_max, va);
v_min = vminq_f16(v_min, va);
}
max = vmaxvq_f16(v_max);
min = vminvq_f16(v_min);
}
# endif
# endif
#endif
for (; i < n; i++) {
float tmp = src[i];
Expand All @@ -135,6 +187,12 @@ void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) {
scale = (max - min) / 255;
if (scale == 0) {
scale = 0.0001f;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
// For FP16 in ARM we use FP16 accumulator
if constexpr (std::is_same_v<T, ov::float16>) {
scale = 0.05f;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @alvoron,
I have enabled the test cases. Since on ARM we use FP16 accumulator we need a greater scale value at boundary condition where min == max [we reach this condition with test cases].

Could you please review it. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashwins990 It seems to me if constants exceed ~3300 we can still get inf.
I believe, we're not going to see such values in real LLM activations, however we can implement more robust fallback here I assume: scale = max(0.05f, abs(min) / 65504.0f)

}
#endif
}
zp = -min / scale;
#if defined(HAVE_AVX512F)
Expand Down
Loading
Loading