File tree Expand file tree Collapse file tree 6 files changed +1015
-28
lines changed
onnxruntime/core/mlas/lib Expand file tree Collapse file tree 6 files changed +1015
-28
lines changed Original file line number Diff line number Diff line change @@ -101,6 +101,8 @@ function(setup_mlas_source_for_windows)
101101 ${MLAS_SRC_DIR} /qnbitgemm_kernel_neon.cpp
102102 ${MLAS_SRC_DIR} /sqnbitgemm_kernel_neon_fp32.cpp
103103 ${MLAS_SRC_DIR} /sqnbitgemm_kernel_neon_int8.cpp
104+ ${MLAS_SRC_DIR} /sqnbitgemm_lut_kernel_neon.h
105+ ${MLAS_SRC_DIR} /sqnbitgemm_lut_kernel_neon.cpp
104106 ${MLAS_SRC_DIR} /cast_kernel_neon.cpp
105107 ${MLAS_SRC_DIR} /hqnbitgemm_kernel_neon_fp16.cpp
106108 ${MLAS_SRC_DIR} /rotary_embedding_kernel_neon.h
@@ -470,6 +472,8 @@ else()
470472 ${MLAS_SRC_DIR} /qnbitgemm_kernel_neon.cpp
471473 ${MLAS_SRC_DIR} /sqnbitgemm_kernel_neon_fp32.cpp
472474 ${MLAS_SRC_DIR} /sqnbitgemm_kernel_neon_int8.cpp
475+ ${MLAS_SRC_DIR} /sqnbitgemm_lut_kernel_neon.h
476+ ${MLAS_SRC_DIR} /sqnbitgemm_lut_kernel_neon.cpp
473477 ${MLAS_SRC_DIR} /rotary_embedding_kernel_neon.h
474478 ${MLAS_SRC_DIR} /rotary_embedding_kernel_neon.cpp
475479 ${MLAS_SRC_DIR} /hgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change @@ -1245,6 +1245,10 @@ struct MLAS_QNBIT_LUT_GEMM_DISPATCH;
12451245
12461246extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2;
12471247
1248+ #if defined(MLAS_TARGET_ARM64)
1249+ extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon;
1250+ #endif
1251+
12481252//
12491253// Rotary embedding dispatch structure.
12501254//
Original file line number Diff line number Diff line change @@ -654,6 +654,9 @@ Return Value:
654654 this ->ArmNeonIsQuantActivationsUnsigned = HasI8MMInstructions ? false : true ;
655655 this ->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon (HasDotProductInstructions, HasI8MMInstructions);
656656
657+ // Enable LUT-based GEMM for 2-bit quantization on ARM64
658+ this ->LutGenKernel = &MlasLutGenKernelNeon;
659+
657660#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
658661 this ->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon;
659662 this ->CastF32ToF16Kernel = &MlasCastF32ToF16KernelNeon;
Original file line number Diff line number Diff line change @@ -10,34 +10,17 @@ Module Name:
1010
1111Abstract:
1212
13- This module implements x64 AVX2 kernel functions for LUT-based n-bit
14- quantized integer matrix multiplication.
13+ This module contains the dispatch table declaration for x64 AVX2
14+ LUT-based n-bit quantized integer matrix multiplication kernels.
15+
1516--*/
1617
1718#pragma once
18- #include "qnbitgemm.h"
19-
20- void
21- GenerateLUT_avx2 (
22- int32_t group_size ,
23- int8_t lut ,
24- const float * b ,
25- float * scales ,
26- float * biases ,
27- int K
28- );
29-
30- void
31- TMACComputeGemm_avx2 (
32- const void * A ,
33- const void * a_scales ,
34- const void * LUT ,
35- const void * LUT_Scales ,
36- const void * LUT_Biases ,
37- void * C ,
38- int bm ,
39- int K ,
40- int M ,
41- int N ,
42- size_t BlkLen
43- );
19+
20+ #include "qlutgemm.h"
21+
22+ //
23+ // External dispatch table for AVX2 LUT GEMM kernels.
24+ // Kernel functions are internal to the .cpp file and accessed via this dispatch.
25+ //
26+ extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2 ;
You can’t perform that action at this time.
0 commit comments