Skip to content

Commit 51617ca

Browse files
Integrate KleidiAI BF16 SME2 Kernel Through Mlas SBGEMM Path
Signed-off-by: Patryk Kaiser <patryk.kaiser@arm.com>
1 parent f98c756 commit 51617ca

File tree

11 files changed

+496
-0
lines changed

11 files changed

+496
-0
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ function(setup_kleidiai)
279279
target_sources(onnxruntime_mlas PRIVATE
280280
${MLAS_SRC_DIR}/kai_ukernel_interface.cpp
281281
${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp
282+
${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp
282283
${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp
283284
${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp
284285
)

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,6 +1955,7 @@ struct MLAS_SBGEMM_DATA_PARAMS {
19551955
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
19561956
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
19571957
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
1958+
bool BIsPacked = false; /**< Whether B is pre-packed */
19581959
};
19591960

19601961
/**

onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
1313
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h"
1414

15+
#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
16+
1517
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod =
1618
{kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
1719
kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
@@ -64,6 +66,19 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp
6466
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
6567
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm};
6668

69+
const kai_matmul_clamp_f32_bf16p_bf16p_ukernel sbgemm_gemm_sme2 =
70+
{kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
71+
kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
72+
kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
73+
kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
74+
kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
75+
kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
76+
kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
77+
kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
78+
kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
79+
kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
80+
kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa};
81+
6782
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() {
6883
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) {
6984
return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm;
@@ -79,3 +94,8 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() {
7994
return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod;
8095
}
8196
}
97+
98+
const kai_matmul_clamp_f32_bf16p_bf16p_ukernel& GetKleidiAISBGemmUKernel() {
99+
// Currently only SME2 variant exists for bfloat16/SBGEMM kernel
100+
return sbgemm_gemm_sme2;
101+
}

onnxruntime/core/mlas/lib/kai_ukernel_interface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,9 @@
88

99
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h"
1010

11+
#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h"
12+
1113
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel();
1214
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel();
15+
16+
const kai_matmul_clamp_f32_bf16p_bf16p_ukernel& GetKleidiAISBGemmUKernel();

onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,36 @@ MlasGemmBatch(
9090
MLAS_THREADPOOL* ThreadPool
9191
);
9292

93+
#if defined(__aarch64__) && defined(__linux__)
94+
size_t
95+
MLASCALL
96+
MlasSBGemmPackBSize(
97+
size_t N,
98+
size_t K
99+
);
100+
101+
bool
102+
MLASCALL
103+
MlasSBGemmPackB(
104+
size_t N,
105+
size_t K,
106+
const float* B,
107+
size_t ldb,
108+
void* PackedB
109+
);
110+
111+
bool
112+
MLASCALL
113+
MlasSBGemmBatch(
114+
size_t M,
115+
size_t N,
116+
size_t K,
117+
const MLAS_SBGEMM_DATA_PARAMS* Data,
118+
size_t BatchSize,
119+
MLAS_THREADPOOL* ThreadPool
120+
);
121+
#endif
122+
93123
size_t
94124
MLASCALL
95125
MlasDynamicQgemmPackBSize(

0 commit comments

Comments
 (0)