1+ #include < arm_neon.h>
2+ #include < kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h>
3+ #include < kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p_interface.h>
4+ #include < kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h>
5+
6+ #include < cfloat>
7+ #include < openvino/core/type/element_type.hpp>
8+
9+ namespace ov ::intel_cpu {
10+
11+ class KleidiKernel {
12+ public:
13+ KleidiKernel (size_t M, size_t N, size_t K, size_t lda, size_t ldb, size_t ldc);
14+ void executeGemm (void * a, void * b, void * c);
15+ void packB (float16_t * inp, float16_t * packed_out, float16_t * bias);
16+ const size_t get_packed_rhs_size () const ;
17+
18+ private:
19+ static constexpr kai_matmul_clamp_f16_f16_f16p_ukernel ukernel{
20+ kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
21+ kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
22+ kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
23+ kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
24+ kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
25+ kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
26+ kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
27+ kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
28+ kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla,
29+ kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla};
30+ size_t M, N, K;
31+ size_t lda, ldb, ldc;
32+ size_t nr, kr, sr;
33+ size_t packedRHSsize;
34+ };
35+
36+ KleidiKernel::KleidiKernel (size_t _M, size_t _N, size_t _K, size_t _lda, size_t _ldb, size_t _ldc)
37+ : M(_M),
38+ N (_N),
39+ K(_K),
40+ lda(_lda),
41+ ldb(_ldb),
42+ ldc(_ldc),
43+ nr(ukernel.get_nr()),
44+ kr(ukernel.get_kr()),
45+ sr(ukernel.get_sr()),
46+ packedRHSsize(kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(_N, _K)){};
47+
48+ const size_t KleidiKernel::get_packed_rhs_size () const {
49+ return packedRHSsize;
50+ }
51+
52+ void KleidiKernel::packB (float16_t * inp, float16_t * packed_out, float16_t * bias) {
53+ // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant.
54+ kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon (1 ,
55+ N,
56+ K,
57+ nr,
58+ kr,
59+ sr, // Packing arguments
60+ ldb * sizeof (float16_t ), // RHS stride
61+ inp, // RHS
62+ bias, // Bias
63+ NULL , // Scale
64+ packed_out, // RHS packed
65+ 0 ,
66+ NULL );
67+ }
68+
69+ void KleidiKernel::executeGemm (void * a, void * b, void * c) {
70+ const size_t m_step = ukernel.get_m_step ();
71+ const size_t n_step = ukernel.get_n_step ();
72+ for (size_t i_m_step = 0 ; i_m_step < M; i_m_step += m_step) {
73+ for (size_t i_n_step = 0 ; i_n_step < N; i_n_step += n_step) {
74+ const uint8_t * lhs_ptr =
75+ (const uint8_t *)a + (ukernel.get_lhs_packed_offset (i_m_step, lda * sizeof (uint16_t )));
76+ const uint8_t * rhs_ptr = (const uint8_t *)b + (ukernel.get_rhs_packed_offset (i_n_step, K));
77+ uint8_t * dst_ptr = (uint8_t *)c + (ukernel.get_dst_offset (i_m_step, i_n_step, ldc * sizeof (uint16_t )));
78+ const size_t actual_m = std::min (M - i_m_step, m_step);
79+ const size_t actual_n = std::min (N - i_n_step, n_step);
80+
81+ ukernel.run_matmul (actual_m,
82+ actual_n,
83+ K, // Dimensions
84+ lhs_ptr, // LHS
85+ lda * sizeof (float16_t ), // LHS stride
86+ rhs_ptr, // RHS packed
87+ dst_ptr, // DST
88+ ldc * sizeof (float16_t ), // DST stride (row)
89+ sizeof (float16_t ), // DST stride (col)
90+ -FLT_MAX,
91+ FLT_MAX // Min and max for the clamp operation
92+ );
93+ }
94+ }
95+ }
96+
97+ } // namespace ov::intel_cpu
0 commit comments