Skip to content

Commit 84e44b3

Browse files
author
Vrajang Parikh
committed
Cleanup: Improve naming, and file structure
1 parent bc085c5 commit 84e44b3

File tree

9 files changed

+404
-410
lines changed

9 files changed

+404
-410
lines changed

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,10 +1243,10 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx;
12431243

12441244
struct MLAS_QNBIT_LUT_GEMM_DISPATCH;
12451245

1246-
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2;
1246+
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchAvx2;
12471247

12481248
#if defined(MLAS_TARGET_ARM64)
1249-
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon;
1249+
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchNeon;
12501250
#endif
12511251

12521252
//
@@ -1457,7 +1457,7 @@ struct MLAS_PLATFORM {
14571457
const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr};
14581458

14591459
const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr};
1460-
const MLAS_QNBIT_LUT_GEMM_DISPATCH* LutGenKernel{nullptr};
1460+
const MLAS_QNBIT_LUT_GEMM_DISPATCH* LutGemmDispatch{nullptr};
14611461

14621462
MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
14631463
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ Return Value:
422422
this->RopeDispatch = &MlasRopeDispatchAvx2;
423423

424424
// TODO(vraspar): check if this really goes here or if there are other platform reqs that we need to fulfill
425-
this->LutGenKernel = &MlasLutGenKernelAvx2;
425+
this->LutGemmDispatch = &MlasLutGemmDispatchAvx2;
426426

427427
//
428428
// Check if the processor supports Hybrid core architecture.
@@ -655,7 +655,7 @@ Return Value:
655655
this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, HasI8MMInstructions);
656656

657657
// Enable LUT-based GEMM for 2-bit quantization on ARM64
658-
this->LutGenKernel = &MlasLutGenKernelNeon;
658+
this->LutGemmDispatch = &MlasLutGemmDispatchNeon;
659659

660660
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
661661
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon;

