Skip to content

Commit 3477ea1

Browse files
ashwins990allnes
andauthored
[ Aarch64 ] Paged Attention FP16 precision enablement. (#29219)
This development is related to Feature Request : #26422 This PR enables f16 inference precision for Paged Attention operator and key-value cache precision as u8. Updated :: Using Kleidi instead of ACL. Attaching the server bechmarking result on Graviton 3E - 64 cores. This shows the comparison of f16 performance with f32 [reference] precision. ![Kleidi-oss-result](https://github.com/user-attachments/assets/4080ad7c-8896-46ce-85b1-80b94664fc25) --------- Co-authored-by: Nesterov Alexander <[email protected]>
1 parent 368e64d commit 3477ea1

File tree

7 files changed

+582
-9
lines changed

7 files changed

+582
-9
lines changed

cmake/developer_package/compile_flags/functions.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ macro(ov_arm_sve_optimization_flags flags)
283283
endif()
284284

285285
# Check for compiler SVE support
286-
ov_check_compiler_supports_sve("-march=armv8-a+sve")
286+
ov_check_compiler_supports_sve("-march=armv8-a+sve+fp16")
287287
if(OV_COMPILER_IS_INTEL_LLVM)
288288
message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}")
289289
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
@@ -305,7 +305,7 @@ macro(ov_arm_sve_optimization_flags flags)
305305

306306
# Add flag for SVE if supported
307307
if(CXX_SVE_FOUND)
308-
list(APPEND ${flags} -march=armv8-a+sve)
308+
list(APPEND ${flags} -march=armv8-a+sve+fp16)
309309
endif()
310310
if(NOT CMAKE_CL_64)
311311
list(APPEND ${flags} -ftree-vectorize)

