Skip to content

Commit 03b4c51

Browse files
committed
Aarch64-pa-f16-Kleidi
1 parent bdd1a28 commit 03b4c51

File tree

8 files changed

+575
-8
lines changed

8 files changed

+575
-8
lines changed

cmake/developer_package/compile_flags/functions.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ macro(ov_arm_sve_optimization_flags flags)
283283
endif()
284284

285285
# Check for compiler SVE support
286-
ov_check_compiler_supports_sve("-march=armv8-a+sve")
286+
ov_check_compiler_supports_sve("-march=armv8-a+sve+fp16")
287287
if(OV_COMPILER_IS_INTEL_LLVM)
288288
message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}")
289289
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
@@ -305,7 +305,7 @@ macro(ov_arm_sve_optimization_flags flags)
305305

306306
# Add flag for SVE if supported
307307
if(CXX_SVE_FOUND)
308-
list(APPEND ${flags} -march=armv8-a+sve)
308+
list(APPEND ${flags} -march=armv8-a+sve+fp16)
309309
endif()
310310
if(NOT CMAKE_CL_64)
311311
list(APPEND ${flags} -ftree-vectorize)

cmake/developer_package/features.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ if(ENABLE_AVX512F)
121121
endif()
122122

123123
if(ENABLE_SVE)
124-
ov_check_compiler_supports_sve("-march=armv8-a+sve")
124+
ov_check_compiler_supports_sve("-march=armv8-a+sve+fp16")
125125

126126
if(NOT CXX_HAS_SVE)
127127
set(ENABLE_SVE OFF CACHE BOOL "Enables ARM64 SVE support" FORCE)