onnxruntime/core/mlas/lib/qlutgemm.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ LutGemmPackQuantBData(
191191
const size_t kfactor = tmac_params.kfactor;
192192

193193
// LUT GEMM requires a valid LUT dispatch implementation, so dispatch must be available
194-
const auto* Dispatch = GetMlasPlatform().LutGenKernel;
194+
const auto* Dispatch = GetMlasPlatform().LutGemmDispatch;
195195
if (Dispatch == nullptr || Dispatch->PackQuantBData == nullptr) {
196196
MLAS_THROW_EX(std::runtime_error, "PackQuantBData requires LUT GEMM dispatch support");
197197
}
@@ -240,9 +240,9 @@ LutPackScalesAndZeroPoints(
240240
const size_t bm = tmac_params.bm;
241241

242242
// LUT GEMM is only available for AVX2, so dispatch must be available
243-
const auto* Dispatch = GetMlasPlatform().LutGenKernel;
243+
const auto* Dispatch = GetMlasPlatform().LutGemmDispatch;
244244
if (Dispatch == nullptr || Dispatch->PackScalesAndZeroPoints == nullptr) {
245-
MLAS_THROW_EX(std::runtime_error, "PackScalesAndZeroPoints requires AVX2 dispatch");
245+
MLAS_THROW_EX(std::runtime_error, "PackScalesAndZeroPoints requires LUT GEMM dispatch support");
246246
}
247247

248248
Dispatch->PackScalesAndZeroPoints(
@@ -320,7 +320,7 @@ MlasIsLutGemmAvailable(
320320
size_t BlkLen
321321
)
322322
{
323-
const auto* lut_kernel = GetMlasPlatform().LutGenKernel;
323+
const auto* lut_kernel = GetMlasPlatform().LutGemmDispatch;
324324
if (lut_kernel == nullptr ||
325325
lut_kernel->GenerateLUT == nullptr ||
326326
lut_kernel->ComputeGemm == nullptr ||
@@ -392,7 +392,7 @@ MlasLutGemm(
392392
)
393393
{
394394
// adapted from ggml_backend_tmac_mul_mat
395-
const auto* Dispatch = GetMlasPlatform().LutGenKernel;
395+
const auto* Dispatch = GetMlasPlatform().LutGemmDispatch;
396396
// This should be ensured by calling MlasIsLutGemmAvailable() before MlasLutGemm()
397397
if (Dispatch == nullptr || Dispatch->GenerateLUT == nullptr || Dispatch->ComputeGemm == nullptr) {
398398
MLAS_THROW_EX(std::runtime_error, "TMAC not supported in this configuration");

onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ _mm256_addv_ps(const __m256 v)
5454
#define extract_low_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_castsi256_si128(v))
5555
#define extract_high_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_extracti128_si256(v, 1))
5656

57+
namespace lutgemm_avx2
58+
{
59+
60+
namespace
61+
{
62+
5763
// Template classes for accumulation
5864
template <int N>
5965
struct SignedHalvingAdder {
@@ -324,9 +330,11 @@ lut_ctor_g4_int8_impl(
324330
*lut_biases = biases;
325331
}
326332

327-
// based on lut_ctor_g4_int8_impl
333+
} // namespace
334+
335+
// LutGemmGenerateLUT_CompFp32 - Entry point for LUT generation
328336
void
329-
GenerateLUT_avx2(
337+
LutGemmGenerateLUT_CompFp32(
330338
const float* b,
331339
int8_t* qlut,
332340
float* lut_scales,
@@ -495,10 +503,9 @@ tbl_g4_int8_float_update_impl(int32_t m, float* c, const int8_t* lut, const uint
495503
return 0;
496504
}
497505

498-
// based on qgemm_lut_int8_g4
499-
// Simplified version with hardcoded configuration for 2-bit quantization
506+
// LutGemmCompute_CompFp32 - Entry point for GEMM computation
500507
void
501-
TMACComputeGemm_avx2(
508+
LutGemmCompute_CompFp32(
502509
const uint8_t* A, // Quantized packed weights
503510
const float* Scales, // Weight scales (and optionally zero-points)
504511
const int8_t* LUT, // Pre-computed quantized lookup table
@@ -651,11 +658,11 @@ TMACComputeGemm_avx2(
651658
}
652659

653660
//
654-
// AVX2 optimized weight packing for T-MAC LUT GEMM
661+
// LutGemmPackQuantBData_CompFp32 - AVX2 optimized weight packing for T-MAC LUT GEMM
655662
// This performs the same transformation as the scalar version but uses SIMD operations
656663
//
657664
void
658-
PackQuantBData_avx2(
665+
LutGemmPackQuantBData_CompFp32(
659666
size_t N,
660667
size_t K,
661668
size_t bits,
@@ -864,12 +871,11 @@ PackQuantBData_avx2(
864871
}
865872

866873
//
867-
// AVX2 optimized scales and zero points packing for T-MAC LUT GEMM
868-
// This performs the same transformation as the scalar version but uses SIMD operations
874+
// LutGemmPackScalesAndZeroPoints_CompFp32 - Scales and zero points packing
869875
//
870876
template <bool HasZeroPoint>
871-
static void
872-
PackScalesAndZeroPoints_avx2_impl(
877+
void
878+
LutGemmPackScalesAndZeroPoints_CompFp32_Impl(
873879
size_t N,
874880
size_t K,
875881
size_t bits,
@@ -984,7 +990,7 @@ PackScalesAndZeroPoints_avx2_impl(
984990
}
985991

986992
void
987-
PackScalesAndZeroPoints_avx2(
993+
LutGemmPackScalesAndZeroPoints_CompFp32(
988994
size_t N,
989995
size_t K,
990996
size_t bits,
@@ -1002,25 +1008,27 @@ PackScalesAndZeroPoints_avx2(
10021008
assert(bits == 2);
10031009

10041010
if (HasZeroPoint) {
1005-
PackScalesAndZeroPoints_avx2_impl<true>(
1011+
LutGemmPackScalesAndZeroPoints_CompFp32_Impl<true>(
10061012
N, K, bits, BlkLen, simd_n_out, bm,
10071013
PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool
10081014
);
10091015
} else {
1010-
PackScalesAndZeroPoints_avx2_impl<false>(
1016+
LutGemmPackScalesAndZeroPoints_CompFp32_Impl<false>(
10111017
N, K, bits, BlkLen, simd_n_out, bm,
10121018
PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool
10131019
);
10141020
}
10151021
}
10161022

1023+
} // namespace lutgemm_avx2
1024+
10171025
// Kernel dispatch structure definition.
10181026

1019-
const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2 = []() {
1027+
const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchAvx2 = []() {
10201028
MLAS_QNBIT_LUT_GEMM_DISPATCH d;
1021-
d.GenerateLUT = GenerateLUT_avx2;
1022-
d.ComputeGemm = TMACComputeGemm_avx2;
1023-
d.PackQuantBData = PackQuantBData_avx2;
1024-
d.PackScalesAndZeroPoints = PackScalesAndZeroPoints_avx2;
1029+
d.GenerateLUT = lutgemm_avx2::LutGemmGenerateLUT_CompFp32;
1030+
d.ComputeGemm = lutgemm_avx2::LutGemmCompute_CompFp32;
1031+
d.PackQuantBData = lutgemm_avx2::LutGemmPackQuantBData_CompFp32;
1032+
d.PackScalesAndZeroPoints = lutgemm_avx2::LutGemmPackScalesAndZeroPoints_CompFp32;
10251033
return d;
10261034
}();

onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ Module Name:
2323
// External dispatch table for AVX2 LUT GEMM kernels.
2424
// Kernel functions are internal to the .cpp file and accessed via this dispatch.
2525
//
26-
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2;
26+
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchAvx2;

onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ Module Name:
4646
#define PRAGMA_UNROLL
4747
#endif
4848

49+
namespace lutgemm_neon
50+
{
51+
52+
namespace
53+
{
54+
4955
//
5056
// Template classes for accumulation - adapted from llama.cpp tbl.cpp
5157
//
@@ -282,11 +288,13 @@ lut_ctor_g4_int8_impl_neon(
282288
*lut_biases = biases;
283289
}
284290

291+
} // namespace
292+
285293
//
286-
// GenerateLUT - Entry point for LUT generation
294+
// LutGemmGenerateLUT_CompFp32 - Entry point for LUT generation
287295
//
288-
static void
289-
GenerateLUT_neon(
296+
void
297+
LutGemmGenerateLUT_CompFp32(
290298
const float* b,
291299
int8_t* qlut,
292300
float* lut_scales,
@@ -620,10 +628,10 @@ tbl_g4_int8_float_update_impl_neon(
620628
}
621629

622630
//
623-
// TMACComputeGemm - Entry point for GEMM computation
631+
// LutGemmCompute_CompFp32 - Entry point for GEMM computation
624632
//
625-
static void
626-
TMACComputeGemm_neon(
633+
void
634+
LutGemmCompute_CompFp32(
627635
const uint8_t* A,
628636
const float* Scales,
629637
const int8_t* LUT,
@@ -756,11 +764,11 @@ TMACComputeGemm_neon(
756764
}
757765

758766
//
759-
// Weight packing for NEON (can use scalar or NEON implementation)
767+
// LutGemmPackQuantBData_CompFp32 - Weight packing for NEON
760768
// This is done during model load, so performance is less critical
761769
//
762-
static void
763-
PackQuantBData_neon(
770+
void
771+
LutGemmPackQuantBData_CompFp32(
764772
size_t N,
765773
size_t K,
766774
size_t bits,
@@ -917,11 +925,11 @@ PackQuantBData_neon(
917925
}
918926

919927
//
920-
// Scales and zero points packing
928+
// LutGemmPackScalesAndZeroPoints_CompFp32 - Scales and zero points packing
921929
//
922930
template <bool HasZeroPoint>
923-
static void
924-
PackScalesAndZeroPoints_neon_impl(
931+
void
932+
LutGemmPackScalesAndZeroPoints_CompFp32_Impl(
925933
size_t N,
926934
size_t K,
927935
size_t bits,
@@ -991,8 +999,8 @@ PackScalesAndZeroPoints_neon_impl(
991999
);
9921000
}
9931001

994-
static void
995-
PackScalesAndZeroPoints_neon(
1002+
void
1003+
LutGemmPackScalesAndZeroPoints_CompFp32(
9961004
size_t N,
9971005
size_t K,
9981006
size_t bits,
@@ -1009,27 +1017,29 @@ PackScalesAndZeroPoints_neon(
10091017
assert(bits == 2);
10101018

10111019
if (HasZeroPoint) {
1012-
PackScalesAndZeroPoints_neon_impl<true>(
1020+
LutGemmPackScalesAndZeroPoints_CompFp32_Impl<true>(
10131021
N, K, bits, BlkLen, simd_n_out, bm,
10141022
PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool
10151023
);
10161024
} else {
1017-
PackScalesAndZeroPoints_neon_impl<false>(
1025+
LutGemmPackScalesAndZeroPoints_CompFp32_Impl<false>(
10181026
N, K, bits, BlkLen, simd_n_out, bm,
10191027
PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool
10201028
);
10211029
}
10221030
}
10231031

1032+
} // namespace lutgemm_neon
1033+
10241034
//
10251035
// Kernel dispatch structure definition
10261036
//
1027-
const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon = []() {
1037+
const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchNeon = []() {
10281038
MLAS_QNBIT_LUT_GEMM_DISPATCH d;
1029-
d.GenerateLUT = GenerateLUT_neon;
1030-
d.ComputeGemm = TMACComputeGemm_neon;
1031-
d.PackQuantBData = PackQuantBData_neon;
1032-
d.PackScalesAndZeroPoints = PackScalesAndZeroPoints_neon;
1039+
d.GenerateLUT = lutgemm_neon::LutGemmGenerateLUT_CompFp32;
1040+
d.ComputeGemm = lutgemm_neon::LutGemmCompute_CompFp32;
1041+
d.PackQuantBData = lutgemm_neon::LutGemmPackQuantBData_CompFp32;
1042+
d.PackScalesAndZeroPoints = lutgemm_neon::LutGemmPackScalesAndZeroPoints_CompFp32;
10331043
return d;
10341044
}();
10351045

onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ Module Name:
2323
// External dispatch table for ARM NEON LUT GEMM kernels.
2424
// Kernel functions are internal to the .cpp file and accessed via this dispatch.
2525
//
26-
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon;
26+
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchNeon;
2727

0 commit comments

Comments
 (0)