cmake/developer_package/features.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ ov_dependent_option (ENABLE_AVX512F "Enable AVX512 optimizations" ON "X86_64 OR
5656

5757
ov_dependent_option (ENABLE_NEON_FP16 "Enable ARM FP16 optimizations" ON "AARCH64" OFF)
5858

59-
ov_dependent_option (ENABLE_SVE "Enable SVE optimizations" ON "AARCH64" OFF)
59+
ov_dependent_option (ENABLE_SVE "Enable SVE optimizations" ON "AARCH64 AND NOT APPLE" OFF)
6060

6161
# Type of build, we add this as an explicit option to default it to ON
6262
get_property(BUILD_SHARED_LIBS_DEFAULT GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS)
@@ -115,7 +115,7 @@ if(ENABLE_AVX512F)
115115
endif()
116116

117117
if(ENABLE_SVE)
118-
ov_check_compiler_supports_sve("-march=armv8-a+sve")
118+
ov_check_compiler_supports_sve("-march=armv8-a+sve+fp16")
119119

120120
if(NOT CXX_HAS_SVE)
121121
set(ENABLE_SVE OFF CACHE BOOL "Enables ARM64 SVE support" FORCE)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright (C) 2024 FUJITSU LIMITED
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include <arm_sve.h>
5+
6+
#include "openvino/core/type/float16.hpp"
7+
8+
namespace ov::intel_cpu::sve_utils {
9+
10+
template <typename T, typename... Args>
11+
constexpr bool one_of(T val, Args... args) {
12+
return ((val == args) || ...);
13+
}
14+
15+
template <size_t T_SIZE>
16+
svbool_t sve_predicate() {
17+
static_assert(one_of(T_SIZE, 8, 16, 32, 64), "Unexpected parameter size");
18+
if constexpr (8 == T_SIZE) {
19+
return svptrue_b8();
20+
} else if (16 == T_SIZE) {
21+
return svptrue_b16();
22+
} else if (32 == T_SIZE) {
23+
return svptrue_b32();
24+
} else if (64 == T_SIZE) {
25+
return svptrue_b64();
26+
}
27+
}
28+
29+
template <typename T_TYPE, size_t T_SIZE>
30+
svbool_t sve_predicate(T_TYPE lower, T_TYPE higher) {
31+
static_assert(one_of(T_SIZE, 8, 16, 32, 64), "Unexpected parameter size");
32+
if constexpr (8 == T_SIZE) {
33+
return svwhilelt_b8(lower, higher);
34+
} else if (16 == T_SIZE) {
35+
return svwhilelt_b16(lower, higher);
36+
} else if (32 == T_SIZE) {
37+
return svwhilelt_b32(lower, higher);
38+
} else if (64 == T_SIZE) {
39+
return svwhilelt_b64(lower, higher);
40+
}
41+
}
42+
43+
template <size_t T_SIZE>
44+
size_t sve_vlen() {
45+
static_assert(one_of(T_SIZE, 8, 16, 32, 64), "Unexpected parameter size");
46+
if constexpr (8 == T_SIZE) {
47+
return svcntb();
48+
} else if (16 == T_SIZE) {
49+
return svcnth();
50+
} else if (32 == T_SIZE) {
51+
return svcntw();
52+
} else if (64 == T_SIZE) {
53+
return svcntd();
54+
}
55+
}
56+
57+
template <typename TA, typename TB>
58+
static void cvt_copy(TA* dst, TB* src, size_t n) {
59+
size_t i = 0;
60+
if constexpr (std::is_same<TA, TB>::value) {
61+
auto pg_dst = sve_predicate<sizeof(TA)>();
62+
auto vlen = sve_vlen<sizeof(TA)>();
63+
for (; i + vlen <= n; i += vlen) {
64+
auto vb = svld1(pg_dst, src + i);
65+
svst1(pg_dst, dst + i, vb);
66+
}
67+
auto pgt = sve_predicate<TA, sizeof(TA)>(i, n);
68+
auto vb = svld1(pg_dst, src + i);
69+
svst1(pg_dst, dst + i, vb);
70+
return;
71+
} else if constexpr (std::is_same<TA, float>::value && std::is_same<TB, ov::float16>::value) {
72+
auto src_ptr = reinterpret_cast<float16_t*>(src);
73+
auto pg_vl2 = svwhilelt_b16(svcnth() / 2, svcnth());
74+
auto vlen = svcnth() / 2;
75+
auto pg_dst = svptrue_b32();
76+
for (; i + vlen <= n; i += vlen) {
77+
auto load_src = svld1_f16(pg_vl2, src_ptr + i);
78+
auto src_interleave = svzip1_f16(load_src, load_src);
79+
auto cvt_dst = svcvt_f32_f16_z(pg_dst, src_interleave);
80+
svst1(pg_dst, dst + i, cvt_dst);
81+
}
82+
}
83+
}
84+
85+
} // namespace ov::intel_cpu::sve_utils
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright (C) 2025 FUJITSU LIMITED
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <arm_neon.h>
8+
#include <kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h>
9+
#include <kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p_interface.h>
10+
#include <kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h>
11+
12+
#include <limits>
13+
#include <openvino/core/type/element_type.hpp>
14+
15+
namespace ov::intel_cpu {
16+
17+
class KleidiGemm {
18+
public:
19+
KleidiGemm(size_t M, size_t N, size_t K, size_t lda, size_t ldb, size_t ldc);
20+
void executeGemm(const void* a, const void* b, void* c);
21+
void packB(const float16_t* inp, const float16_t* bias, float16_t* packed_out);
22+
const size_t get_packed_rhs_size() const;
23+
24+
private:
25+
static constexpr kai_matmul_clamp_f16_f16_f16p_ukernel ukernel{
26+
kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
27+
kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
28+
kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
29+
kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
30+
kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
31+
kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
32+
kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
33+
kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
34+
kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
35+
kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla};
36+
size_t M, N, K;
37+
size_t lda, ldb, ldc;
38+
size_t nr, kr, sr;
39+
size_t packedRHSsize;
40+
};
41+
42+
KleidiGemm::KleidiGemm(size_t _M, size_t _N, size_t _K, size_t _lda, size_t _ldb, size_t _ldc)
43+
: M(_M),
44+
N(_N),
45+
K(_K),
46+
lda(_lda),
47+
ldb(_ldb),
48+
ldc(_ldc),
49+
nr(ukernel.get_nr()),
50+
kr(ukernel.get_kr()),
51+
sr(ukernel.get_sr()),
52+
packedRHSsize(kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(_N, _K)){};
53+
54+
const size_t KleidiGemm::get_packed_rhs_size() const {
55+
return packedRHSsize;
56+
}
57+
58+
void KleidiGemm::packB(const float16_t* inp, const float16_t* bias, float16_t* packed_out) {
59+
// Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant.
60+
kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(1,
61+
N,
62+
K,
63+
nr,
64+
kr,
65+
sr, // Packing arguments
66+
ldb * sizeof(float16_t), // RHS stride
67+
inp, // RHS
68+
bias, // Bias
69+
NULL, // Scale
70+
packed_out, // RHS packed
71+
0,
72+
NULL);
73+
}
74+
75+
void KleidiGemm::executeGemm(const void* a, const void* b, void* c) {
76+
const size_t m_step = ukernel.get_m_step();
77+
const size_t n_step = ukernel.get_n_step();
78+
for (size_t i_m_step = 0; i_m_step < M; i_m_step += m_step) {
79+
for (size_t i_n_step = 0; i_n_step < N; i_n_step += n_step) {
80+
const uint8_t* lhs_ptr =
81+
static_cast<const uint8_t*>(a) + (ukernel.get_lhs_packed_offset(i_m_step, lda * sizeof(float16_t)));
82+
const uint8_t* rhs_ptr = static_cast<const uint8_t*>(b) + (ukernel.get_rhs_packed_offset(i_n_step, K));
83+
uint8_t* dst_ptr =
84+
static_cast<uint8_t*>(c) + (ukernel.get_dst_offset(i_m_step, i_n_step, ldc * sizeof(float16_t)));
85+
const size_t actual_m = std::min(M - i_m_step, m_step);
86+
const size_t actual_n = std::min(N - i_n_step, n_step);
87+
88+
ukernel.run_matmul(actual_m,
89+
actual_n,
90+
K, // Dimensions
91+
lhs_ptr, // LHS
92+
lda * sizeof(float16_t), // LHS stride
93+
rhs_ptr, // RHS packed
94+
dst_ptr, // DST
95+
ldc * sizeof(float16_t), // DST stride (row)
96+
sizeof(float16_t), // DST stride (col)
97+
-std::numeric_limits<float>::max(),
98+
std::numeric_limits<float>::max() // Min and max for the clamp operation
99+
);
100+
}
101+
}
102+
}
103+
104+
} // namespace ov::intel_cpu

0 commit comments

Comments
 (0)