src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,16 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co
7474
const auto pa_op = m.get_match_root();
7575
auto key_cache = ov::as_type_ptr<ov::op::v0::Parameter>(pa_op->get_input_node_shared_ptr(3));
7676
auto value_cache = ov::as_type_ptr<ov::op::v0::Parameter>(pa_op->get_input_node_shared_ptr(4));
77+
#if defined(OPENVINO_ARCH_ARM64)
78+
auto format_cache_precision = [](ov::element::Type cache_precision, ov::element::Type infer_precision) {
79+
return ov::element::u8;
80+
};
81+
#else
7782
auto format_cache_precision = [](ov::element::Type cache_precision, ov::element::Type infer_precision) {
7883
return cache_precision == ov::element::f16 && infer_precision == ov::element::bf16 ? infer_precision
7984
: cache_precision;
8085
};
86+
#endif
8187
auto init_cache_shape = [&](const size_t head_nums,
8288
const size_t head_size,
8389
const size_t block_size,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright (C) 2024 FUJITSU LIMITED
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include <arm_sve.h>
5+
6+
#include "openvino/core/type/float16.hpp"
7+
8+
#define SIZE_IN_BITS(t_var) sizeof(t_var) * 8
9+
#define __ce(expr, bits, ...) \
10+
if constexpr (expr == bits) { \
11+
__VA_ARGS__ \
12+
}
13+
14+
#define SVE_PREDICATE(var, t_var) \
15+
svbool_t var; \
16+
\
17+
__ce(SIZE_IN_BITS(t_var), 8, var = svptrue_b8();) __ce(SIZE_IN_BITS(t_var), 16, var = svptrue_b16();) \
18+
__ce(SIZE_IN_BITS(t_var), 32, var = svptrue_b32();) __ce(SIZE_IN_BITS(t_var), 64, var = svptrue_b64();)
19+
20+
#define SVE_VLEN(var, t_var) \
21+
size_t var; \
22+
\
23+
__ce(SIZE_IN_BITS(t_var), 8, var = svcntb();) __ce(SIZE_IN_BITS(t_var), 16, var = svcnth();) \
24+
__ce(SIZE_IN_BITS(t_var), 32, var = svcntw();) __ce(SIZE_IN_BITS(t_var), 64, var = svcntd();)
25+
26+
#define SVE_PREDICATE_WHILELT(var, t_var, arg1, arg2) \
27+
svbool_t var; \
28+
\
29+
__ce(SIZE_IN_BITS(t_var), 8, var = svwhilelt_b8(arg1, arg2);) \
30+
__ce(SIZE_IN_BITS(t_var), 16, var = svwhilelt_b16(arg1, arg2);) \
31+
__ce(SIZE_IN_BITS(t_var), 32, var = svwhilelt_b32(arg1, arg2);) \
32+
__ce(SIZE_IN_BITS(t_var), 64, var = svwhilelt_b64(arg1, arg2);)
33+
34+
namespace ov::Extensions::Cpu::XARCH {
35+
static void cvt_copy(float* dst, ov::float16* src, size_t n) {
36+
auto src_ptr = reinterpret_cast<float16_t*>(src);
37+
auto pg_vl2 = svwhilelt_b16(svcnth() / 2, svcnth());
38+
auto vlen = svcnth() / 2;
39+
auto pg_dst = svptrue_b32();
40+
size_t i = 0;
41+
for (; i + vlen <= n; i += vlen) {
42+
auto load_src = svld1_f16(pg_vl2, src_ptr + i);
43+
auto src_interleave = svzip1_f16(load_src, load_src);
44+
auto cvt_dst = svcvt_f32_f16_z(pg_dst, src_interleave);
45+
svst1(pg_dst, dst + i, cvt_dst);
46+
}
47+
for (; i < n; i++) {
48+
dst[i] = src[i];
49+
}
50+
}
51+
} // namespace ov::Extensions::Cpu::XARCH
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#include <arm_neon.h>
2+
#include <kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h>
3+
#include <kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p_interface.h>
4+
#include <kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h>
5+
6+
#include <cfloat>
7+
#include <openvino/core/type/element_type.hpp>
8+
9+
namespace ov::intel_cpu {
10+
11+
class KleidiKernel {
12+
public:
13+
KleidiKernel(size_t M, size_t N, size_t K, size_t lda, size_t ldb, size_t ldc);
14+
void executeGemm(void* a, void* b, void* c);
15+
void packB(float16_t* inp, float16_t* packed_out, float16_t* bias);
16+
const size_t get_packed_rhs_size() const;
17+
18+
private:
19+
static constexpr kai_matmul_clamp_f16_f16_f16p_ukernel ukernel{
20+
kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
21+
kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
22+
kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
23+
kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
24+
kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
25+
kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
26+
kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
27+
kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
28+
kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
29+
kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla};
30+
size_t M, N, K;
31+
size_t lda, ldb, ldc;
32+
size_t nr, kr, sr;
33+
size_t packedRHSsize;
34+
};
35+
36+
KleidiKernel::KleidiKernel(size_t _M, size_t _N, size_t _K, size_t _lda, size_t _ldb, size_t _ldc)
37+
: M(_M),
38+
N(_N),
39+
K(_K),
40+
lda(_lda),
41+
ldb(_ldb),
42+
ldc(_ldc),
43+
nr(ukernel.get_nr()),
44+
kr(ukernel.get_kr()),
45+
sr(ukernel.get_sr()),
46+
packedRHSsize(kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(_N, _K)){};
47+
48+
const size_t KleidiKernel::get_packed_rhs_size() const {
49+
return packedRHSsize;
50+
}
51+
52+
void KleidiKernel::packB(float16_t* inp, float16_t* packed_out, float16_t* bias) {
53+
// Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant.
54+
kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(1,
55+
N,
56+
K,
57+
nr,
58+
kr,
59+
sr, // Packing arguments
60+
ldb * sizeof(float16_t), // RHS stride
61+
inp, // RHS
62+
bias, // Bias
63+
NULL, // Scale
64+
packed_out, // RHS packed
65+
0,
66+
NULL);
67+
}
68+
69+
void KleidiKernel::executeGemm(void* a, void* b, void* c) {
70+
const size_t m_step = ukernel.get_m_step();
71+
const size_t n_step = ukernel.get_n_step();
72+
for (size_t i_m_step = 0; i_m_step < M; i_m_step += m_step) {
73+
for (size_t i_n_step = 0; i_n_step < N; i_n_step += n_step) {
74+
const uint8_t* lhs_ptr =
75+
(const uint8_t*)a + (ukernel.get_lhs_packed_offset(i_m_step, lda * sizeof(uint16_t)));
76+
const uint8_t* rhs_ptr = (const uint8_t*)b + (ukernel.get_rhs_packed_offset(i_n_step, K));
77+
uint8_t* dst_ptr = (uint8_t*)c + (ukernel.get_dst_offset(i_m_step, i_n_step, ldc * sizeof(uint16_t)));
78+
const size_t actual_m = std::min(M - i_m_step, m_step);
79+
const size_t actual_n = std::min(N - i_n_step, n_step);
80+
81+
ukernel.run_matmul(actual_m,
82+
actual_n,
83+
K, // Dimensions
84+
lhs_ptr, // LHS
85+
lda * sizeof(float16_t), // LHS stride
86+
rhs_ptr, // RHS packed
87+
dst_ptr, // DST
88+
ldc * sizeof(float16_t), // DST stride (row)
89+
sizeof(float16_t), // DST stride (col)
90+
-FLT_MAX,
91+
FLT_MAX // Min and max for the clamp operation
92+
);
93+
}
94+
}
95+
}
96+
97+
} // namespace ov::intel_cpu

0 commit comments

Comments
 (0)