Skip to content

Commit 92c1ed2

Browse files
JonathanC-ARMhariharans29Copilot
authored
Implement FP32 kleidiai Gemv (#26302)
### Description Implementation of special sgemm path which uses GEMV kernels in cases where M or N are 1 Additionally this pr introduces the usage of a microkernel interface which utilizes typedef's provided by KleidiAI such that we can simplify the code and remove things such as ternary operations for SME1 vs SME2 kernels ### Indicative Performance In Lieu of any production models where gemv was a large contributor of the network. I opted to create a mini model to test which contains thousands of randomized matmul variants. With a distribution of GEMV cases throughout <img width="1572" height="148" alt="image (6)" src="https://github.com/user-attachments/assets/451441e4-df5b-42d1-8c6e-ec8dd14161e6" /> Using onnxruntime perf test I was able to half the total inference time vs mlas with this model <img width="1200" height="900" alt="ort_ops_compare_gemv_no_2025-10-07_19-40-30_vs_gemv_2025-10-07_19-40-58" src="https://github.com/user-attachments/assets/ddef3bf3-796c-4f58-8712-361510e2a901" /> **_More Benchmarks to come shortly_** --------- Signed-off-by: Jonathan Clohessy <Jonathan.Clohessy@arm.com> Signed-off-by: Jonathan Clohessy <jonathan.clohessy@arm.com> Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent fa70b18 commit 92c1ed2

File tree

5 files changed

+345
-67
lines changed

5 files changed

+345
-67
lines changed

onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,20 @@
66

77
#include "kai_ukernel_interface.h"
88
#include "mlasi.h"
9+
#include "kleidiai/mlasi_kleidiai.h"
910

1011
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
1112
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
1213
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
1314
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h"
1415

16+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h"
17+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h"
18+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h"
19+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h"
20+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
21+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
22+
1523
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod =
1624
{kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
1725
kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
@@ -64,6 +72,56 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp
6472
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
6573
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm};
6674

75+
const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv_sme =
76+
{kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
77+
kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
78+
kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
79+
kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
80+
kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
81+
kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
82+
kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
83+
kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
84+
kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
85+
kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla};
86+
87+
const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv_sme2 =
88+
{kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
89+
kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
90+
kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
91+
kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
92+
kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
93+
kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
94+
kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
95+
kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
96+
kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
97+
kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla};
98+
99+
const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme =
100+
{kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
101+
kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
102+
kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
103+
kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
104+
kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
105+
kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
106+
kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
107+
kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
108+
kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
109+
kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
110+
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa};
111+
112+
const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme2 =
113+
{kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
114+
kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
115+
kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
116+
kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
117+
kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
118+
kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
119+
kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
120+
kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
121+
kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
122+
kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
123+
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa};
124+
67125
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() {
68126
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) {
69127
return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm;
@@ -79,3 +137,19 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() {
79137
return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod;
80138
}
81139
}
140+
141+
const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel() {
142+
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) {
143+
return sgemm_gemm_sme2;
144+
} else {
145+
return sgemm_gemm_sme;
146+
}
147+
}
148+
149+
const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel() {
150+
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) {
151+
return sgemm_gemv_sme2;
152+
} else {
153+
return sgemm_gemv_sme;
154+
}
155+
}

onnxruntime/core/mlas/lib/kai_ukernel_interface.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,12 @@
88

99
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h"
1010

11+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h"
12+
13+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h"
14+
1115
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel();
1216
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel();
17+
18+
const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel();
19+
const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel();

onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ MlasGemmPackB(
7777
void* PackedB
7878
);
7979

80+
bool
81+
MLASCALL
82+
MlasGemvBatch(
83+
CBLAS_TRANSPOSE TransA,
84+
CBLAS_TRANSPOSE TransB,
85+
size_t M,
86+
size_t N,
87+
size_t K,
88+
const MLAS_SGEMM_DATA_PARAMS* Data,
89+
size_t BatchSize
90+
);
91+
92+
8093
bool
8194
MLASCALL
8295
MlasGemmBatch(

0 commit comments

Comments
 (0)