Skip to content

Commit 76951ba

Browse files
committed
Init neon kernel for lut 2 bit gemm
1 parent a3ef325 commit 76951ba

File tree

6 files changed

+1015
-28
lines changed

6 files changed

+1015
-28
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ function(setup_mlas_source_for_windows)
101101
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
102102
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
103103
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
104+
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.h
105+
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.cpp
104106
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
105107
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
106108
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
@@ -470,6 +472,8 @@ else()
470472
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
471473
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
472474
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
475+
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.h
476+
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.cpp
473477
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
474478
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
475479
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,10 @@ struct MLAS_QNBIT_LUT_GEMM_DISPATCH;
12451245

12461246
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2;
12471247

1248+
#if defined(MLAS_TARGET_ARM64)
1249+
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon;
1250+
#endif
1251+
12481252
//
12491253
// Rotary embedding dispatch structure.
12501254
//

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,9 @@ Return Value:
654654
this->ArmNeonIsQuantActivationsUnsigned = HasI8MMInstructions ? false : true;
655655
this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, HasI8MMInstructions);
656656

657+
// Enable LUT-based GEMM for 2-bit quantization on ARM64
658+
this->LutGenKernel = &MlasLutGenKernelNeon;
659+
657660
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
658661
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon;
659662
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelNeon;

onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,17 @@ Module Name:
1010
1111
Abstract:
1212
13-
This module implements x64 AVX2 kernel functions for LUT-based n-bit
14-
quantized integer matrix multiplication.
13+
This module contains the dispatch table declaration for x64 AVX2
14+
LUT-based n-bit quantized integer matrix multiplication kernels.
15+
1516
--*/
1617

1718
#pragma once
18-
#include "qnbitgemm.h"
19-
20-
void
21-
GenerateLUT_avx2(
22-
int32_t group_size,
23-
int8_t lut,
24-
const float* b,
25-
float* scales,
26-
float* biases,
27-
int K
28-
);
29-
30-
void
31-
TMACComputeGemm_avx2(
32-
const void* A,
33-
const void* a_scales,
34-
const void* LUT,
35-
const void* LUT_Scales,
36-
const void* LUT_Biases,
37-
void* C,
38-
int bm,
39-
int K,
40-
int M,
41-
int N,
42-
size_t BlkLen
43-
);
19+
20+
#include "qlutgemm.h"
21+
22+
//
23+
// External dispatch table for AVX2 LUT GEMM kernels.
24+
// Kernel functions are internal to the .cpp file and accessed via this dispatch.
25+
//
26+
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2;

0 commit comments

Comments
 (0)