From f520ad457f0bcdf46875137d7a47f4683275705a Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 21 Jan 2026 23:18:34 +0000 Subject: [PATCH 01/10] Add AVX2 LUT weight packing for SQNBitGemm --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 179 +-------- onnxruntime/core/mlas/lib/qlutgemm.h | 39 ++ .../mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp | 340 ++++++++++++++++++ 3 files changed, 398 insertions(+), 160 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index f029e539f02a1..cd285be6dc78c 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -163,111 +163,14 @@ LutGemmPackQuantBData( const size_t bm = tmac_params.bm; const size_t kfactor = tmac_params.kfactor; - assert(BlkLen % g == 0); - assert((BlkLen / g) % kfactor == 0); - - const size_t mgroup = ngroups_per_elem * simd_n_in; // 32 - assert(bm % mgroup == 0); - assert(bm % bits == 0); - - std::unique_ptr buf(new uint8_t[N * bits * (K / g)]); - memset(buf.get(), 0, N * bits * (K / g)); - - const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - size_t im = static_cast(tid); - for (size_t ik = 0; ik < K; ++ik) { - size_t idx = (im * K + ik); - size_t num_elem_per_byte = 8 / bits; - size_t elem_idx = idx % num_elem_per_byte; - - uint8_t v = ((const uint8_t*)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits); - - for (size_t ib = 0; ib < bits; ++ib) { - size_t new_ik = ik / g; - size_t shft_left = ik % g; - buf[im * bits * K / g + ib * K / g + new_ik] += static_cast(((v >> ib) & 1) << shft_left); - } - } - } - ); - - // Now buf contains the bit planes grouped by g along K - // Next, we need to do a multi-reshape/transpose into the final layout - - const size_t c0_fac2 = K / g; - const size_t c0_fac1 = simd_n_out * c0_fac2; - const size_t c0_fac0 = bits * c0_fac1; - - const size_t c1_nb2 = K / g; - const size_t c1_nb1 = simd_n_in * c1_nb2; - const size_t c1_nb0 = ngroups_per_elem * c1_nb1; - const size_t c1_fac2 = K / g; - const size_t c1_fac1 = ngroups_per_elem * c1_fac2; - const size_t c1_fac0 = simd_n_in * c1_fac1; - - const size_t c2_nb4 = kfactor; - const size_t c2_nb3 = K / g / kfactor * c2_nb4; - const size_t c2_nb2 = ngroups_per_elem * c2_nb3; - const size_t c2_nb1 = simd_n_in * c2_nb2; - const size_t c2_nb0 = bm / mgroup * c2_nb1; - const size_t c2_fac3 = simd_n_in * ngroups_per_elem; - const size_t c2_fac2 = kfactor * c2_fac3; - const size_t c2_fac1 = bm / mgroup * c2_fac2; - const size_t c2_fac0 = K / g / kfactor * c2_fac1; - - const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem); - memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed? - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - size_t im = static_cast(tid); - for (size_t ib = 0; ib < bits; ib++) { - for (size_t ik = 0; ik < K / g; ik++) { - // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) - size_t new_im = im / simd_n_out; - size_t new_isno = im % simd_n_out; - size_t new_ib = ib; - size_t new_ik = ik; - size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; - - // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) - new_im = new_idx / c1_nb0; - size_t new_ing = (new_idx % c1_nb0) / c1_nb1; - size_t new_isni = (new_idx % c1_nb1) / c1_nb2; - new_ik = (new_idx % c1_nb2); - new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; - - // # 0 1 2 3 4 5 - // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) - new_im = new_idx / c2_nb0; - size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; - new_isni = (new_idx % c2_nb1) / c2_nb2; - new_ing = (new_idx % c2_nb2) / c2_nb3; - new_ik = (new_idx % c2_nb3) / c2_nb4; - size_t new_ikf = (new_idx % c2_nb4); - new_idx = new_im * c2_fac0 + - new_ik * c2_fac1 + - new_ibm * c2_fac2 + - new_ikf * c2_fac3 + - new_isni * ngroups_per_elem + - new_ing; - new_idx = new_idx / ngroups_per_elem; - size_t buf_idx = im * bits * K / g + ib * K / g + ik; - uint8_t buf_val = buf[buf_idx]; - - // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) - PackedQuantBDataBegin[new_idx] = static_cast( - static_cast(PackedQuantBDataBegin[new_idx]) + - (buf_val << (new_ing * g)) - ); - } - } - } + // LUT GEMM is only available for AVX2, so dispatch must be available + const auto* Dispatch = GetMlasPlatform().LutGenKernel; + assert(Dispatch && Dispatch->PackQuantBData && "PackQuantBData requires AVX2 dispatch"); + + Dispatch->PackQuantBData( + N, K, bits, g, ngroups_per_elem, + simd_n_in, simd_n_out, bm, kfactor, + QuantBDataBegin, PackedQuantBDataBegin, ThreadPool ); } @@ -298,67 +201,23 @@ LutPackScalesAndZeroPoints( bool HasZeroPoint, float* PackedQuantBZPBegin, const float* QuantBScale, - const uint8_t* QuantBZeroPoint + const uint8_t* QuantBZeroPoint, + MLAS_THREADPOOL* ThreadPool ) { const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, BlkBitWidth, BlkLen, HasZeroPoint); const size_t bits = tmac_params.bits; const size_t simd_n_out = tmac_params.simd_n_out; const size_t bm = tmac_params.bm; - const size_t num_elem_per_byte = 8 / bits; - - // ZP array is column-major packed, with per-column alignment to byte boundary - const size_t row_blks = K / BlkLen; // number of blocks per column - const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte; - - for (size_t im = 0; im < N; im += 1) { - for (size_t ik = 0; ik < K; ik += BlkLen) { - size_t idx = (im * K + ik) / BlkLen; // linear block index for scale (scale is NOT packed) - float scale = QuantBScale[idx]; - float zp = 0.0f; - if (HasZeroPoint) { - size_t blk_in_col = ik / BlkLen; // block index within column - size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte; - size_t elem_idx = blk_in_col % num_elem_per_byte; - uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & ((1 << bits) - 1); - - // The LUT kernel assumes weights are centered around the midpoint (2 for 2-bit). - // Thus, need to correct for the actual ZP relative to the midpoint. - - int midpoint = 1 << (bits - 1); // 2 for 2-bit - zp = static_cast(static_cast(v) - midpoint) * scale; - } - - // TODO(vraspar): fix when k < BlkLen and nb1 is 0 - size_t nb1 = K / BlkLen; - size_t nb0 = bm / bits * nb1; - - size_t new_im, new_ibm, new_ik; - if (nb1 == 0) { - new_im = 0; - new_ibm = 0; - new_ik = 0; - } else { - new_im = idx / nb0; - new_ibm = (idx % nb0) / nb1; - new_ik = (idx % nb1); - } - - if (HasZeroPoint) { - size_t new_isimd = new_ibm % simd_n_out; - size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out; - size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; - size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd; - - PackedQuantBZPBegin[new_idx_scale] = scale; - PackedQuantBZPBegin[new_idx_zero] = zp; - } else { - size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm; - PackedQuantBZPBegin[new_idx] = scale; - } - } - } + // LUT GEMM is only available for AVX2, so dispatch must be available + const auto* Dispatch = GetMlasPlatform().LutGenKernel; + assert(Dispatch && Dispatch->PackScalesAndZeroPoints && "PackScalesAndZeroPoints requires AVX2 dispatch"); + + Dispatch->PackScalesAndZeroPoints( + N, K, bits, BlkLen, simd_n_out, bm, HasZeroPoint, + PackedQuantBZPBegin, QuantBScale, QuantBZeroPoint, ThreadPool + ); } // Internal helper: calculates the offset to scales in the packed buffer @@ -418,7 +277,7 @@ MlasLutGemmPack( if (QuantBScale != nullptr) { size_t scales_offset = LutGemmPackedScalesOffset(N, K, BlkBitWidth, BlkLen, HasZeroPoint); float* scales_dest = reinterpret_cast(PackedBuf + scales_offset); - LutPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint); + LutPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint, ThreadPool); } } diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h index ef4d01a2c5809..a64947d8e523c 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.h +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -70,6 +70,41 @@ typedef void(MLAS_QNBIT_LUT_GEMM_COMPUTE)( bool HasZeroPoint ); +// +// Function signature for packing quantized B data +// +typedef void(MLAS_QNBIT_LUT_PACK_QUANTB_DATA)( + size_t N, + size_t K, + size_t bits, + size_t g, + size_t ngroups_per_elem, + size_t simd_n_in, + size_t simd_n_out, + size_t bm, + size_t kfactor, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +); + +// +// Function signature for packing scales and zero points +// +typedef void(MLAS_QNBIT_LUT_PACK_SCALES_AND_ZP)( + size_t N, + size_t K, + size_t bits, + size_t BlkLen, + size_t simd_n_out, + size_t bm, + bool HasZeroPoint, + float* PackedScalesBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + MLAS_THREADPOOL* ThreadPool +); + // // Kernel dispatch structure. // @@ -81,4 +116,8 @@ struct MLAS_QNBIT_LUT_GEMM_DISPATCH { MLAS_QNBIT_GEMM_LUT_GEN* GenerateLUT = nullptr; MLAS_QNBIT_LUT_GEMM_COMPUTE* ComputeGemm = nullptr; + + MLAS_QNBIT_LUT_PACK_QUANTB_DATA* PackQuantBData = nullptr; + + MLAS_QNBIT_LUT_PACK_SCALES_AND_ZP* PackScalesAndZeroPoints = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index b54f051ca1504..d7e3ee404a12f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -20,6 +20,8 @@ Module Name: --*/ #include +#include +#include #include #include // AVX2 intrinsics @@ -661,11 +663,349 @@ TMACComputeGemm_avx2( delete[] CBits; } +// +// AVX2 optimized weight packing for T-MAC LUT GEMM +// This performs the same transformation as the scalar version but uses SIMD operations +// +void +PackQuantBData_avx2( + size_t N, + size_t K, + size_t bits, + size_t g, + size_t ngroups_per_elem, + size_t simd_n_in, + size_t simd_n_out, + size_t bm, + size_t kfactor, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + // Only optimized for 2-bit, g=4, ngroups_per_elem=2 + assert(bits == 2 && g == 4 && ngroups_per_elem == 2); + + const size_t mgroup = ngroups_per_elem * simd_n_in; // 32 + const size_t K_div_g = K / g; + + // Phase 1: Bit-plane decomposition with grouping by g=4 + // For 2-bit: each input byte has 4 elements, we extract bit planes and group 4 consecutive bits + std::unique_ptr buf(new uint8_t[N * bits * K_div_g]); + + // Masks for 2-bit extraction + const __m256i mask_2bit = _mm256_set1_epi8(0x03); // mask for 2-bit values + const __m256i mask_bit0 = _mm256_set1_epi8(0x01); // bit 0 of each 2-bit element + + // Phase 1: Parallelize over N (each thread processes one row) + MlasTrySimpleParallel( + ThreadPool, static_cast(N), + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + + // Process in chunks of 32 bytes (128 2-bit elements = 32 groups of 4) + const uint8_t* src_row = reinterpret_cast(QuantBDataBegin) + (im * K / 4); + uint8_t* dst_bit0 = buf.get() + im * bits * K_div_g; + uint8_t* dst_bit1 = dst_bit0 + K_div_g; + + size_t ik = 0; + // Process 128 elements at a time (32 bytes input = 128 2-bit elements = 32 output bytes per bit plane) + for (; ik + 128 <= K; ik += 128) { + // Load 32 bytes = 128 2-bit elements + __m256i packed = _mm256_loadu_si256(reinterpret_cast(src_row + ik / 4)); + + // Extract each of 4 positions within each byte + // pos0: bits 0-1, pos1: bits 2-3, pos2: bits 4-5, pos3: bits 6-7 + __m256i pos0 = _mm256_and_si256(packed, mask_2bit); + __m256i pos1 = _mm256_and_si256(_mm256_srli_epi16(packed, 2), mask_2bit); + __m256i pos2 = _mm256_and_si256(_mm256_srli_epi16(packed, 4), mask_2bit); + __m256i pos3 = _mm256_srli_epi16(packed, 6); + + // For g=4: we need to group 4 consecutive elements + // Each output byte contains bits from 4 consecutive input elements, shifted by their position + // Output for bit0: (pos0_bit0 << 0) | (pos1_bit0 << 1) | (pos2_bit0 << 2) | (pos3_bit0 << 3) + + // Extract bit 0 from each position + __m256i b0_pos0 = _mm256_and_si256(pos0, mask_bit0); // bit0 at position 0 + __m256i b0_pos1 = _mm256_and_si256(pos1, mask_bit0); // bit0 at position 0 + __m256i b0_pos2 = _mm256_and_si256(pos2, mask_bit0); // bit0 at position 0 + __m256i b0_pos3 = _mm256_and_si256(pos3, mask_bit0); // bit0 at position 0 + + // Combine: shift each position's bit to its final location + __m256i bit0_out = _mm256_or_si256( + _mm256_or_si256(b0_pos0, _mm256_slli_epi16(b0_pos1, 1)), + _mm256_or_si256(_mm256_slli_epi16(b0_pos2, 2), _mm256_slli_epi16(b0_pos3, 3)) + ); + + // Extract bit 1 from each position and shift down first + __m256i b1_pos0 = _mm256_and_si256(_mm256_srli_epi16(pos0, 1), mask_bit0); + __m256i b1_pos1 = _mm256_and_si256(_mm256_srli_epi16(pos1, 1), mask_bit0); + __m256i b1_pos2 = _mm256_and_si256(_mm256_srli_epi16(pos2, 1), mask_bit0); + __m256i b1_pos3 = _mm256_and_si256(_mm256_srli_epi16(pos3, 1), mask_bit0); + + // Combine for bit 1 plane + __m256i bit1_out = _mm256_or_si256( + _mm256_or_si256(b1_pos0, _mm256_slli_epi16(b1_pos1, 1)), + _mm256_or_si256(_mm256_slli_epi16(b1_pos2, 2), _mm256_slli_epi16(b1_pos3, 3)) + ); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst_bit0 + ik / g), bit0_out); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst_bit1 + ik / g), bit1_out); + } + + // Handle remaining elements with scalar code + for (; ik < K; ++ik) { + size_t idx = ik; + size_t num_elem_per_byte = 4; // 2-bit: 4 elements per byte + size_t elem_idx = idx % num_elem_per_byte; + uint8_t v = src_row[idx / num_elem_per_byte] >> (elem_idx * bits); + + size_t new_ik = ik / g; + size_t shft_left = ik % g; + dst_bit0[new_ik] += static_cast(((v >> 0) & 1) << shft_left); + dst_bit1[new_ik] += static_cast(((v >> 1) & 1) << shft_left); + } + } + ); + + // Phase 2: Multi-reshape/transpose into final layout + // Precompute factors and simplify index math to avoid expensive div/mod in inner loops. + // const size_t bm_div_bits = bm / bits; + const size_t bm_div_mgroup = bm / mgroup; + + const size_t c2_fac3_div = simd_n_in; + const size_t c2_fac2_div = kfactor * c2_fac3_div; + const size_t c2_fac1_div = bm_div_mgroup * c2_fac2_div; + const size_t c2_fac0_div = K_div_g * bm_div_mgroup * simd_n_in; + + const size_t PackedQuantBDataSize = (N * bits) * (K_div_g / ngroups_per_elem); + memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); + auto* packed_u8 = reinterpret_cast(PackedQuantBDataBegin); + + // Phase 2: Parallelize over N - each thread handles all (bits, K/g) work for its assigned rows + // This ensures no write conflicts since each im writes to disjoint output regions + MlasTrySimpleParallel( + ThreadPool, static_cast(N), + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + const size_t im0 = im / simd_n_out; + const size_t isno = im - im0 * simd_n_out; + const size_t x_base = simd_n_out * (im0 * bits) + isno; + + for (size_t ib = 0; ib < bits; ib++) { + const size_t x = x_base + ib * simd_n_out; + const size_t new_im1 = x / mgroup; + const size_t y = x - new_im1 * mgroup; + const size_t new_ing = y / simd_n_in; + const size_t new_isni = y - new_ing * simd_n_in; + + const size_t new_im2 = new_im1 / bm_div_mgroup; + const size_t new_ibm = new_im1 - new_im2 * bm_div_mgroup; + + const size_t base_im = new_im2 * c2_fac0_div + new_ibm * c2_fac2_div + new_isni; + const size_t buf_base = im * bits * K_div_g + ib * K_div_g; + + const uint8_t shift = static_cast(new_ing * g); + const size_t stride = c2_fac3_div; + + for (size_t ik = 0; ik < K_div_g; ik += kfactor) { + const size_t new_ik = ik / kfactor; + const size_t base_k = base_im + new_ik * c2_fac1_div; + const size_t buf_k = buf_base + ik; + + uint8_t* dst = packed_u8 + base_k; + const uint8_t* src = buf.get() + buf_k; + + if (kfactor == 8) { + dst[stride * 0] = static_cast(dst[stride * 0] + (src[0] << shift)); + dst[stride * 1] = static_cast(dst[stride * 1] + (src[1] << shift)); + dst[stride * 2] = static_cast(dst[stride * 2] + (src[2] << shift)); + dst[stride * 3] = static_cast(dst[stride * 3] + (src[3] << shift)); + dst[stride * 4] = static_cast(dst[stride * 4] + (src[4] << shift)); + dst[stride * 5] = static_cast(dst[stride * 5] + (src[5] << shift)); + dst[stride * 6] = static_cast(dst[stride * 6] + (src[6] << shift)); + dst[stride * 7] = static_cast(dst[stride * 7] + (src[7] << shift)); + } else if (kfactor == 16) { + dst[stride * 0] = static_cast(dst[stride * 0] + (src[0] << shift)); + dst[stride * 1] = static_cast(dst[stride * 1] + (src[1] << shift)); + dst[stride * 2] = static_cast(dst[stride * 2] + (src[2] << shift)); + dst[stride * 3] = static_cast(dst[stride * 3] + (src[3] << shift)); + dst[stride * 4] = static_cast(dst[stride * 4] + (src[4] << shift)); + dst[stride * 5] = static_cast(dst[stride * 5] + (src[5] << shift)); + dst[stride * 6] = static_cast(dst[stride * 6] + (src[6] << shift)); + dst[stride * 7] = static_cast(dst[stride * 7] + (src[7] << shift)); + dst[stride * 8] = static_cast(dst[stride * 8] + (src[8] << shift)); + dst[stride * 9] = static_cast(dst[stride * 9] + (src[9] << shift)); + dst[stride * 10] = static_cast(dst[stride * 10] + (src[10] << shift)); + dst[stride * 11] = static_cast(dst[stride * 11] + (src[11] << shift)); + dst[stride * 12] = static_cast(dst[stride * 12] + (src[12] << shift)); + dst[stride * 13] = static_cast(dst[stride * 13] + (src[13] << shift)); + dst[stride * 14] = static_cast(dst[stride * 14] + (src[14] << shift)); + dst[stride * 15] = static_cast(dst[stride * 15] + (src[15] << shift)); + } else { + for (size_t ikf = 0; ikf < kfactor; ikf++) { + dst[stride * ikf] = static_cast(dst[stride * ikf] + (src[ikf] << shift)); + } + } + } + } + } + ); +} + +// +// AVX2 optimized scales and zero points packing for T-MAC LUT GEMM +// This performs the same transformation as the scalar version but uses SIMD operations +// +template +static void +PackScalesAndZeroPoints_avx2_impl( + size_t N, + size_t K, + size_t bits, + size_t BlkLen, + size_t simd_n_out, + size_t bm, + float* PackedScalesBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + MLAS_THREADPOOL* ThreadPool +) +{ + const size_t num_elem_per_byte = 8 / bits; // 4 for 2-bit + const size_t row_blks = K / BlkLen; // number of blocks per column + const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte; + + const size_t nb1 = K / BlkLen; + const size_t bm_div_bits = bm / bits; + const int midpoint = 1 << (bits - 1); // 2 for 2-bit + const uint8_t bits_mask = static_cast((1 << bits) - 1); + + const size_t TotalBlocks = N * row_blks; + ptrdiff_t MaxThreads = MlasGetMaximumThreadCount(ThreadPool); + + if (N >= static_cast(MaxThreads) || row_blks <= 1) { + MlasTrySimpleParallel( + ThreadPool, static_cast(N), + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + const size_t new_im = (bm_div_bits > 0) ? (im / bm_div_bits) : 0; + const size_t new_ibm = (bm_div_bits > 0) ? (im - new_im * bm_div_bits) : 0; + + if constexpr (HasZeroPoint) { + const size_t new_isimd = new_ibm % simd_n_out; + const size_t new_ibm_div_simd = new_ibm / simd_n_out; + const size_t outer_base = new_im * (bm_div_bits * nb1 / simd_n_out) + new_ibm_div_simd; + const size_t outer_stride = bm_div_bits / simd_n_out; + + for (size_t blk_in_col = 0; blk_in_col < row_blks; blk_in_col++) { + const size_t idx = im * nb1 + blk_in_col; + const float scale = QuantBScale[idx]; + + size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte; + size_t elem_idx = blk_in_col % num_elem_per_byte; + uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & bits_mask; + float zp = static_cast(static_cast(v) - midpoint) * scale; + + const size_t new_idx_outer = outer_base + blk_in_col * outer_stride; + const size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; + const size_t new_idx_zero = new_idx_scale + simd_n_out; + + PackedScalesBegin[new_idx_scale] = scale; + PackedScalesBegin[new_idx_zero] = zp; + } + } else { + const size_t base_idx = new_im * bm_div_bits * nb1 + new_ibm; + const size_t stride_idx = bm_div_bits; + + for (size_t blk_in_col = 0; blk_in_col < row_blks; blk_in_col++) { + const size_t idx = im * nb1 + blk_in_col; + const float scale = QuantBScale[idx]; + const size_t new_idx = base_idx + blk_in_col * stride_idx; + PackedScalesBegin[new_idx] = scale; + } + } + } + ); + } else { + MlasTrySimpleParallel( + ThreadPool, static_cast(TotalBlocks), + [&](ptrdiff_t tid) { + const size_t block_idx = static_cast(tid); + const size_t im = block_idx / row_blks; + const size_t blk_in_col = block_idx - im * row_blks; + + const size_t new_im = (bm_div_bits > 0) ? (im / bm_div_bits) : 0; + const size_t new_ibm = (bm_div_bits > 0) ? (im - new_im * bm_div_bits) : 0; + + const size_t idx = im * nb1 + blk_in_col; + const float scale = QuantBScale[idx]; + + if constexpr (HasZeroPoint) { + const size_t new_isimd = new_ibm % simd_n_out; + const size_t new_ibm_div_simd = new_ibm / simd_n_out; + const size_t outer_base = new_im * (bm_div_bits * nb1 / simd_n_out) + new_ibm_div_simd; + const size_t outer_stride = bm_div_bits / simd_n_out; + + size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte; + size_t elem_idx = blk_in_col % num_elem_per_byte; + uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & bits_mask; + float zp = static_cast(static_cast(v) - midpoint) * scale; + + const size_t new_idx_outer = outer_base + blk_in_col * outer_stride; + const size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; + const size_t new_idx_zero = new_idx_scale + simd_n_out; + + PackedScalesBegin[new_idx_scale] = scale; + PackedScalesBegin[new_idx_zero] = zp; + } else { + const size_t base_idx = new_im * bm_div_bits * nb1 + new_ibm; + const size_t new_idx = base_idx + blk_in_col * bm_div_bits; + PackedScalesBegin[new_idx] = scale; + } + } + ); + } +} + +void +PackScalesAndZeroPoints_avx2( + size_t N, + size_t K, + size_t bits, + size_t BlkLen, + size_t simd_n_out, + size_t bm, + bool HasZeroPoint, + float* PackedScalesBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + MLAS_THREADPOOL* ThreadPool +) +{ + // Only optimized for 2-bit quantization + assert(bits == 2); + + if (HasZeroPoint) { + PackScalesAndZeroPoints_avx2_impl( + N, K, bits, BlkLen, simd_n_out, bm, + PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool + ); + } else { + PackScalesAndZeroPoints_avx2_impl( + N, K, bits, BlkLen, simd_n_out, bm, + PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool + ); + } +} + // Kernel dispatch structure definition. const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2 = []() { MLAS_QNBIT_LUT_GEMM_DISPATCH d; d.GenerateLUT = GenerateLUT_avx2; d.ComputeGemm = TMACComputeGemm_avx2; + d.PackQuantBData = PackQuantBData_avx2; + d.PackScalesAndZeroPoints = PackScalesAndZeroPoints_avx2; return d; }(); From b52a9473e7dc336ac05d33d87550cf36e0ab58b1 Mon Sep 17 00:00:00 2001 From: vraspar Date: Thu, 22 Jan 2026 12:05:30 -0800 Subject: [PATCH 02/10] lut profiling --- onnxruntime/test/mlas/bench/bench_lutgemm.cpp | 186 ++++++++++++++++++ .../test/mlas/unittest/test_sqlutgemm.cpp | 5 +- 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/mlas/bench/bench_lutgemm.cpp diff --git a/onnxruntime/test/mlas/bench/bench_lutgemm.cpp b/onnxruntime/test/mlas/bench/bench_lutgemm.cpp new file mode 100644 index 0000000000000..72657fec1c4f1 --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_lutgemm.cpp @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas_q4.h" +#include "mlas_qnbit.h" +#include "bench_util.h" +#include "core/util/thread_utils.h" + +#include + +static const std::vector lutgemm_bench_arg_names = {"BlkLen", "N", "K", "Threads", "HasZP"}; +static const std::vector lutgemm_compute_arg_names = {"BlkLen", "M", "N", "K", "Threads", "HasZP"}; + +template +void LUTGEMM_PACK(benchmark::State& state) { + if (state.range(0) <= 0) throw std::invalid_argument("BlkLen must be greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must be greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must be greater than 0!"); + if (state.range(3) <= 0) throw std::invalid_argument("Threads must be greater than 0!"); + + const size_t BlkLen = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); + const size_t Threads = static_cast(state.range(3)); + const bool HasZeroPoint = static_cast(state.range(4)); + + if (!MlasIsLutGemmAvailable(N, K, BlkBitWidth, BlkLen)) { + state.SkipWithMessage("LUT GEMM is not available with the given configuration."); + return; + } + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = static_cast(Threads); + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes( + static_cast(BlkLen), true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + auto B = RandomVectorUniform(K * N, -1.0f, 1.0f); + std::vector QuantBData(QuantBDataSizeInBytes); + std::vector QuantBScale(QuantBScaleSize); + std::vector QuantBZeroPoint(HasZeroPoint ? QuantBZeroPointSizeInBytes : 0); + + MlasQuantizeBlockwise( + QuantBData.data(), QuantBScale.data(), + HasZeroPoint ? QuantBZeroPoint.data() : nullptr, + B.data(), static_cast(BlkLen), true, + static_cast(K), static_cast(N), static_cast(N), + tp.get()); + + MlasClearLutGemmKernelConfig(); + MlasInitLutGemmKernelConfig(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + + size_t PackedBufSize = MlasLutGemmPackedSize(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + std::vector PackedBuf(PackedBufSize); + + MlasLutGemmPack( + N, K, BlkBitWidth, BlkLen, HasZeroPoint, + reinterpret_cast(QuantBData.data()), + QuantBScale.data(), + HasZeroPoint ? QuantBZeroPoint.data() : nullptr, + PackedBuf.data(), + tp.get()); + + for (auto _ : state) { + MlasLutGemmPack( + N, K, BlkBitWidth, BlkLen, HasZeroPoint, + reinterpret_cast(QuantBData.data()), + QuantBScale.data(), + HasZeroPoint ? QuantBZeroPoint.data() : nullptr, + PackedBuf.data(), + tp.get()); + } +} + +template +void LUTGEMM_COMPUTE(benchmark::State& state) { + if (state.range(0) <= 0) throw std::invalid_argument("BlkLen must be greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("M must be greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("N must be greater than 0!"); + if (state.range(3) <= 0) throw std::invalid_argument("K must be greater than 0!"); + if (state.range(4) <= 0) throw std::invalid_argument("Threads must be greater than 0!"); + + const size_t BlkLen = static_cast(state.range(0)); + const size_t M = static_cast(state.range(1)); + const size_t N = static_cast(state.range(2)); + const size_t K = static_cast(state.range(3)); + const size_t Threads = static_cast(state.range(4)); + const bool HasZeroPoint = static_cast(state.range(5)); + + if (!MlasIsLutGemmAvailable(N, K, BlkBitWidth, BlkLen)) { + state.SkipWithMessage("LUT GEMM is not available with the given configuration."); + return; + } + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = static_cast(Threads); + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + auto A = RandomVectorUniform(M * K, -1.0f, 1.0f); + std::vector C(M * N); + + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes( + static_cast(BlkLen), true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + auto B = RandomVectorUniform(K * N, -1.0f, 1.0f); + std::vector QuantBData(QuantBDataSizeInBytes); + std::vector QuantBScale(QuantBScaleSize); + std::vector QuantBZeroPoint(HasZeroPoint ? QuantBZeroPointSizeInBytes : 0); + + MlasQuantizeBlockwise( + QuantBData.data(), QuantBScale.data(), + HasZeroPoint ? QuantBZeroPoint.data() : nullptr, + B.data(), static_cast(BlkLen), true, + static_cast(K), static_cast(N), static_cast(N), + tp.get()); + + MlasClearLutGemmKernelConfig(); + MlasInitLutGemmKernelConfig(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + + size_t PackedBufSize = MlasLutGemmPackedSize(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + std::vector PackedBuf(PackedBufSize); + + MlasLutGemmPack( + N, K, BlkBitWidth, BlkLen, HasZeroPoint, + reinterpret_cast(QuantBData.data()), + QuantBScale.data(), + HasZeroPoint ? QuantBZeroPoint.data() : nullptr, + PackedBuf.data(), + tp.get()); + + MlasLutGemm(A.data(), BlkLen, PackedBuf.data(), C.data(), + static_cast(K), static_cast(M), static_cast(N), + HasZeroPoint, tp.get()); + + for (auto _ : state) { + MlasLutGemm(A.data(), BlkLen, PackedBuf.data(), C.data(), + static_cast(K), static_cast(M), static_cast(N), + HasZeroPoint, tp.get()); + } +} + +static void LutGemmPackArgs(benchmark::internal::Benchmark* b) { + b->ArgNames(lutgemm_bench_arg_names); + b->ArgsProduct({ + {128}, // BlkLen + {4096}, // N + {4096}, // K + {8}, // Threads + {int64_t{false}}, // HasZeroPoint + }); +} + +static void LutGemmComputeArgs(benchmark::internal::Benchmark* b) { + b->ArgNames(lutgemm_compute_arg_names); + b->ArgsProduct({ + {128}, // BlkLen + {1, 32}, // M + {4096}, // N + {4096}, // K + {8}, // Threads + {int64_t{false}}, // HasZeroPoint + }); +} + +[[maybe_unused]] static const bool benchmarks_registered = []() { + const bool is_lutgemm_supported = MlasIsLutGemmAvailable(4096, 4096, 2, 128); + if (is_lutgemm_supported) { + BENCHMARK(LUTGEMM_PACK<2>)->Apply(LutGemmPackArgs)->UseRealTime(); + BENCHMARK(LUTGEMM_COMPUTE<2>)->Apply(LutGemmComputeArgs)->UseRealTime(); + return true; + } + return false; +}(); diff --git a/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp index 12ec5ec78f599..181fda23f299d 100644 --- a/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp @@ -175,7 +175,7 @@ class SQLutGemmShortExecuteTest : public MlasTestFixture Date: Mon, 26 Jan 2026 13:34:36 -0800 Subject: [PATCH 03/10] Fix AVX2 dispatch error handling and improve memory initialization in LUT GEMM functions --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 22 ++++++++++++++----- .../mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp | 2 ++ onnxruntime/test/mlas/bench/bench_lutgemm.cpp | 22 +++++++++---------- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index cd285be6dc78c..53215e81f74e8 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -165,8 +165,10 @@ LutGemmPackQuantBData( // LUT GEMM is only available for AVX2, so dispatch must be available const auto* Dispatch = GetMlasPlatform().LutGenKernel; - assert(Dispatch && Dispatch->PackQuantBData && "PackQuantBData requires AVX2 dispatch"); - + if (Dispatch == nullptr || Dispatch->PackQuantBData == nullptr) { + MLAS_THROW_EX(std::runtime_error, "PackQuantBData requires AVX2 dispatch"); + } + Dispatch->PackQuantBData( N, K, bits, g, ngroups_per_elem, simd_n_in, simd_n_out, bm, kfactor, @@ -212,8 +214,10 @@ LutPackScalesAndZeroPoints( // LUT GEMM is only available for AVX2, so dispatch must be available const auto* Dispatch = GetMlasPlatform().LutGenKernel; - assert(Dispatch && Dispatch->PackScalesAndZeroPoints && "PackScalesAndZeroPoints requires AVX2 dispatch"); - + if (Dispatch == nullptr || Dispatch->PackScalesAndZeroPoints == nullptr) { + MLAS_THROW_EX(std::runtime_error, "PackScalesAndZeroPoints requires AVX2 dispatch"); + } + Dispatch->PackScalesAndZeroPoints( N, K, bits, BlkLen, simd_n_out, bm, HasZeroPoint, PackedQuantBZPBegin, QuantBScale, QuantBZeroPoint, ThreadPool @@ -290,7 +294,11 @@ MlasIsLutGemmAvailable( ) { const auto* lut_kernel = GetMlasPlatform().LutGenKernel; - if (lut_kernel == nullptr || lut_kernel->GenerateLUT == nullptr || lut_kernel->ComputeGemm == nullptr) { + if (lut_kernel == nullptr || + lut_kernel->GenerateLUT == nullptr || + lut_kernel->ComputeGemm == nullptr || + lut_kernel->PackQuantBData == nullptr || + lut_kernel->PackScalesAndZeroPoints == nullptr) { return false; } @@ -359,7 +367,9 @@ MlasLutGemm( // adapted from ggml_backend_tmac_mul_mat const auto* Dispatch = GetMlasPlatform().LutGenKernel; // This should be ensured by calling MlasIsLutGemmAvailable() before MlasLutGemm() - assert(Dispatch && Dispatch->GenerateLUT && "TMAC not supported in this configuration."); + if (Dispatch == nullptr || Dispatch->GenerateLUT == nullptr || Dispatch->ComputeGemm == nullptr) { + MLAS_THROW_EX(std::runtime_error, "TMAC not supported in this configuration"); + } // Calculate scales offset from packed buffer // TODO(vraspar): support other bitwidths diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index d7e3ee404a12f..862431c6d2c42 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -692,6 +692,7 @@ PackQuantBData_avx2( // Phase 1: Bit-plane decomposition with grouping by g=4 // For 2-bit: each input byte has 4 elements, we extract bit planes and group 4 consecutive bits std::unique_ptr buf(new uint8_t[N * bits * K_div_g]); + memset(buf.get(), 0, N * bits * K_div_g); // Masks for 2-bit extraction const __m256i mask_2bit = _mm256_set1_epi8(0x03); // mask for 2-bit values @@ -808,6 +809,7 @@ PackQuantBData_avx2( const uint8_t shift = static_cast(new_ing * g); const size_t stride = c2_fac3_div; + assert(K_div_g % kfactor == 0 && "K_div_g must be divisible by kfactor"); for (size_t ik = 0; ik < K_div_g; ik += kfactor) { const size_t new_ik = ik / kfactor; const size_t base_k = base_im + new_ik * c2_fac1_div; diff --git a/onnxruntime/test/mlas/bench/bench_lutgemm.cpp b/onnxruntime/test/mlas/bench/bench_lutgemm.cpp index 72657fec1c4f1..890b16c85e610 100644 --- a/onnxruntime/test/mlas/bench/bench_lutgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_lutgemm.cpp @@ -155,23 +155,23 @@ void LUTGEMM_COMPUTE(benchmark::State& state) { static void LutGemmPackArgs(benchmark::internal::Benchmark* b) { b->ArgNames(lutgemm_bench_arg_names); b->ArgsProduct({ - {128}, // BlkLen - {4096}, // N - {4096}, // K - {8}, // Threads - {int64_t{false}}, // HasZeroPoint + {128}, // BlkLen + {4096}, // N + {4096}, // K + {8}, // Threads + {int64_t{false}}, // HasZeroPoint }); } static void LutGemmComputeArgs(benchmark::internal::Benchmark* b) { b->ArgNames(lutgemm_compute_arg_names); b->ArgsProduct({ - {128}, // BlkLen - {1, 32}, // M - {4096}, // N - {4096}, // K - {8}, // Threads - {int64_t{false}}, // HasZeroPoint + {128}, // BlkLen + {1, 32}, // M + {4096}, // N + {4096}, // K + {8}, // Threads + {int64_t{false}}, // HasZeroPoint }); } From 70cb824afd2599d3a9ae408c86479fce9f85ed6f Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 26 Jan 2026 15:40:06 -0800 Subject: [PATCH 04/10] PackQuantBData_avx2: Instead of entire buffer, zero-initialize only necessary tail bytes --- .../core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index 862431c6d2c42..a2b1b11b88c57 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -692,7 +692,6 @@ PackQuantBData_avx2( // Phase 1: Bit-plane decomposition with grouping by g=4 // For 2-bit: each input byte has 4 elements, we extract bit planes and group 4 consecutive bits std::unique_ptr buf(new uint8_t[N * bits * K_div_g]); - memset(buf.get(), 0, N * bits * K_div_g); // Masks for 2-bit extraction const __m256i mask_2bit = _mm256_set1_epi8(0x03); // mask for 2-bit values @@ -754,6 +753,15 @@ PackQuantBData_avx2( _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst_bit1 + ik / g), bit1_out); } + // Zero-initialize only the tail bytes that will be updated with "+=" + // to avoid touching the full buffer. + const size_t tail_new_ik = ik / g; + if (tail_new_ik < K_div_g) { + const size_t tail_len = K_div_g - tail_new_ik; + std::memset(dst_bit0 + tail_new_ik, 0, tail_len); + std::memset(dst_bit1 + tail_new_ik, 0, tail_len); + } + // Handle remaining elements with scalar code for (; ik < K; ++ik) { size_t idx = ik; From 6d1f6c1af5874d3b4b4e106a18d7a457c6a225bd Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 26 Jan 2026 16:53:38 -0800 Subject: [PATCH 05/10] Change parallelization in PackQuantBData_avx2 to prevent race conditions by processing tiles of input values. --- .../mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp | 133 +++++++++--------- 1 file changed, 70 insertions(+), 63 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index a2b1b11b88c57..15f32e6a5e845 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -791,74 +791,81 @@ PackQuantBData_avx2( memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); auto* packed_u8 = reinterpret_cast(PackedQuantBDataBegin); - // Phase 2: Parallelize over N - each thread handles all (bits, K/g) work for its assigned rows - // This ensures no write conflicts since each im writes to disjoint output regions + // Phase 2: Parallelize over tiles of im values that share output bytes. + // Consecutive im0 values (ngroups_per_elem of them) write to the same output bytes + // with different shifts, so they must be processed by the same thread to avoid races. + const size_t im_per_tile = ngroups_per_elem * simd_n_out; + const size_t num_tiles = (N + im_per_tile - 1) / im_per_tile; MlasTrySimpleParallel( - ThreadPool, static_cast(N), + ThreadPool, static_cast(num_tiles), [&](ptrdiff_t tid) { - size_t im = static_cast(tid); - const size_t im0 = im / simd_n_out; - const size_t isno = im - im0 * simd_n_out; - const size_t x_base = simd_n_out * (im0 * bits) + isno; - - for (size_t ib = 0; ib < bits; ib++) { - const size_t x = x_base + ib * simd_n_out; - const size_t new_im1 = x / mgroup; - const size_t y = x - new_im1 * mgroup; - const size_t new_ing = y / simd_n_in; - const size_t new_isni = y - new_ing * simd_n_in; - - const size_t new_im2 = new_im1 / bm_div_mgroup; - const size_t new_ibm = new_im1 - new_im2 * bm_div_mgroup; - - const size_t base_im = new_im2 * c2_fac0_div + new_ibm * c2_fac2_div + new_isni; - const size_t buf_base = im * bits * K_div_g + ib * K_div_g; - - const uint8_t shift = static_cast(new_ing * g); - const size_t stride = c2_fac3_div; - - assert(K_div_g % kfactor == 0 && "K_div_g must be divisible by kfactor"); - for (size_t ik = 0; ik < K_div_g; ik += kfactor) { - const size_t new_ik = ik / kfactor; - const size_t base_k = base_im + new_ik * c2_fac1_div; - const size_t buf_k = buf_base + ik; - - uint8_t* dst = packed_u8 + base_k; - const uint8_t* src = buf.get() + buf_k; - - if (kfactor == 8) { - dst[stride * 0] = static_cast(dst[stride * 0] + (src[0] << shift)); - dst[stride * 1] = static_cast(dst[stride * 1] + (src[1] << shift)); - dst[stride * 2] = static_cast(dst[stride * 2] + (src[2] << shift)); - dst[stride * 3] = static_cast(dst[stride * 3] + (src[3] << shift)); - dst[stride * 4] = static_cast(dst[stride * 4] + (src[4] << shift)); - dst[stride * 5] = static_cast(dst[stride * 5] + (src[5] << shift)); - dst[stride * 6] = static_cast(dst[stride * 6] + (src[6] << shift)); - dst[stride * 7] = static_cast(dst[stride * 7] + (src[7] << shift)); - } else if (kfactor == 16) { - dst[stride * 0] = static_cast(dst[stride * 0] + (src[0] << shift)); - dst[stride * 1] = static_cast(dst[stride * 1] + (src[1] << shift)); - dst[stride * 2] = static_cast(dst[stride * 2] + (src[2] << shift)); - dst[stride * 3] = static_cast(dst[stride * 3] + (src[3] << shift)); - dst[stride * 4] = static_cast(dst[stride * 4] + (src[4] << shift)); - dst[stride * 5] = static_cast(dst[stride * 5] + (src[5] << shift)); - dst[stride * 6] = static_cast(dst[stride * 6] + (src[6] << shift)); - dst[stride * 7] = static_cast(dst[stride * 7] + (src[7] << shift)); - dst[stride * 8] = static_cast(dst[stride * 8] + (src[8] << shift)); - dst[stride * 9] = static_cast(dst[stride * 9] + (src[9] << shift)); - dst[stride * 10] = static_cast(dst[stride * 10] + (src[10] << shift)); - dst[stride * 11] = static_cast(dst[stride * 11] + (src[11] << shift)); - dst[stride * 12] = static_cast(dst[stride * 12] + (src[12] << shift)); - dst[stride * 13] = static_cast(dst[stride * 13] + (src[13] << shift)); - dst[stride * 14] = static_cast(dst[stride * 14] + (src[14] << shift)); - dst[stride * 15] = static_cast(dst[stride * 15] + (src[15] << shift)); - } else { - for (size_t ikf = 0; ikf < kfactor; ikf++) { - dst[stride * ikf] = static_cast(dst[stride * ikf] + (src[ikf] << shift)); + const size_t im_start = static_cast(tid) * im_per_tile; + const size_t im_end = std::min(im_start + im_per_tile, N); + + for (size_t im = im_start; im < im_end; ++im) { + const size_t im0 = im / simd_n_out; + const size_t isno = im - im0 * simd_n_out; + const size_t x_base = simd_n_out * (im0 * bits) + isno; + + for (size_t ib = 0; ib < bits; ib++) { + const size_t x = x_base + ib * simd_n_out; + const size_t new_im1 = x / mgroup; + const size_t y = x - new_im1 * mgroup; + const size_t new_ing = y / simd_n_in; + const size_t new_isni = y - new_ing * simd_n_in; + + const size_t new_im2 = new_im1 / bm_div_mgroup; + const size_t new_ibm = new_im1 - new_im2 * bm_div_mgroup; + + const size_t base_im = new_im2 * c2_fac0_div + new_ibm * c2_fac2_div + new_isni; + const size_t buf_base = im * bits * K_div_g + ib * K_div_g; + + const uint8_t shift = static_cast(new_ing * g); + const size_t stride = c2_fac3_div; + + assert(K_div_g % kfactor == 0 && "K_div_g must be divisible by kfactor"); + for (size_t ik = 0; ik < K_div_g; ik += kfactor) { + const size_t new_ik = ik / kfactor; + const size_t base_k = base_im + new_ik * c2_fac1_div; + const size_t buf_k = buf_base + ik; + + uint8_t* dst = packed_u8 + base_k; + const uint8_t* src = buf.get() + buf_k; + + if (kfactor == 8) { + dst[stride * 0] = static_cast(dst[stride * 0] + (src[0] << shift)); + dst[stride * 1] = static_cast(dst[stride * 1] + (src[1] << shift)); + dst[stride * 2] = static_cast(dst[stride * 2] + (src[2] << shift)); + dst[stride * 3] = static_cast(dst[stride * 3] + (src[3] << shift)); + dst[stride * 4] = static_cast(dst[stride * 4] + (src[4] << shift)); + dst[stride * 5] = static_cast(dst[stride * 5] + (src[5] << shift)); + dst[stride * 6] = static_cast(dst[stride * 6] + (src[6] << shift)); + dst[stride * 7] = static_cast(dst[stride * 7] + (src[7] << shift)); + } else if (kfactor == 16) { + dst[stride * 0] = static_cast(dst[stride * 0] + (src[0] << shift)); + dst[stride * 1] = static_cast(dst[stride * 1] + (src[1] << shift)); + dst[stride * 2] = static_cast(dst[stride * 2] + (src[2] << shift)); + dst[stride * 3] = static_cast(dst[stride * 3] + (src[3] << shift)); + dst[stride * 4] = static_cast(dst[stride * 4] + (src[4] << shift)); + dst[stride * 5] = static_cast(dst[stride * 5] + (src[5] << shift)); + dst[stride * 6] = static_cast(dst[stride * 6] + (src[6] << shift)); + dst[stride * 7] = static_cast(dst[stride * 7] + (src[7] << shift)); + dst[stride * 8] = static_cast(dst[stride * 8] + (src[8] << shift)); + dst[stride * 9] = static_cast(dst[stride * 9] + (src[9] << shift)); + dst[stride * 10] = static_cast(dst[stride * 10] + (src[10] << shift)); + dst[stride * 11] = static_cast(dst[stride * 11] + (src[11] << shift)); + dst[stride * 12] = static_cast(dst[stride * 12] + (src[12] << shift)); + dst[stride * 13] = static_cast(dst[stride * 13] + (src[13] << shift)); + dst[stride * 14] = static_cast(dst[stride * 14] + (src[14] << shift)); + dst[stride * 15] = static_cast(dst[stride * 15] + (src[15] << shift)); + } else { + for (size_t ikf = 0; ikf < kfactor; ikf++) { + dst[stride * ikf] = static_cast(dst[stride * ikf] + (src[ikf] << shift)); + } } } } - } + } // end for im } ); } From d213a75fe09bfcdaab1ee1ef3bd381d269682d38 Mon Sep 17 00:00:00 2001 From: vraspar Date: Tue, 3 Feb 2026 11:16:43 -0800 Subject: [PATCH 06/10] Update comment Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index 53215e81f74e8..efc7a17763e9f 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -163,10 +163,10 @@ LutGemmPackQuantBData( const size_t bm = tmac_params.bm; const size_t kfactor = tmac_params.kfactor; - // LUT GEMM is only available for AVX2, so dispatch must be available + // LUT GEMM requires a valid LUT dispatch implementation, so dispatch must be available const auto* Dispatch = GetMlasPlatform().LutGenKernel; if (Dispatch == nullptr || Dispatch->PackQuantBData == nullptr) { - MLAS_THROW_EX(std::runtime_error, "PackQuantBData requires AVX2 dispatch"); + MLAS_THROW_EX(std::runtime_error, "PackQuantBData requires LUT GEMM dispatch support"); } Dispatch->PackQuantBData( From a3ef325145af57ac9129525d2be191c59866c3fd Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 4 Feb 2026 19:23:33 +0000 Subject: [PATCH 07/10] Cleanup --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 45 +++++++++++---- .../mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp | 56 +++++++++---------- 2 files changed, 62 insertions(+), 39 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index efc7a17763e9f..73e185075cc19 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -22,20 +22,47 @@ Module Name: #include #include +#include #include -#include #include #include +/** T-MAC GEMM kernel config key - struct-based for type safety and performance */ +struct TMACConfigKey { + size_t M; + size_t N; + size_t nbits; + size_t block_size; + bool has_zero_point; + + bool operator==(const TMACConfigKey& other) const { + return M == other.M && N == other.N && nbits == other.nbits && + block_size == other.block_size && has_zero_point == other.has_zero_point; + } +}; + +struct TMACConfigKeyHash { + size_t operator()(const TMACConfigKey& k) const { + // Combine hash values using a simple mixing function + size_t h = std::hash{}(k.M); + h ^= std::hash{}(k.N) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(k.nbits) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(k.block_size) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(k.has_zero_point) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; + } +}; + /** T-MAC GEMM kernel Config */ -static std::unordered_map tmac_kernel_configs; +static std::unordered_map tmac_kernel_configs; const MlasTMACKernelParams& MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point) { - std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0"); - if (tmac_kernel_configs.count(key)) { - return tmac_kernel_configs[key]; + TMACConfigKey key{M, N, nbits, block_size, has_zero_point}; + auto it = tmac_kernel_configs.find(key); + if (it != tmac_kernel_configs.end()) { + return it->second; } MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized"); } @@ -49,7 +76,7 @@ MlasClearLutGemmKernelConfig() void MLASCALL MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point) { - std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0"); + TMACConfigKey key{M, N, nbits, block_size, has_zero_point}; if (tmac_kernel_configs.count(key)) { return; } @@ -489,10 +516,8 @@ MlasLutGemm( size_t scales_size_per_tile = 0; if (scales_size_total % n_tiles_num != 0) { - // Sanity: scales should partition evenly across tiles. If they don't, choose floor division - // and document that callers must layout scales accordingly. - // Prefer to error loudly in debug builds. - fprintf(stderr, "Warning: scales_size_total=%zu is not divisible by n_tiles_num=%zu; using floor division.\n", scales_size_total, n_tiles_num); + // Scales must partition evenly across tiles. Callers must ensure proper layout. + MLAS_THROW_EX(std::runtime_error, "scales_size_total must be divisible by n_tiles_num"); } scales_size_per_tile = scales_size_total / n_tiles_num; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index 15f32e6a5e845..cc8308675e399 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -122,6 +122,7 @@ struct SignedHalvingAdder<2> { template struct SignedWideningAdder { + static_assert(N > 0, "N parameter exists for API compatibility with SignedHalvingAdder"); __m256i lhs_low = _mm256_setzero_si256(); __m256i lhs_high = _mm256_setzero_si256(); @@ -494,13 +495,6 @@ tbl_g4_int8_float_update_impl(int32_t m, float* c, const int8_t* lut, const uint return 0; } -int32_t -tbl_int32_reset(int32_t m, int32_t* c) -{ - memset(c, 0, m * sizeof(int32_t)); - return 0; -} - // based on qgemm_lut_int8_g4 // Simplified version with hardcoded configuration for 2-bit quantization void @@ -551,19 +545,14 @@ TMACComputeGemm_avx2( assert(K % (kfactor * g) == 0); assert(BlkLen % g == 0); - // Validate configuration - assert(bm % bits == 0); - assert(K % (kfactor * g) == 0); - assert(BlkLen % g == 0); - // ==================== ALLOCATE BUFFERS ==================== - // Use float for now (can be changed to _Float16 if needed) + // Use unique_ptr for exception safety (RAII ensures cleanup on all exit paths) - float* CBits = new float[bm]; - float* C_global = new float[m]; + std::unique_ptr CBits(new float[bm]); + std::unique_ptr C_global(new float[m]); // Reset accumulator buffer to zero - tbl_int32_reset(bm * sizeof(float) / sizeof(int32_t), reinterpret_cast(CBits)); + std::memset(CBits.get(), 0, bm * sizeof(float)); // ==================== CALCULATE LOOP PARAMETERS ==================== const int32_t k_outer_max = K / (kfactor * g); @@ -608,43 +597,43 @@ TMACComputeGemm_avx2( // For standard 2-bit, kfactor=16, BlkLen=64: actk = 64/4 = 16 if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && has_zero_point && !one_scale) { tbl_g4_int8_float_update_impl( - static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases ); } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && !one_scale) { tbl_g4_int8_float_update_impl( - static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases ); } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && one_scale) { tbl_g4_int8_float_update_impl( - static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases ); } // actk == 8 variants (for BlkLen=32) else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { tbl_g4_int8_float_update_impl( - static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases ); } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { tbl_g4_int8_float_update_impl( - static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases ); } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { tbl_g4_int8_float_update_impl( - static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases ); } // kfactor == 8 variants else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { tbl_g4_int8_float_update_impl( - static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases ); } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { tbl_g4_int8_float_update_impl( - static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases ); } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { tbl_g4_int8_float_update_impl( - static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases ); } else { // No matching kernel template found @@ -656,11 +645,9 @@ TMACComputeGemm_avx2( // Gather bit-plane results into final output // Only support 2-bit in this implementation // TODO(vraspar): extend to other bit-widths - tbl_g4_int8_float_gather_bit2_impl(m, C_global, CBits, C); + tbl_g4_int8_float_gather_bit2_impl(m, C_global.get(), CBits.get(), C); - // ==================== CLEANUP ==================== - delete[] C_global; - delete[] CBits; + // unique_ptr automatically handles cleanup via RAII } // @@ -794,6 +781,12 @@ PackQuantBData_avx2( // Phase 2: Parallelize over tiles of im values that share output bytes. // Consecutive im0 values (ngroups_per_elem of them) write to the same output bytes // with different shifts, so they must be processed by the same thread to avoid races. + // + // Thread-safety invariant: Each tile k maps exclusively to output indices where + // new_im1 = k. For tile k processing im ∈ [k*im_per_tile, (k+1)*im_per_tile), + // x = simd_n_out * (im0 * bits) + isno + ib * simd_n_out ranges over [32k, 32k+31], + // so new_im1 = x / mgroup = x / 32 = k (since mgroup = 32). Different tiles write + // to disjoint new_im1 values, ensuring no data races between parallel threads. const size_t im_per_tile = ngroups_per_elem * simd_n_out; const size_t num_tiles = (N + im_per_tile - 1) / im_per_tile; MlasTrySimpleParallel( @@ -889,6 +882,11 @@ PackScalesAndZeroPoints_avx2_impl( MLAS_THREADPOOL* ThreadPool ) { + // Validate that QuantBZeroPoint is provided when HasZeroPoint is true + if constexpr (HasZeroPoint) { + assert(QuantBZeroPoint != nullptr && "QuantBZeroPoint must not be null when HasZeroPoint is true"); + } + const size_t num_elem_per_byte = 8 / bits; // 4 for 2-bit const size_t row_blks = K / BlkLen; // number of blocks per column const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte; From 76951ba6c807a4753179bd63444af58083c769c9 Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 4 Feb 2026 22:22:59 +0000 Subject: [PATCH 08/10] Init neon kernel for lut 2 bit gemm --- cmake/onnxruntime_mlas.cmake | 4 + onnxruntime/core/mlas/lib/mlasi.h | 4 + onnxruntime/core/mlas/lib/platform.cpp | 3 + .../mlas/lib/sqnbitgemm_lut_kernel_avx2.h | 39 +- .../mlas/lib/sqnbitgemm_lut_kernel_neon.cpp | 966 ++++++++++++++++++ .../mlas/lib/sqnbitgemm_lut_kernel_neon.h | 27 + 6 files changed, 1015 insertions(+), 28 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index d7dcde945e6d7..8c97dfbfbc1da 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -101,6 +101,8 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.h + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.cpp ${MLAS_SRC_DIR}/cast_kernel_neon.cpp ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h @@ -470,6 +472,8 @@ else() ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.h + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index e75ca3dc90e60..d6299af5d8f80 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1245,6 +1245,10 @@ struct MLAS_QNBIT_LUT_GEMM_DISPATCH; extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2; +#if defined(MLAS_TARGET_ARM64) +extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon; +#endif + // // Rotary embedding dispatch structure. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index b913b1c3b8c26..b7242dd4704db 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -654,6 +654,9 @@ Return Value: this->ArmNeonIsQuantActivationsUnsigned = HasI8MMInstructions ? false : true; this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, HasI8MMInstructions); + // Enable LUT-based GEMM for 2-bit quantization on ARM64 + this->LutGenKernel = &MlasLutGenKernelNeon; + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon; this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelNeon; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h index e66eec6fd67ea..e900080403d51 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h @@ -10,34 +10,17 @@ Module Name: Abstract: - This module implements x64 AVX2 kernel functions for LUT-based n-bit - quantized integer matrix multiplication. + This module contains the dispatch table declaration for x64 AVX2 + LUT-based n-bit quantized integer matrix multiplication kernels. + --*/ #pragma once -#include "qnbitgemm.h" - -void -GenerateLUT_avx2( - int32_t group_size, - int8_t lut, - const float* b, - float* scales, - float* biases, - int K -); - -void -TMACComputeGemm_avx2( - const void* A, - const void* a_scales, - const void* LUT, - const void* LUT_Scales, - const void* LUT_Biases, - void* C, - int bm, - int K, - int M, - int N, - size_t BlkLen -); + +#include "qlutgemm.h" + +// +// External dispatch table for AVX2 LUT GEMM kernels. +// Kernel functions are internal to the .cpp file and accessed via this dispatch. +// +extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp new file mode 100644 index 0000000000000..7732969314f52 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp @@ -0,0 +1,966 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_lut_kernel_neon.cpp + +Abstract: + + This module implements ARM64 NEON kernel functions for LUT-based quantized + n-bit integer matrix multiplication. + + It provides optimized ARM NEON implementations for lookup table generation, + GEMM computation, and related operations on quantized weight and activation + matrices. + + Inspired by T-MAC implementation in llama.cpp (https://github.com/microsoft/T-MAC) + +--*/ + +#if defined(MLAS_TARGET_ARM64) + +#include + +#include +#include +#include +#include +#include +#include + +#include "mlasi.h" +#include "qlutgemm.h" +#include "qnbitgemm.h" +#include "sqnbitgemm_q8_block.h" + +// Conditional pragma unroll for compiler compatibility +#if defined(__clang__) +#define PRAGMA_UNROLL _Pragma("unroll") +#else +#define PRAGMA_UNROLL +#endif + +// +// Template classes for accumulation - adapted from llama.cpp tbl.cpp +// + +// Fast aggregation using halving add (vrhaddq_s8) +// Used when ActK is a power of 2 for faster accumulation +template +struct SignedHalvingAdder { + SignedHalvingAdder adder; + int8x16_t lhs; + + inline void push(int8x16_t v, int k) + { + if (k < N / 2) { + adder.push(v, k); + if (k == N / 2 - 1) { + lhs = adder.get(); + } + } else { + adder.push(v, k - N / 2); + if (k == N - 1) { + lhs = vrhaddq_s8(lhs, adder.get()); + } + } + } + + inline int8x16_t get() + { + return lhs; + } + + inline int16x8_t get_low() + { + return vmovl_s8(vget_low_s8(lhs)); + } + + inline int16x8_t get_high() + { + return vmovl_high_s8(lhs); + } +}; + +template <> +struct SignedHalvingAdder<2> { + int8x16_t lhs; + + inline void push(int8x16_t v, int k) + { + if (k == 0) { + lhs = v; + } else { + lhs = vrhaddq_s8(lhs, v); + } + } + + inline int8x16_t get() + { + return lhs; + } + + inline int16x8_t get_low() + { + return vmovl_s8(vget_low_s8(lhs)); + } + + inline int16x8_t get_high() + { + return vmovl_high_s8(lhs); + } +}; + +// Widening adder for accuracy (no fast aggregation) +// Used when precision is more important than speed +template +struct SignedWideningAdder { + static_assert(N > 0, "N parameter exists for API compatibility with SignedHalvingAdder"); + int16x8_t lhs_low = vdupq_n_s16(0); + int16x8_t lhs_high = vdupq_n_s16(0); + + inline void push(int8x16_t v, int k) + { + if (k == 0) { + lhs_low = vmovl_s8(vget_low_s8(v)); + lhs_high = vmovl_high_s8(v); + } else { + lhs_low = vaddq_s16(lhs_low, vmovl_s8(vget_low_s8(v))); + lhs_high = vaddq_s16(lhs_high, vmovl_high_s8(v)); + } + } + + inline int16x8_t get_low() + { + return lhs_low; + } + + inline int16x8_t get_high() + { + return lhs_high; + } +}; + +template +using SignedAdder = typename std::conditional, SignedWideningAdder>::type; + +// Template for computing log2 at compile time +template +struct mylog2 { + enum { + value = 1 + mylog2::value + }; +}; + +template <> +struct mylog2<0> { + enum { + value = -1 + }; +}; + +// Template for computing bias scale at compile time +template +constexpr int +get_bias_scale() +{ + // The bias scale will be added to the first bit + // 15 = (1/2 + 1 + 2 + 4) / (1/2) + // 7 = (1/2 + 1 + 2) / (1/2) + // 3 = (1/2 + 1) / (1/2) + // 1 = (1/2) / (1/2) + return 3; // For 2-bit quantization +} + +// +// Partial max computation for LUT scale calculation +// +static inline void +partial_max_g4_int8_k8_neon(float* lut_scales, const float* b) +{ + // Process 8 groups of 4 floats each (strided by 4) + float32x4_t max_abs = vdupq_n_f32(0.0f); + + for (int i = 0; i < 8; i++) { + // Load 4 consecutive floats from position i*4 + float32x4_t vals = vld1q_f32(b + i * 4); + float32x4_t abs_vals = vabsq_f32(vals); + max_abs = vmaxq_f32(max_abs, abs_vals); + } + + // Horizontal max across the vector + float max_val = vmaxvq_f32(max_abs); + float scales = max_val / 127.0f; + *lut_scales = std::max(*lut_scales, scales); +} + +// +// LUT construction for int8 quantized activations +// Builds 16-entry lookup tables for groups of 4 activation values +// +static inline void +lut_ctor_g4_int8_impl_neon( + int32_t act_k, + int8_t* qlut, + const float* b, + float* lut_scales, + float* lut_biases +) +{ + float32x4_t vec_lut[16]; + float biases = 0.0f; + float scales = *lut_scales; + float t_scales = scales ? 1.0f / scales : 0.0f; + + for (int k = 0; k < act_k / 32; ++k) { + // Load 4 groups of 8 floats (strided pattern) + // ORT uses contiguous float layout, so we load and rearrange + float32x4_t vec_b0, vec_b1, vec_b2, vec_b3; + + // Load first 4 elements from each group of 4 + // Pattern: b[k*32 + i*4 + j] where i=0..7, j=0..3 + // We need vec_b0 = {b[0], b[4], b[8], b[12], b[16], b[20], b[24], b[28]} etc. + // For NEON with float32, we work with 4 elements at a time + + // Simplified: process 4 lanes at a time + for (int lane = 0; lane < 2; lane++) { + const float* base = b + k * 32 + lane * 16; + + // Load 4 values with stride 4 + float b0_vals[4] = {base[0], base[4], base[8], base[12]}; + float b1_vals[4] = {base[1], base[5], base[9], base[13]}; + float b2_vals[4] = {base[2], base[6], base[10], base[14]}; + float b3_vals[4] = {base[3], base[7], base[11], base[15]}; + + vec_b0 = vld1q_f32(b0_vals); + vec_b1 = vld1q_f32(b1_vals); + vec_b2 = vld1q_f32(b2_vals); + vec_b3 = vld1q_f32(b3_vals); + + // Build 16-entry LUT: each entry is ±b0 ±b1 ±b2 ±b3 + PRAGMA_UNROLL + for (int g = 1; g < 16; g += 2) { + vec_lut[g] = vec_b0; + if (g & 0b0010) { + vec_lut[g] = vaddq_f32(vec_lut[g], vec_b1); + } else { + vec_lut[g] = vsubq_f32(vec_lut[g], vec_b1); + } + if (g & 0b0100) { + vec_lut[g] = vaddq_f32(vec_lut[g], vec_b2); + } else { + vec_lut[g] = vsubq_f32(vec_lut[g], vec_b2); + } + if (g & 0b1000) { + vec_lut[g] = vaddq_f32(vec_lut[g], vec_b3); + } else { + vec_lut[g] = vsubq_f32(vec_lut[g], vec_b3); + } + } + + // Symmetric: vec_lut[g] = -vec_lut[15 - g] + PRAGMA_UNROLL + for (int g = 0; g < 16; g += 2) { + vec_lut[g] = vnegq_f32(vec_lut[15 - g]); + } + + // Accumulate bias + biases += vaddvq_f32(vec_lut[0]); + + // Scale and quantize + PRAGMA_UNROLL + for (int g = 0; g < 16; ++g) { + vec_lut[g] = vmulq_n_f32(vec_lut[g], t_scales); + } + + // Convert to int8 and store + int8_t* qlut_dst = qlut + k * 128 + lane * 64; // 8 * 16 / 2 = 64 + + PRAGMA_UNROLL + for (int g = 0; g < 16; ++g) { + // Round and convert to int32 + int32x4_t i32 = vcvtnq_s32_f32(vec_lut[g]); + // Narrow to int16 + int16x4_t i16 = vqmovn_s32(i32); + // Narrow to int8 + int8x8_t i8 = vqmovn_s16(vcombine_s16(i16, i16)); + + // Store individual lanes with proper layout + qlut_dst[g + 0 * 16] = vget_lane_s8(i8, 0); + qlut_dst[g + 1 * 16] = vget_lane_s8(i8, 1); + qlut_dst[g + 2 * 16] = vget_lane_s8(i8, 2); + qlut_dst[g + 3 * 16] = vget_lane_s8(i8, 3); + } + } + } + + *lut_scales = scales; + *lut_biases = biases; +} + +// +// GenerateLUT - Entry point for LUT generation +// +static void +GenerateLUT_neon( + const float* b, + int8_t* qlut, + float* lut_scales, + float* lut_biases, + size_t M, + size_t K, + size_t N, + size_t act_group_size +) +{ + (void)M; // silence unused parameter warning + (void)N; // silence unused parameter warning + + const int32_t kk_outer_max = static_cast(K / act_group_size); + const int32_t ags_div32 = static_cast(act_group_size / 32); + + // Phase 1: Compute partial max for each activation group + for (int32_t kk_outer = 0; kk_outer < kk_outer_max; ++kk_outer) { + lut_scales[kk_outer] = 0.0f; + for (int32_t k_outer = 0; k_outer < ags_div32; ++k_outer) { + partial_max_g4_int8_k8_neon(&lut_scales[kk_outer], &b[(kk_outer * act_group_size) + (k_outer * 32)]); + } + } + + // Phase 2: Build quantized LUT + for (int32_t k_outer_1 = 0; k_outer_1 < kk_outer_max; ++k_outer_1) { + lut_ctor_g4_int8_impl_neon( + static_cast(act_group_size), + &qlut[k_outer_1 * act_group_size * 4], + &b[k_outer_1 * act_group_size], + &lut_scales[k_outer_1], + &lut_biases[k_outer_1] + ); + } +} + +// +// Bit gathering for 2-bit results +// +static inline void +tbl_g4_int8_float_gather_bit2_impl_neon(int32_t m, float* C_global, float* CBits, float* C) +{ + constexpr int32_t bits = 2; + + int32_t m_c_outer_max = m / 32; + for (int32_t m_c_outer = 0; m_c_outer < m_c_outer_max; ++m_c_outer) { + int32_t cse_var_2 = m_c_outer * 32 * bits; + int32_t cse_var_1 = m_c_outer * 32; + + PRAGMA_UNROLL + for (int32_t m_c_inner = 0; m_c_inner < 32; ++m_c_inner) { + int32_t bit_offset_0 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8); + int32_t bit_offset_1 = bit_offset_0 + 8; + C_global[cse_var_1 + m_c_inner] = + (CBits[cse_var_2 + bit_offset_0] * 0.5f) + CBits[cse_var_2 + bit_offset_1]; + } + } + + // Copy to output + for (int32_t m_inner_outer = 0; m_inner_outer < m_c_outer_max; ++m_inner_outer) { + PRAGMA_UNROLL + for (int32_t m_inner = 0; m_inner < 32; ++m_inner) { + int offset = m_inner_outer * 32 + m_inner; + C[offset] = C_global[offset]; + } + } +} + +// +// Core GEMM compute kernel using table lookup +// +template +inline int32_t +tbl_g4_int8_float_update_impl_neon( + int32_t m, + float* c, + const int8_t* lut, + const uint8_t* a, + const float* scales, + const float* lut_scales, + const float* lut_biases +) +{ + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + int8x16_t vec_lut[K]; + + // Load LUT tables + PRAGMA_UNROLL + for (int k = 0; k < K; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + + SignedAdder adder_bot, adder_top; + + for (int i = 0; i < m / 2; i += 16) { + float32x4_t vec_c0 = vdupq_n_f32(0.0f); + float32x4_t vec_c1 = vdupq_n_f32(0.0f); + float32x4_t vec_c2 = vdupq_n_f32(0.0f); + float32x4_t vec_c3 = vdupq_n_f32(0.0f); + float32x4_t vec_c4 = vdupq_n_f32(0.0f); + float32x4_t vec_c5 = vdupq_n_f32(0.0f); + float32x4_t vec_c6 = vdupq_n_f32(0.0f); + float32x4_t vec_c7 = vdupq_n_f32(0.0f); + + float partial_sum = 0.0f; + + PRAGMA_UNROLL + for (int kk = 0; kk < K; kk += ActK) { + PRAGMA_UNROLL + for (int k = 0; k < ActK; k++) { + // Load packed 4-bit indices + uint8x16_t vec_as = vld1q_u8(a + i * K + (kk + k) * 16); + + // Extract nibbles + uint8x16_t vec_a_bot = vandq_u8(vec_as, vec_mask); // Lower 4 bits + uint8x16_t vec_a_top = vshrq_n_u8(vec_as, 4); // Upper 4 bits + + // TABLE LOOKUP - THE KEY OPERATION + int8x16_t vec_v_bot = vqtbl1q_s8(vec_lut[kk + k], vreinterpretq_u8_s8(vreinterpretq_s8_u8(vec_a_bot))); + int8x16_t vec_v_top = vqtbl1q_s8(vec_lut[kk + k], vreinterpretq_u8_s8(vreinterpretq_s8_u8(vec_a_top))); + + adder_bot.push(vec_v_bot, k); + adder_top.push(vec_v_top, k); + } + + // Widen to int16 + int16x8_t vec_v_bot_low = adder_bot.get_low(); + int16x8_t vec_v_bot_high = adder_bot.get_high(); + int16x8_t vec_v_top_low = adder_top.get_low(); + int16x8_t vec_v_top_high = adder_top.get_high(); + + // Convert to float32 (need to widen int16 -> int32 -> float32) + float32x4_t vec_v_bot_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec_v_bot_low))); + float32x4_t vec_v_bot_low_high = vcvtq_f32_s32(vmovl_high_s16(vec_v_bot_low)); + float32x4_t vec_v_bot_high_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec_v_bot_high))); + float32x4_t vec_v_bot_high_high = vcvtq_f32_s32(vmovl_high_s16(vec_v_bot_high)); + float32x4_t vec_v_top_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec_v_top_low))); + float32x4_t vec_v_top_low_high = vcvtq_f32_s32(vmovl_high_s16(vec_v_top_low)); + float32x4_t vec_v_top_high_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec_v_top_high))); + float32x4_t vec_v_top_high_high = vcvtq_f32_s32(vmovl_high_s16(vec_v_top_high)); + + float lut_s = lut_scales[kk / ActK]; + float lut_b = lut_biases[kk / ActK]; + + if (ZeroPoint) { + partial_sum += lut_b; + } + + if (FastAggregation) { + lut_s = lut_s * ActK; + lut_b -= lut_s * (mylog2::value / 4 * get_bias_scale()); + } + + // FMA operations with conditional bias +#define lut_fma(vs, ib) \ + (((ib) % Bits) ? vmulq_n_f32((vs), lut_s) : vmlaq_n_f32(vdupq_n_f32(lut_b), (vs), lut_s)) + + if (kk == 0) { + vec_c0 = lut_fma(vec_v_bot_low_low, (i / 4)); + vec_c1 = lut_fma(vec_v_bot_low_high, (i / 4 + 1)); + vec_c2 = lut_fma(vec_v_bot_high_low, (i / 4 + 2)); + vec_c3 = lut_fma(vec_v_bot_high_high, (i / 4 + 3)); + vec_c4 = lut_fma(vec_v_top_low_low, (i / 4 + 4)); + vec_c5 = lut_fma(vec_v_top_low_high, (i / 4 + 5)); + vec_c6 = lut_fma(vec_v_top_high_low, (i / 4 + 6)); + vec_c7 = lut_fma(vec_v_top_high_high, (i / 4 + 7)); + } else { + vec_c0 = vaddq_f32(vec_c0, lut_fma(vec_v_bot_low_low, (i / 4))); + vec_c1 = vaddq_f32(vec_c1, lut_fma(vec_v_bot_low_high, (i / 4 + 1))); + vec_c2 = vaddq_f32(vec_c2, lut_fma(vec_v_bot_high_low, (i / 4 + 2))); + vec_c3 = vaddq_f32(vec_c3, lut_fma(vec_v_bot_high_high, (i / 4 + 3))); + vec_c4 = vaddq_f32(vec_c4, lut_fma(vec_v_top_low_low, (i / 4 + 4))); + vec_c5 = vaddq_f32(vec_c5, lut_fma(vec_v_top_low_high, (i / 4 + 5))); + vec_c6 = vaddq_f32(vec_c6, lut_fma(vec_v_top_high_low, (i / 4 + 6))); + vec_c7 = vaddq_f32(vec_c7, lut_fma(vec_v_top_high_high, (i / 4 + 7))); + } +#undef lut_fma + } + + // Apply weight scales and store + if (ZeroPoint) { + float32x4_t vec_s0 = vld1q_f32(scales + ((i / 4) / Bits) * 16); + float32x4_t vec_s1 = vld1q_f32(scales + ((i / 4 + 1) / Bits) * 16); + float32x4_t vec_s2 = vld1q_f32(scales + ((i / 4 + 2) / Bits) * 16 + 4); + float32x4_t vec_s3 = vld1q_f32(scales + ((i / 4 + 3) / Bits) * 16 + 4); + + vec_c0 = vfmaq_f32(vld1q_f32(c + i * 2), vec_c0, vec_s0); + vec_c1 = vfmaq_f32(vld1q_f32(c + i * 2 + 4), vec_c1, vec_s1); + vec_c2 = vfmaq_f32(vld1q_f32(c + i * 2 + 8), vec_c2, vec_s2); + vec_c3 = vfmaq_f32(vld1q_f32(c + i * 2 + 12), vec_c3, vec_s3); + + float32x4_t vec_z0 = vld1q_f32(scales + ((i / 4) / Bits) * 16 + 8); + float32x4_t vec_z1 = vld1q_f32(scales + ((i / 4 + 1) / Bits) * 16 + 8); + float32x4_t vec_z2 = vld1q_f32(scales + ((i / 4 + 2) / Bits) * 16 + 12); + float32x4_t vec_z3 = vld1q_f32(scales + ((i / 4 + 3) / Bits) * 16 + 12); + + partial_sum *= 2; + +#define add_zero(cs, zs, ib) \ + (((ib) % Bits) ? (cs) : vfmaq_n_f32((cs), (zs), partial_sum)) + + vst1q_f32(c + i * 2, add_zero(vec_c0, vec_z0, (i / 4))); + vst1q_f32(c + i * 2 + 4, add_zero(vec_c1, vec_z1, (i / 4 + 1))); + vst1q_f32(c + i * 2 + 8, add_zero(vec_c2, vec_z2, (i / 4 + 2))); + vst1q_f32(c + i * 2 + 12, add_zero(vec_c3, vec_z3, (i / 4 + 3))); + vst1q_f32(c + i * 2 + 16, vec_c4); + vst1q_f32(c + i * 2 + 20, vec_c5); + vst1q_f32(c + i * 2 + 24, vec_c6); + vst1q_f32(c + i * 2 + 28, vec_c7); +#undef add_zero + } else if (OneScale) { + float single_scale = scales[0]; + float32x4_t vec_s = vdupq_n_f32(single_scale); + + vst1q_f32(c + i * 2, vfmaq_f32(vld1q_f32(c + i * 2), vec_c0, vec_s)); + vst1q_f32(c + i * 2 + 4, vfmaq_f32(vld1q_f32(c + i * 2 + 4), vec_c1, vec_s)); + vst1q_f32(c + i * 2 + 8, vfmaq_f32(vld1q_f32(c + i * 2 + 8), vec_c2, vec_s)); + vst1q_f32(c + i * 2 + 12, vfmaq_f32(vld1q_f32(c + i * 2 + 12), vec_c3, vec_s)); + vst1q_f32(c + i * 2 + 16, vfmaq_f32(vld1q_f32(c + i * 2 + 16), vec_c4, vec_s)); + vst1q_f32(c + i * 2 + 20, vfmaq_f32(vld1q_f32(c + i * 2 + 20), vec_c5, vec_s)); + vst1q_f32(c + i * 2 + 24, vfmaq_f32(vld1q_f32(c + i * 2 + 24), vec_c6, vec_s)); + vst1q_f32(c + i * 2 + 28, vfmaq_f32(vld1q_f32(c + i * 2 + 28), vec_c7, vec_s)); + } else { + float32x4_t vec_s0 = vld1q_f32(scales + ((i / 4) / Bits) * 8); + float32x4_t vec_s1 = vld1q_f32(scales + ((i / 4 + 1) / Bits) * 8); + float32x4_t vec_s2 = vld1q_f32(scales + ((i / 4 + 2) / Bits) * 8 + 4); + float32x4_t vec_s3 = vld1q_f32(scales + ((i / 4 + 3) / Bits) * 8 + 4); + + vst1q_f32(c + i * 2, vfmaq_f32(vld1q_f32(c + i * 2), vec_c0, vec_s0)); + vst1q_f32(c + i * 2 + 4, vfmaq_f32(vld1q_f32(c + i * 2 + 4), vec_c1, vec_s1)); + vst1q_f32(c + i * 2 + 8, vfmaq_f32(vld1q_f32(c + i * 2 + 8), vec_c2, vec_s2)); + vst1q_f32(c + i * 2 + 12, vfmaq_f32(vld1q_f32(c + i * 2 + 12), vec_c3, vec_s3)); + vst1q_f32(c + i * 2 + 16, vfmaq_f32(vld1q_f32(c + i * 2 + 16), vec_c4, vec_s0)); + vst1q_f32(c + i * 2 + 20, vfmaq_f32(vld1q_f32(c + i * 2 + 20), vec_c5, vec_s1)); + vst1q_f32(c + i * 2 + 24, vfmaq_f32(vld1q_f32(c + i * 2 + 24), vec_c6, vec_s2)); + vst1q_f32(c + i * 2 + 28, vfmaq_f32(vld1q_f32(c + i * 2 + 28), vec_c7, vec_s3)); + } + } + + return 0; +} + +// +// TMACComputeGemm - Entry point for GEMM computation +// +static void +TMACComputeGemm_neon( + const uint8_t* A, + const float* Scales, + const int8_t* LUT, + const float* LUT_Scales, + const float* LUT_Biases, + float* C, + int K, + int M, + int N, + size_t BlkLen, + bool HasZeroPoint +) +{ + // Validate batch size + if (N != 1) { + MLAS_THROW_EX(std::runtime_error, "N > 1 is not supported yet"); + } + + // Get kernel config + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(M, K, 2, BlkLen, HasZeroPoint); + + // Configuration + bool has_zero_point = tmac_params.has_zero_point; + bool one_scale = tmac_params.one_scale; + + const int32_t bits = static_cast(tmac_params.bits); + const int32_t g = static_cast(tmac_params.g); + const int32_t ngroups_per_elem = static_cast(tmac_params.ngroups_per_elem); + const int32_t kfactor = static_cast(tmac_params.kfactor); + + const bool has_scale = tmac_params.has_scale; + + const int32_t q_group_size = static_cast(tmac_params.q_group_size); + const int32_t act_group_size = static_cast(tmac_params.act_group_size); + const int32_t actk = static_cast(tmac_params.actk); + + const int32_t bm = static_cast(tmac_params.bm); + int32_t m = bm / bits; + + // Validate configuration + assert(bm % bits == 0); + assert(K % (kfactor * g) == 0); + assert(BlkLen % g == 0); + + // Allocate buffers + std::unique_ptr CBits(new float[bm]); + std::unique_ptr C_global(new float[m]); + + std::memset(CBits.get(), 0, bm * sizeof(float)); + + // Calculate loop parameters + const int32_t k_outer_max = K / (kfactor * g); + const int32_t scale_gs = q_group_size / (kfactor * g); + + int32_t scale_idx_shfr = 0; + if (scale_gs == 1) { + scale_idx_shfr = 0; + } else if (scale_gs == 2) { + scale_idx_shfr = 1; + } else if (scale_gs == 4) { + scale_idx_shfr = 2; + } else if (scale_gs == 8) { + scale_idx_shfr = 3; + } else { + MLAS_THROW_EX(std::runtime_error, "Unsupported scale_gs configuration"); + } + + // Main computation loop + for (int32_t k_outer = 0; k_outer < k_outer_max; k_outer++) { + const uint8_t* a = A + k_outer * bm * kfactor / ngroups_per_elem; + + const float* scales = one_scale ? reinterpret_cast(Scales) : + (has_zero_point ? reinterpret_cast(Scales) + (k_outer >> scale_idx_shfr) * m * 2 : + reinterpret_cast(Scales) + (k_outer >> scale_idx_shfr) * m); + + const int8_t* lut = reinterpret_cast(LUT) + k_outer * kfactor * (1 << g); + const float* lut_scales = reinterpret_cast(LUT_Scales) + + (k_outer * kfactor * g / act_group_size); + const float* lut_biases = reinterpret_cast(LUT_Biases) + + (k_outer * kfactor * g / act_group_size); + + // Select appropriate kernel template + if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl_neon( + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl_neon( + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl_neon( + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases + ); + } + // actk == 8 variants (for BlkLen=32) + else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl_neon( + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl_neon( + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl_neon( + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases + ); + } + // kfactor == 8 variants + else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl_neon( + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl_neon( + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl_neon( + static_cast(bm), CBits.get(), lut, a, scales, lut_scales, lut_biases + ); + } else { + MLAS_THROW_EX(std::runtime_error, "No matching kernel found for T-MAC GEMM"); + } + } + + // Gather results + tbl_g4_int8_float_gather_bit2_impl_neon(m, C_global.get(), CBits.get(), C); +} + +// +// Weight packing for NEON (can use scalar or NEON implementation) +// This is done during model load, so performance is less critical +// +static void +PackQuantBData_neon( + size_t N, + size_t K, + size_t bits, + size_t g, + size_t ngroups_per_elem, + size_t simd_n_in, + size_t simd_n_out, + size_t bm, + size_t kfactor, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + // Only optimized for 2-bit, g=4, ngroups_per_elem=2 + assert(bits == 2 && g == 4 && ngroups_per_elem == 2); + + const size_t mgroup = ngroups_per_elem * simd_n_in; // 32 + const size_t K_div_g = K / g; + + // Phase 1: Bit-plane decomposition + std::unique_ptr buf(new uint8_t[N * bits * K_div_g]); + + // Parallelize over N + MlasTrySimpleParallel( + ThreadPool, static_cast(N), + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + + const uint8_t* src_row = reinterpret_cast(QuantBDataBegin) + (im * K / 4); + uint8_t* dst_bit0 = buf.get() + im * bits * K_div_g; + uint8_t* dst_bit1 = dst_bit0 + K_div_g; + + // Initialize to zero + std::memset(dst_bit0, 0, K_div_g); + std::memset(dst_bit1, 0, K_div_g); + + // NEON-accelerated bit extraction + size_t ik = 0; + const uint8x16_t mask_2bit = vdupq_n_u8(0x03); + const uint8x16_t mask_bit0 = vdupq_n_u8(0x01); + + // Process 64 elements at a time (16 bytes input = 64 2-bit elements) + for (; ik + 64 <= K; ik += 64) { + uint8x16_t packed = vld1q_u8(src_row + ik / 4); + + // Extract each of 4 positions + uint8x16_t pos0 = vandq_u8(packed, mask_2bit); + uint8x16_t pos1 = vandq_u8(vshrq_n_u8(packed, 2), mask_2bit); + uint8x16_t pos2 = vandq_u8(vshrq_n_u8(packed, 4), mask_2bit); + uint8x16_t pos3 = vshrq_n_u8(packed, 6); + + // Extract bit 0 from each position + uint8x16_t b0_pos0 = vandq_u8(pos0, mask_bit0); + uint8x16_t b0_pos1 = vandq_u8(pos1, mask_bit0); + uint8x16_t b0_pos2 = vandq_u8(pos2, mask_bit0); + uint8x16_t b0_pos3 = vandq_u8(pos3, mask_bit0); + + // Combine for bit 0 plane + uint8x16_t bit0_out = vorrq_u8( + vorrq_u8(b0_pos0, vshlq_n_u8(b0_pos1, 1)), + vorrq_u8(vshlq_n_u8(b0_pos2, 2), vshlq_n_u8(b0_pos3, 3)) + ); + + // Extract bit 1 from each position + uint8x16_t b1_pos0 = vandq_u8(vshrq_n_u8(pos0, 1), mask_bit0); + uint8x16_t b1_pos1 = vandq_u8(vshrq_n_u8(pos1, 1), mask_bit0); + uint8x16_t b1_pos2 = vandq_u8(vshrq_n_u8(pos2, 1), mask_bit0); + uint8x16_t b1_pos3 = vandq_u8(vshrq_n_u8(pos3, 1), mask_bit0); + + // Combine for bit 1 plane + uint8x16_t bit1_out = vorrq_u8( + vorrq_u8(b1_pos0, vshlq_n_u8(b1_pos1, 1)), + vorrq_u8(vshlq_n_u8(b1_pos2, 2), vshlq_n_u8(b1_pos3, 3)) + ); + + vst1q_u8(dst_bit0 + ik / g, bit0_out); + vst1q_u8(dst_bit1 + ik / g, bit1_out); + } + + // Handle remaining elements with scalar code + for (; ik < K; ++ik) { + size_t idx = ik; + size_t num_elem_per_byte = 4; + size_t elem_idx = idx % num_elem_per_byte; + uint8_t v = src_row[idx / num_elem_per_byte] >> (elem_idx * bits); + + size_t new_ik = ik / g; + size_t shft_left = ik % g; + dst_bit0[new_ik] += static_cast(((v >> 0) & 1) << shft_left); + dst_bit1[new_ik] += static_cast(((v >> 1) & 1) << shft_left); + } + } + ); + + // Phase 2: Multi-reshape/transpose into final layout + const size_t bm_div_mgroup = bm / mgroup; + + const size_t c2_fac3_div = simd_n_in; + const size_t c2_fac2_div = kfactor * c2_fac3_div; + const size_t c2_fac1_div = bm_div_mgroup * c2_fac2_div; + const size_t c2_fac0_div = K_div_g * bm_div_mgroup * simd_n_in; + + const size_t PackedQuantBDataSize = (N * bits) * (K_div_g / ngroups_per_elem); + memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); + auto* packed_u8 = reinterpret_cast(PackedQuantBDataBegin); + + const size_t im_per_tile = ngroups_per_elem * simd_n_out; + const size_t num_tiles = (N + im_per_tile - 1) / im_per_tile; + + MlasTrySimpleParallel( + ThreadPool, static_cast(num_tiles), + [&](ptrdiff_t tid) { + const size_t im_start = static_cast(tid) * im_per_tile; + const size_t im_end = std::min(im_start + im_per_tile, N); + + for (size_t im = im_start; im < im_end; ++im) { + const size_t im0 = im / simd_n_out; + const size_t isno = im - im0 * simd_n_out; + const size_t x_base = simd_n_out * (im0 * bits) + isno; + + for (size_t ib = 0; ib < bits; ib++) { + const size_t x = x_base + ib * simd_n_out; + const size_t new_im1 = x / mgroup; + const size_t y = x - new_im1 * mgroup; + const size_t new_ing = y / simd_n_in; + const size_t new_isni = y - new_ing * simd_n_in; + + const size_t new_im2 = new_im1 / bm_div_mgroup; + const size_t new_ibm = new_im1 - new_im2 * bm_div_mgroup; + + const size_t base_im = new_im2 * c2_fac0_div + new_ibm * c2_fac2_div + new_isni; + const size_t buf_base = im * bits * K_div_g + ib * K_div_g; + + const uint8_t shift = static_cast(new_ing * g); + const size_t stride = c2_fac3_div; + + for (size_t ik = 0; ik < K_div_g; ik += kfactor) { + const size_t new_ik = ik / kfactor; + const size_t base_k = base_im + new_ik * c2_fac1_div; + const size_t buf_k = buf_base + ik; + + uint8_t* dst = packed_u8 + base_k; + const uint8_t* src = buf.get() + buf_k; + + for (size_t ikf = 0; ikf < kfactor; ikf++) { + dst[stride * ikf] = static_cast(dst[stride * ikf] + (src[ikf] << shift)); + } + } + } + } + } + ); +} + +// +// Scales and zero points packing +// +template +static void +PackScalesAndZeroPoints_neon_impl( + size_t N, + size_t K, + size_t bits, + size_t BlkLen, + size_t simd_n_out, + size_t bm, + float* PackedScalesBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + MLAS_THREADPOOL* ThreadPool +) +{ + if constexpr (HasZeroPoint) { + assert(QuantBZeroPoint != nullptr); + } + + const size_t num_elem_per_byte = 8 / bits; + const size_t row_blks = K / BlkLen; + const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte; + + const size_t nb1 = K / BlkLen; + const size_t bm_div_bits = bm / bits; + const int midpoint = 1 << (bits - 1); + const uint8_t bits_mask = static_cast((1 << bits) - 1); + + MlasTrySimpleParallel( + ThreadPool, static_cast(N), + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + const size_t new_im = (bm_div_bits > 0) ? (im / bm_div_bits) : 0; + const size_t new_ibm = (bm_div_bits > 0) ? (im - new_im * bm_div_bits) : 0; + + if constexpr (HasZeroPoint) { + const size_t new_isimd = new_ibm % simd_n_out; + const size_t new_ibm_div_simd = new_ibm / simd_n_out; + const size_t outer_base = new_im * (bm_div_bits * nb1 / simd_n_out) + new_ibm_div_simd; + const size_t outer_stride = bm_div_bits / simd_n_out; + + for (size_t blk_in_col = 0; blk_in_col < row_blks; blk_in_col++) { + const size_t idx = im * nb1 + blk_in_col; + const float scale = QuantBScale[idx]; + + size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte; + size_t elem_idx = blk_in_col % num_elem_per_byte; + uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & bits_mask; + float zp = static_cast(static_cast(v) - midpoint) * scale; + + const size_t new_idx_outer = outer_base + blk_in_col * outer_stride; + const size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; + const size_t new_idx_zero = new_idx_scale + simd_n_out; + + PackedScalesBegin[new_idx_scale] = scale; + PackedScalesBegin[new_idx_zero] = zp; + } + } else { + const size_t base_idx = new_im * bm_div_bits * nb1 + new_ibm; + const size_t stride_idx = bm_div_bits; + + for (size_t blk_in_col = 0; blk_in_col < row_blks; blk_in_col++) { + const size_t idx = im * nb1 + blk_in_col; + const float scale = QuantBScale[idx]; + const size_t new_idx = base_idx + blk_in_col * stride_idx; + PackedScalesBegin[new_idx] = scale; + } + } + } + ); +} + +static void +PackScalesAndZeroPoints_neon( + size_t N, + size_t K, + size_t bits, + size_t BlkLen, + size_t simd_n_out, + size_t bm, + bool HasZeroPoint, + float* PackedScalesBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(bits == 2); + + if (HasZeroPoint) { + PackScalesAndZeroPoints_neon_impl( + N, K, bits, BlkLen, simd_n_out, bm, + PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool + ); + } else { + PackScalesAndZeroPoints_neon_impl( + N, K, bits, BlkLen, simd_n_out, bm, + PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool + ); + } +} + +// +// Kernel dispatch structure definition +// +const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon = []() { + MLAS_QNBIT_LUT_GEMM_DISPATCH d; + d.GenerateLUT = GenerateLUT_neon; + d.ComputeGemm = TMACComputeGemm_neon; + d.PackQuantBData = PackQuantBData_neon; + d.PackScalesAndZeroPoints = PackScalesAndZeroPoints_neon; + return d; +}(); + +#endif // MLAS_TARGET_ARM64 diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.h b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.h new file mode 100644 index 0000000000000..f8710c8a90c0b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.h @@ -0,0 +1,27 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_lut_kernel_neon.h + +Abstract: + + This module contains the dispatch table declaration for ARM64 NEON + LUT-based n-bit quantized integer matrix multiplication kernels. + +--*/ + +#pragma once + +#include "qlutgemm.h" + +// +// External dispatch table for ARM NEON LUT GEMM kernels. +// Kernel functions are internal to the .cpp file and accessed via this dispatch. +// +extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon; + From bc085c569eec323e631e0f56f6a7f7ba6dc662dc Mon Sep 17 00:00:00 2001 From: Vrajang Parikh Date: Fri, 6 Feb 2026 21:32:27 +0000 Subject: [PATCH 09/10] working neon kernels --- .../mlas/lib/sqnbitgemm_lut_kernel_neon.cpp | 444 +++++---- .../unittest/test_sqlutgemm_components.cpp | 876 ++++++++++++++++++ 2 files changed, 1133 insertions(+), 187 deletions(-) create mode 100644 onnxruntime/test/mlas/unittest/test_sqlutgemm_components.cpp diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp index 7732969314f52..23cc244eb28fd 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp @@ -21,6 +21,8 @@ Module Name: --*/ +#include "mlas.h" + #if defined(MLAS_TARGET_ARM64) #include @@ -177,30 +179,33 @@ get_bias_scale() } // -// Partial max computation for LUT scale calculation +// Partial max computation for LUT scale calculation - SCALAR VERSION +// Computes: max(|b0| + |b1| + |b2| + |b3|) for 8 groups of 4 consecutive elements +// This is a direct port of the AVX2 algorithm to scalar for correctness verification // static inline void partial_max_g4_int8_k8_neon(float* lut_scales, const float* b) { - // Process 8 groups of 4 floats each (strided by 4) - float32x4_t max_abs = vdupq_n_f32(0.0f); - - for (int i = 0; i < 8; i++) { - // Load 4 consecutive floats from position i*4 - float32x4_t vals = vld1q_f32(b + i * 4); - float32x4_t abs_vals = vabsq_f32(vals); - max_abs = vmaxq_f32(max_abs, abs_vals); + // 8 groups of 4 consecutive elements each + // Groups: {0-3}, {4-7}, {8-11}, {12-15}, {16-19}, {20-23}, {24-27}, {28-31} + float max_abssum = 0.0f; + + for (int group = 0; group < 8; ++group) { + float abssum = std::abs(b[group * 4 + 0]) + + std::abs(b[group * 4 + 1]) + + std::abs(b[group * 4 + 2]) + + std::abs(b[group * 4 + 3]); + max_abssum = std::max(max_abssum, abssum); } - - // Horizontal max across the vector - float max_val = vmaxvq_f32(max_abs); - float scales = max_val / 127.0f; + + float scales = max_abssum / 127.0f; *lut_scales = std::max(*lut_scales, scales); } // -// LUT construction for int8 quantized activations -// Builds 16-entry lookup tables for groups of 4 activation values +// LUT construction - SCALAR VERSION +// This is a direct port of the AVX2 algorithm for correctness verification +// Output layout matches AVX2: qlut[k * 128 + group * 16 + lut_entry] // static inline void lut_ctor_g4_int8_impl_neon( @@ -211,93 +216,68 @@ lut_ctor_g4_int8_impl_neon( float* lut_biases ) { - float32x4_t vec_lut[16]; float biases = 0.0f; float scales = *lut_scales; float t_scales = scales ? 1.0f / scales : 0.0f; for (int k = 0; k < act_k / 32; ++k) { - // Load 4 groups of 8 floats (strided pattern) - // ORT uses contiguous float layout, so we load and rearrange - float32x4_t vec_b0, vec_b1, vec_b2, vec_b3; - - // Load first 4 elements from each group of 4 - // Pattern: b[k*32 + i*4 + j] where i=0..7, j=0..3 - // We need vec_b0 = {b[0], b[4], b[8], b[12], b[16], b[20], b[24], b[28]} etc. - // For NEON with float32, we work with 4 elements at a time - - // Simplified: process 4 lanes at a time - for (int lane = 0; lane < 2; lane++) { - const float* base = b + k * 32 + lane * 16; - - // Load 4 values with stride 4 - float b0_vals[4] = {base[0], base[4], base[8], base[12]}; - float b1_vals[4] = {base[1], base[5], base[9], base[13]}; - float b2_vals[4] = {base[2], base[6], base[10], base[14]}; - float b3_vals[4] = {base[3], base[7], base[11], base[15]}; - - vec_b0 = vld1q_f32(b0_vals); - vec_b1 = vld1q_f32(b1_vals); - vec_b2 = vld1q_f32(b2_vals); - vec_b3 = vld1q_f32(b3_vals); - - // Build 16-entry LUT: each entry is ±b0 ±b1 ±b2 ±b3 - PRAGMA_UNROLL + // For each of 8 groups of 4 consecutive elements + // Group g contains elements: b[k*32 + g*4 + 0..3] + float lut[16][8]; // lut[lut_entry][group] + + for (int group = 0; group < 8; ++group) { + // Get the 4 elements in this group + float b0 = b[k * 32 + group * 4 + 0]; + float b1 = b[k * 32 + group * 4 + 1]; + float b2 = b[k * 32 + group * 4 + 2]; + float b3 = b[k * 32 + group * 4 + 3]; + + // Build 16-entry LUT using ±b0 ±b1 ±b2 ±b3 + // Odd entries first (g = 1, 3, 5, ..., 15) for (int g = 1; g < 16; g += 2) { - vec_lut[g] = vec_b0; + float val = b0; if (g & 0b0010) { - vec_lut[g] = vaddq_f32(vec_lut[g], vec_b1); + val += b1; } else { - vec_lut[g] = vsubq_f32(vec_lut[g], vec_b1); + val -= b1; } if (g & 0b0100) { - vec_lut[g] = vaddq_f32(vec_lut[g], vec_b2); + val += b2; } else { - vec_lut[g] = vsubq_f32(vec_lut[g], vec_b2); + val -= b2; } if (g & 0b1000) { - vec_lut[g] = vaddq_f32(vec_lut[g], vec_b3); + val += b3; } else { - vec_lut[g] = vsubq_f32(vec_lut[g], vec_b3); + val -= b3; } + lut[g][group] = val; } - - // Symmetric: vec_lut[g] = -vec_lut[15 - g] - PRAGMA_UNROLL + + // Even entries: lut[g] = -lut[15 - g] for (int g = 0; g < 16; g += 2) { - vec_lut[g] = vnegq_f32(vec_lut[15 - g]); - } - - // Accumulate bias - biases += vaddvq_f32(vec_lut[0]); - - // Scale and quantize - PRAGMA_UNROLL - for (int g = 0; g < 16; ++g) { - vec_lut[g] = vmulq_n_f32(vec_lut[g], t_scales); + lut[g][group] = -lut[15 - g][group]; } - - // Convert to int8 and store - int8_t* qlut_dst = qlut + k * 128 + lane * 64; // 8 * 16 / 2 = 64 - - PRAGMA_UNROLL + } + + // Accumulate bias from lut[0] (sum across all 8 groups) + for (int group = 0; group < 8; ++group) { + biases += lut[0][group]; + } + + // Scale and quantize, then store + // Output layout: qlut[k * 128 + group * 16 + lut_entry] + for (int group = 0; group < 8; ++group) { for (int g = 0; g < 16; ++g) { - // Round and convert to int32 - int32x4_t i32 = vcvtnq_s32_f32(vec_lut[g]); - // Narrow to int16 - int16x4_t i16 = vqmovn_s32(i32); - // Narrow to int8 - int8x8_t i8 = vqmovn_s16(vcombine_s16(i16, i16)); - - // Store individual lanes with proper layout - qlut_dst[g + 0 * 16] = vget_lane_s8(i8, 0); - qlut_dst[g + 1 * 16] = vget_lane_s8(i8, 1); - qlut_dst[g + 2 * 16] = vget_lane_s8(i8, 2); - qlut_dst[g + 3 * 16] = vget_lane_s8(i8, 3); + float scaled = lut[g][group] * t_scales; + // Round to nearest, clamp to int8 range + int32_t rounded = static_cast(std::round(scaled)); + rounded = std::max(-128, std::min(127, rounded)); + qlut[k * 128 + group * 16 + g] = static_cast(rounded); } } } - + *lut_scales = scales; *lut_biases = biases; } @@ -376,7 +356,8 @@ tbl_g4_int8_float_gather_bit2_impl_neon(int32_t m, float* C_global, float* CBits } // -// Core GEMM compute kernel using table lookup +// Core GEMM compute kernel using table lookup - NEON FP32 VERSION +// Adapted from llama.cpp T-MAC FP16 NEON to use FP32 // template inline int32_t @@ -393,23 +374,18 @@ tbl_g4_int8_float_update_impl_neon( const uint8x16_t vec_mask = vdupq_n_u8(0x0f); int8x16_t vec_lut[K]; - // Load LUT tables + // Load LUT vectors PRAGMA_UNROLL for (int k = 0; k < K; k++) { vec_lut[k] = vld1q_s8(lut + k * 16); } SignedAdder adder_bot, adder_top; - + for (int i = 0; i < m / 2; i += 16) { - float32x4_t vec_c0 = vdupq_n_f32(0.0f); - float32x4_t vec_c1 = vdupq_n_f32(0.0f); - float32x4_t vec_c2 = vdupq_n_f32(0.0f); - float32x4_t vec_c3 = vdupq_n_f32(0.0f); - float32x4_t vec_c4 = vdupq_n_f32(0.0f); - float32x4_t vec_c5 = vdupq_n_f32(0.0f); - float32x4_t vec_c6 = vdupq_n_f32(0.0f); - float32x4_t vec_c7 = vdupq_n_f32(0.0f); + // For FP32, we need 8 vectors of 4 floats each to cover 32 outputs + // (compared to FP16's 4 vectors of 8 floats) + float32x4_t vec_c0, vec_c1, vec_c2, vec_c3, vec_c4, vec_c5, vec_c6, vec_c7; float partial_sum = 0.0f; @@ -417,36 +393,37 @@ tbl_g4_int8_float_update_impl_neon( for (int kk = 0; kk < K; kk += ActK) { PRAGMA_UNROLL for (int k = 0; k < ActK; k++) { - // Load packed 4-bit indices + // Load 16 packed bytes containing 32 4-bit indices uint8x16_t vec_as = vld1q_u8(a + i * K + (kk + k) * 16); - - // Extract nibbles - uint8x16_t vec_a_bot = vandq_u8(vec_as, vec_mask); // Lower 4 bits - uint8x16_t vec_a_top = vshrq_n_u8(vec_as, 4); // Upper 4 bits - - // TABLE LOOKUP - THE KEY OPERATION - int8x16_t vec_v_bot = vqtbl1q_s8(vec_lut[kk + k], vreinterpretq_u8_s8(vreinterpretq_s8_u8(vec_a_bot))); - int8x16_t vec_v_top = vqtbl1q_s8(vec_lut[kk + k], vreinterpretq_u8_s8(vreinterpretq_s8_u8(vec_a_top))); - - adder_bot.push(vec_v_bot, k); - adder_top.push(vec_v_top, k); + uint8x16_t vec_a_top = vshrq_n_u8(vec_as, 4); + uint8x16_t vec_a_bot = vandq_u8(vec_as, vec_mask); + + // Table lookup - get int8 values from LUT + // Note: vqtbl1q_s8 takes uint8x16_t as index type + int8x16_t vec_v_bot_tmp = vqtbl1q_s8(vec_lut[kk + k], vec_a_bot); + int8x16_t vec_v_top_tmp = vqtbl1q_s8(vec_lut[kk + k], vec_a_top); + + // Accumulate using appropriate adder + adder_bot.push(vec_v_bot_tmp, k); + adder_top.push(vec_v_top_tmp, k); } - // Widen to int16 - int16x8_t vec_v_bot_low = adder_bot.get_low(); - int16x8_t vec_v_bot_high = adder_bot.get_high(); - int16x8_t vec_v_top_low = adder_top.get_low(); - int16x8_t vec_v_top_high = adder_top.get_high(); - - // Convert to float32 (need to widen int16 -> int32 -> float32) - float32x4_t vec_v_bot_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec_v_bot_low))); - float32x4_t vec_v_bot_low_high = vcvtq_f32_s32(vmovl_high_s16(vec_v_bot_low)); - float32x4_t vec_v_bot_high_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec_v_bot_high))); - float32x4_t vec_v_bot_high_high = vcvtq_f32_s32(vmovl_high_s16(vec_v_bot_high)); - float32x4_t vec_v_top_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec_v_top_low))); - float32x4_t vec_v_top_low_high = vcvtq_f32_s32(vmovl_high_s16(vec_v_top_low)); - float32x4_t vec_v_top_high_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec_v_top_high))); - float32x4_t vec_v_top_high_high = vcvtq_f32_s32(vmovl_high_s16(vec_v_top_high)); + // Get accumulated int16 values + int16x8_t sum_bot_low = adder_bot.get_low(); // bot elements 0-7 + int16x8_t sum_bot_high = adder_bot.get_high(); // bot elements 8-15 + int16x8_t sum_top_low = adder_top.get_low(); // top elements 0-7 + int16x8_t sum_top_high = adder_top.get_high(); // top elements 8-15 + + // Convert to FP32 - each int16x8_t becomes two float32x4_t + // vec_v_*_lo = first 4 elements, vec_v_*_hi = last 4 elements + float32x4_t vec_v_bot_low_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(sum_bot_low))); + float32x4_t vec_v_bot_low_hi = vcvtq_f32_s32(vmovl_high_s16(sum_bot_low)); + float32x4_t vec_v_bot_high_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(sum_bot_high))); + float32x4_t vec_v_bot_high_hi = vcvtq_f32_s32(vmovl_high_s16(sum_bot_high)); + float32x4_t vec_v_top_low_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(sum_top_low))); + float32x4_t vec_v_top_low_hi = vcvtq_f32_s32(vmovl_high_s16(sum_top_low)); + float32x4_t vec_v_top_high_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(sum_top_high))); + float32x4_t vec_v_top_high_hi = vcvtq_f32_s32(vmovl_high_s16(sum_top_high)); float lut_s = lut_scales[kk / ActK]; float lut_b = lut_biases[kk / ActK]; @@ -457,95 +434,188 @@ tbl_g4_int8_float_update_impl_neon( if (FastAggregation) { lut_s = lut_s * ActK; - lut_b -= lut_s * (mylog2::value / 4 * get_bias_scale()); + lut_b -= lut_s * (mylog2::value / 4.0f * get_bias_scale()); } - // FMA operations with conditional bias -#define lut_fma(vs, ib) \ - (((ib) % Bits) ? vmulq_n_f32((vs), lut_s) : vmlaq_n_f32(vdupq_n_f32(lut_b), (vs), lut_s)) + float32x4_t vec_lut_s = vdupq_n_f32(lut_s); + float32x4_t vec_lut_b = vdupq_n_f32(lut_b); + + // lut_fma: ((ib % Bits) ? (v * lut_s) : (v * lut_s + lut_b)) + // ib for each group: + // Group 0 (c0,c1): ib = i/4 + // Group 1 (c2,c3): ib = i/4 + 1 + // Group 2 (c4,c5): ib = i/4 + 2 + // Group 3 (c6,c7): ib = i/4 + 3 + + int ib0 = i / 4; + int ib1 = i / 4 + 1; + int ib2 = i / 4 + 2; + int ib3 = i / 4 + 3; + +#define LUT_FMA(vec_v, ib_val) \ + (((ib_val) % Bits) ? vmulq_f32(vec_v, vec_lut_s) : vmlaq_f32(vec_lut_b, vec_v, vec_lut_s)) if (kk == 0) { - vec_c0 = lut_fma(vec_v_bot_low_low, (i / 4)); - vec_c1 = lut_fma(vec_v_bot_low_high, (i / 4 + 1)); - vec_c2 = lut_fma(vec_v_bot_high_low, (i / 4 + 2)); - vec_c3 = lut_fma(vec_v_bot_high_high, (i / 4 + 3)); - vec_c4 = lut_fma(vec_v_top_low_low, (i / 4 + 4)); - vec_c5 = lut_fma(vec_v_top_low_high, (i / 4 + 5)); - vec_c6 = lut_fma(vec_v_top_high_low, (i / 4 + 6)); - vec_c7 = lut_fma(vec_v_top_high_high, (i / 4 + 7)); + vec_c0 = LUT_FMA(vec_v_bot_low_lo, ib0); + vec_c1 = LUT_FMA(vec_v_bot_low_hi, ib0); + vec_c2 = LUT_FMA(vec_v_bot_high_lo, ib1); + vec_c3 = LUT_FMA(vec_v_bot_high_hi, ib1); + vec_c4 = LUT_FMA(vec_v_top_low_lo, ib2); + vec_c5 = LUT_FMA(vec_v_top_low_hi, ib2); + vec_c6 = LUT_FMA(vec_v_top_high_lo, ib3); + vec_c7 = LUT_FMA(vec_v_top_high_hi, ib3); } else { - vec_c0 = vaddq_f32(vec_c0, lut_fma(vec_v_bot_low_low, (i / 4))); - vec_c1 = vaddq_f32(vec_c1, lut_fma(vec_v_bot_low_high, (i / 4 + 1))); - vec_c2 = vaddq_f32(vec_c2, lut_fma(vec_v_bot_high_low, (i / 4 + 2))); - vec_c3 = vaddq_f32(vec_c3, lut_fma(vec_v_bot_high_high, (i / 4 + 3))); - vec_c4 = vaddq_f32(vec_c4, lut_fma(vec_v_top_low_low, (i / 4 + 4))); - vec_c5 = vaddq_f32(vec_c5, lut_fma(vec_v_top_low_high, (i / 4 + 5))); - vec_c6 = vaddq_f32(vec_c6, lut_fma(vec_v_top_high_low, (i / 4 + 6))); - vec_c7 = vaddq_f32(vec_c7, lut_fma(vec_v_top_high_high, (i / 4 + 7))); + vec_c0 = vaddq_f32(vec_c0, LUT_FMA(vec_v_bot_low_lo, ib0)); + vec_c1 = vaddq_f32(vec_c1, LUT_FMA(vec_v_bot_low_hi, ib0)); + vec_c2 = vaddq_f32(vec_c2, LUT_FMA(vec_v_bot_high_lo, ib1)); + vec_c3 = vaddq_f32(vec_c3, LUT_FMA(vec_v_bot_high_hi, ib1)); + vec_c4 = vaddq_f32(vec_c4, LUT_FMA(vec_v_top_low_lo, ib2)); + vec_c5 = vaddq_f32(vec_c5, LUT_FMA(vec_v_top_low_hi, ib2)); + vec_c6 = vaddq_f32(vec_c6, LUT_FMA(vec_v_top_high_lo, ib3)); + vec_c7 = vaddq_f32(vec_c7, LUT_FMA(vec_v_top_high_hi, ib3)); } -#undef lut_fma +#undef LUT_FMA } // Apply weight scales and store if (ZeroPoint) { - float32x4_t vec_s0 = vld1q_f32(scales + ((i / 4) / Bits) * 16); - float32x4_t vec_s1 = vld1q_f32(scales + ((i / 4 + 1) / Bits) * 16); - float32x4_t vec_s2 = vld1q_f32(scales + ((i / 4 + 2) / Bits) * 16 + 4); - float32x4_t vec_s3 = vld1q_f32(scales + ((i / 4 + 3) / Bits) * 16 + 4); - - vec_c0 = vfmaq_f32(vld1q_f32(c + i * 2), vec_c0, vec_s0); - vec_c1 = vfmaq_f32(vld1q_f32(c + i * 2 + 4), vec_c1, vec_s1); - vec_c2 = vfmaq_f32(vld1q_f32(c + i * 2 + 8), vec_c2, vec_s2); - vec_c3 = vfmaq_f32(vld1q_f32(c + i * 2 + 12), vec_c3, vec_s3); - - float32x4_t vec_z0 = vld1q_f32(scales + ((i / 4) / Bits) * 16 + 8); - float32x4_t vec_z1 = vld1q_f32(scales + ((i / 4 + 1) / Bits) * 16 + 8); - float32x4_t vec_z2 = vld1q_f32(scales + ((i / 4 + 2) / Bits) * 16 + 12); - float32x4_t vec_z3 = vld1q_f32(scales + ((i / 4 + 3) / Bits) * 16 + 12); - partial_sum *= 2; + float32x4_t vec_ps = vdupq_n_f32(partial_sum); + + // For ZeroPoint mode, scales are interleaved with zero points + // scales[base+0..7] = scales, scales[base+8..15] = zero_points + int base0 = ((i / 4) / Bits) * 16; + int base1 = ((i / 4 + 1) / Bits) * 16; + int base2 = ((i / 4 + 2) / Bits) * 16; + int base3 = ((i / 4 + 3) / Bits) * 16; + + // Load scales (first 8 of each 16-element group) + float32x4_t s0_lo = vld1q_f32(scales + base0); + float32x4_t s0_hi = vld1q_f32(scales + base0 + 4); + float32x4_t s1_lo = vld1q_f32(scales + base1); + float32x4_t s1_hi = vld1q_f32(scales + base1 + 4); + float32x4_t s2_lo = vld1q_f32(scales + base2); + float32x4_t s2_hi = vld1q_f32(scales + base2 + 4); + float32x4_t s3_lo = vld1q_f32(scales + base3); + float32x4_t s3_hi = vld1q_f32(scales + base3 + 4); + + // Load zero points (second 8 of each 16-element group) + float32x4_t z0_lo = vld1q_f32(scales + base0 + 8); + float32x4_t z0_hi = vld1q_f32(scales + base0 + 12); + float32x4_t z1_lo = vld1q_f32(scales + base1 + 8); + float32x4_t z1_hi = vld1q_f32(scales + base1 + 12); + float32x4_t z2_lo = vld1q_f32(scales + base2 + 8); + float32x4_t z2_hi = vld1q_f32(scales + base2 + 12); + float32x4_t z3_lo = vld1q_f32(scales + base3 + 8); + float32x4_t z3_hi = vld1q_f32(scales + base3 + 12); + + // Load previous C values + float32x4_t prev0 = vld1q_f32(c + i * 2); + float32x4_t prev1 = vld1q_f32(c + i * 2 + 4); + float32x4_t prev2 = vld1q_f32(c + i * 2 + 8); + float32x4_t prev3 = vld1q_f32(c + i * 2 + 12); + float32x4_t prev4 = vld1q_f32(c + i * 2 + 16); + float32x4_t prev5 = vld1q_f32(c + i * 2 + 20); + float32x4_t prev6 = vld1q_f32(c + i * 2 + 24); + float32x4_t prev7 = vld1q_f32(c + i * 2 + 28); + + // result = prev + acc * scale + (zero * partial_sum if ib % Bits == 0) + int ib0 = i / 4; + int ib1 = i / 4 + 1; + int ib2 = i / 4 + 2; + int ib3 = i / 4 + 3; + + vec_c0 = vmlaq_f32(prev0, vec_c0, s0_lo); + vec_c1 = vmlaq_f32(prev1, vec_c1, s0_hi); + vec_c2 = vmlaq_f32(prev2, vec_c2, s1_lo); + vec_c3 = vmlaq_f32(prev3, vec_c3, s1_hi); + vec_c4 = vmlaq_f32(prev4, vec_c4, s2_lo); + vec_c5 = vmlaq_f32(prev5, vec_c5, s2_hi); + vec_c6 = vmlaq_f32(prev6, vec_c6, s3_lo); + vec_c7 = vmlaq_f32(prev7, vec_c7, s3_hi); + + if ((ib0 % Bits) == 0) { + vec_c0 = vmlaq_f32(vec_c0, z0_lo, vec_ps); + vec_c1 = vmlaq_f32(vec_c1, z0_hi, vec_ps); + } + if ((ib1 % Bits) == 0) { + vec_c2 = vmlaq_f32(vec_c2, z1_lo, vec_ps); + vec_c3 = vmlaq_f32(vec_c3, z1_hi, vec_ps); + } + if ((ib2 % Bits) == 0) { + vec_c4 = vmlaq_f32(vec_c4, z2_lo, vec_ps); + vec_c5 = vmlaq_f32(vec_c5, z2_hi, vec_ps); + } + if ((ib3 % Bits) == 0) { + vec_c6 = vmlaq_f32(vec_c6, z3_lo, vec_ps); + vec_c7 = vmlaq_f32(vec_c7, z3_hi, vec_ps); + } -#define add_zero(cs, zs, ib) \ - (((ib) % Bits) ? (cs) : vfmaq_n_f32((cs), (zs), partial_sum)) - - vst1q_f32(c + i * 2, add_zero(vec_c0, vec_z0, (i / 4))); - vst1q_f32(c + i * 2 + 4, add_zero(vec_c1, vec_z1, (i / 4 + 1))); - vst1q_f32(c + i * 2 + 8, add_zero(vec_c2, vec_z2, (i / 4 + 2))); - vst1q_f32(c + i * 2 + 12, add_zero(vec_c3, vec_z3, (i / 4 + 3))); + // Store results + vst1q_f32(c + i * 2, vec_c0); + vst1q_f32(c + i * 2 + 4, vec_c1); + vst1q_f32(c + i * 2 + 8, vec_c2); + vst1q_f32(c + i * 2 + 12, vec_c3); vst1q_f32(c + i * 2 + 16, vec_c4); vst1q_f32(c + i * 2 + 20, vec_c5); vst1q_f32(c + i * 2 + 24, vec_c6); vst1q_f32(c + i * 2 + 28, vec_c7); -#undef add_zero } else if (OneScale) { - float single_scale = scales[0]; - float32x4_t vec_s = vdupq_n_f32(single_scale); - - vst1q_f32(c + i * 2, vfmaq_f32(vld1q_f32(c + i * 2), vec_c0, vec_s)); - vst1q_f32(c + i * 2 + 4, vfmaq_f32(vld1q_f32(c + i * 2 + 4), vec_c1, vec_s)); - vst1q_f32(c + i * 2 + 8, vfmaq_f32(vld1q_f32(c + i * 2 + 8), vec_c2, vec_s)); - vst1q_f32(c + i * 2 + 12, vfmaq_f32(vld1q_f32(c + i * 2 + 12), vec_c3, vec_s)); - vst1q_f32(c + i * 2 + 16, vfmaq_f32(vld1q_f32(c + i * 2 + 16), vec_c4, vec_s)); - vst1q_f32(c + i * 2 + 20, vfmaq_f32(vld1q_f32(c + i * 2 + 20), vec_c5, vec_s)); - vst1q_f32(c + i * 2 + 24, vfmaq_f32(vld1q_f32(c + i * 2 + 24), vec_c6, vec_s)); - vst1q_f32(c + i * 2 + 28, vfmaq_f32(vld1q_f32(c + i * 2 + 28), vec_c7, vec_s)); + float32x4_t vec_s = vdupq_n_f32(scales[0]); + + vec_c0 = vaddq_f32(vld1q_f32(c + i * 2), vmulq_f32(vec_c0, vec_s)); + vec_c1 = vaddq_f32(vld1q_f32(c + i * 2 + 4), vmulq_f32(vec_c1, vec_s)); + vec_c2 = vaddq_f32(vld1q_f32(c + i * 2 + 8), vmulq_f32(vec_c2, vec_s)); + vec_c3 = vaddq_f32(vld1q_f32(c + i * 2 + 12), vmulq_f32(vec_c3, vec_s)); + vec_c4 = vaddq_f32(vld1q_f32(c + i * 2 + 16), vmulq_f32(vec_c4, vec_s)); + vec_c5 = vaddq_f32(vld1q_f32(c + i * 2 + 20), vmulq_f32(vec_c5, vec_s)); + vec_c6 = vaddq_f32(vld1q_f32(c + i * 2 + 24), vmulq_f32(vec_c6, vec_s)); + vec_c7 = vaddq_f32(vld1q_f32(c + i * 2 + 28), vmulq_f32(vec_c7, vec_s)); + + vst1q_f32(c + i * 2, vec_c0); + vst1q_f32(c + i * 2 + 4, vec_c1); + vst1q_f32(c + i * 2 + 8, vec_c2); + vst1q_f32(c + i * 2 + 12, vec_c3); + vst1q_f32(c + i * 2 + 16, vec_c4); + vst1q_f32(c + i * 2 + 20, vec_c5); + vst1q_f32(c + i * 2 + 24, vec_c6); + vst1q_f32(c + i * 2 + 28, vec_c7); } else { - float32x4_t vec_s0 = vld1q_f32(scales + ((i / 4) / Bits) * 8); - float32x4_t vec_s1 = vld1q_f32(scales + ((i / 4 + 1) / Bits) * 8); - float32x4_t vec_s2 = vld1q_f32(scales + ((i / 4 + 2) / Bits) * 8 + 4); - float32x4_t vec_s3 = vld1q_f32(scales + ((i / 4 + 3) / Bits) * 8 + 4); - - vst1q_f32(c + i * 2, vfmaq_f32(vld1q_f32(c + i * 2), vec_c0, vec_s0)); - vst1q_f32(c + i * 2 + 4, vfmaq_f32(vld1q_f32(c + i * 2 + 4), vec_c1, vec_s1)); - vst1q_f32(c + i * 2 + 8, vfmaq_f32(vld1q_f32(c + i * 2 + 8), vec_c2, vec_s2)); - vst1q_f32(c + i * 2 + 12, vfmaq_f32(vld1q_f32(c + i * 2 + 12), vec_c3, vec_s3)); - vst1q_f32(c + i * 2 + 16, vfmaq_f32(vld1q_f32(c + i * 2 + 16), vec_c4, vec_s0)); - vst1q_f32(c + i * 2 + 20, vfmaq_f32(vld1q_f32(c + i * 2 + 20), vec_c5, vec_s1)); - vst1q_f32(c + i * 2 + 24, vfmaq_f32(vld1q_f32(c + i * 2 + 24), vec_c6, vec_s2)); - vst1q_f32(c + i * 2 + 28, vfmaq_f32(vld1q_f32(c + i * 2 + 28), vec_c7, vec_s3)); + // Symmetric quantization without zero points + int base0 = ((i / 4) / Bits) * 8; + int base1 = ((i / 4 + 1) / Bits) * 8; + int base2 = ((i / 4 + 2) / Bits) * 8; + int base3 = ((i / 4 + 3) / Bits) * 8; + + float32x4_t s0_lo = vld1q_f32(scales + base0); + float32x4_t s0_hi = vld1q_f32(scales + base0 + 4); + float32x4_t s1_lo = vld1q_f32(scales + base1); + float32x4_t s1_hi = vld1q_f32(scales + base1 + 4); + float32x4_t s2_lo = vld1q_f32(scales + base2); + float32x4_t s2_hi = vld1q_f32(scales + base2 + 4); + float32x4_t s3_lo = vld1q_f32(scales + base3); + float32x4_t s3_hi = vld1q_f32(scales + base3 + 4); + + vec_c0 = vmlaq_f32(vld1q_f32(c + i * 2), vec_c0, s0_lo); + vec_c1 = vmlaq_f32(vld1q_f32(c + i * 2 + 4), vec_c1, s0_hi); + vec_c2 = vmlaq_f32(vld1q_f32(c + i * 2 + 8), vec_c2, s1_lo); + vec_c3 = vmlaq_f32(vld1q_f32(c + i * 2 + 12), vec_c3, s1_hi); + vec_c4 = vmlaq_f32(vld1q_f32(c + i * 2 + 16), vec_c4, s2_lo); + vec_c5 = vmlaq_f32(vld1q_f32(c + i * 2 + 20), vec_c5, s2_hi); + vec_c6 = vmlaq_f32(vld1q_f32(c + i * 2 + 24), vec_c6, s3_lo); + vec_c7 = vmlaq_f32(vld1q_f32(c + i * 2 + 28), vec_c7, s3_hi); + + vst1q_f32(c + i * 2, vec_c0); + vst1q_f32(c + i * 2 + 4, vec_c1); + vst1q_f32(c + i * 2 + 8, vec_c2); + vst1q_f32(c + i * 2 + 12, vec_c3); + vst1q_f32(c + i * 2 + 16, vec_c4); + vst1q_f32(c + i * 2 + 20, vec_c5); + vst1q_f32(c + i * 2 + 24, vec_c6); + vst1q_f32(c + i * 2 + 28, vec_c7); } } - + return 0; } diff --git a/onnxruntime/test/mlas/unittest/test_sqlutgemm_components.cpp b/onnxruntime/test/mlas/unittest/test_sqlutgemm_components.cpp new file mode 100644 index 0000000000000..0048e8f8d5d06 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqlutgemm_components.cpp @@ -0,0 +1,876 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sqlutgemm_components.cpp + +Abstract: + + Component tests for MLAS LUT-based n-bit GEMM (TMAC/LUT path). + These tests verify individual components (weight packing, scale packing) + by comparing SIMD implementations against scalar reference implementations. + + The scalar reference implementations are copied from the ORT main branch + qlutgemm.cpp to serve as ground truth. + +--*/ + +#include "test_util.h" +#include "mlas_qnbit.h" +#include "mlas_q4.h" + +#include +#include +#include +#include +#include + +// +// ============================================================================ +// SCALAR REFERENCE IMPLEMENTATIONS +// These are copied from onnxruntime main branch qlutgemm.cpp to serve as +// platform-independent ground truth for testing SIMD implementations. +// ============================================================================ +// + +namespace ScalarReference { + +/** + * @brief Calculates packed quantized B data size (same as LutGemmPackQuantBDataSize) + */ +static size_t +PackQuantBDataSize( + size_t N, + size_t bits, + size_t K, + size_t g, + size_t ngroups_per_elem +) +{ + return (N * bits) * (K / g / ngroups_per_elem); +} + +/** + * @brief Calculates packed scales/zp size in floats + */ +static size_t +PackScalesAndZeroPointsSize( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint +) +{ + if (HasZeroPoint) { + return N * K / BlkLen * 2; + } else { + return N * K / BlkLen; + } +} + +/** + * @brief Calculate the aligned offset to scales in the packed buffer + */ +static size_t +PackedScalesOffset( + size_t PackedQuantBDataSize +) +{ + constexpr size_t kAlignment = 64; + return ((PackedQuantBDataSize + kAlignment - 1) / kAlignment) * kAlignment; +} + +/** + * @brief Scalar reference implementation for packing quantized B data. + * + * This performs the T-MAC weight transformation: + * 1. Bit-plane decomposition with g=4 grouping + * 2. Multi-stage reshape/transpose for LUT-optimized layout + * + * Copied from: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/mlas/lib/qlutgemm.cpp + */ +static void +PackQuantBData_Reference( + size_t N, + size_t K, + size_t bits, + size_t g, + size_t ngroups_per_elem, + size_t simd_n_in, + size_t simd_n_out, + size_t bm, + size_t kfactor, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin +) +{ + assert(bits == 2 && g == 4 && ngroups_per_elem == 2); + + const size_t mgroup = ngroups_per_elem * simd_n_in; // 32 + assert(bm % mgroup == 0); + assert(bm % bits == 0); + + std::unique_ptr buf(new uint8_t[N * bits * (K / g)]); + memset(buf.get(), 0, N * bits * (K / g)); + + // Phase 1: Bit-plane decomposition + for (size_t im = 0; im < N; ++im) { + for (size_t ik = 0; ik < K; ++ik) { + size_t idx = (im * K + ik); + size_t num_elem_per_byte = 8 / bits; + size_t elem_idx = idx % num_elem_per_byte; + + uint8_t v = ((const uint8_t*)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits); + + for (size_t ib = 0; ib < bits; ++ib) { + size_t new_ik = ik / g; + size_t shft_left = ik % g; + buf[im * bits * K / g + ib * K / g + new_ik] += + static_cast(((v >> ib) & 1) << shft_left); + } + } + } + + // Phase 2: Multi-reshape/transpose into final layout + const size_t c0_fac2 = K / g; + const size_t c0_fac1 = simd_n_out * c0_fac2; + const size_t c0_fac0 = bits * c0_fac1; + + const size_t c1_nb2 = K / g; + const size_t c1_nb1 = simd_n_in * c1_nb2; + const size_t c1_nb0 = ngroups_per_elem * c1_nb1; + const size_t c1_fac2 = K / g; + const size_t c1_fac1 = ngroups_per_elem * c1_fac2; + const size_t c1_fac0 = simd_n_in * c1_fac1; + + const size_t c2_nb4 = kfactor; + const size_t c2_nb3 = K / g / kfactor * c2_nb4; + const size_t c2_nb2 = ngroups_per_elem * c2_nb3; + const size_t c2_nb1 = simd_n_in * c2_nb2; + const size_t c2_nb0 = bm / mgroup * c2_nb1; + const size_t c2_fac3 = simd_n_in * ngroups_per_elem; + const size_t c2_fac2 = kfactor * c2_fac3; + const size_t c2_fac1 = bm / mgroup * c2_fac2; + const size_t c2_fac0 = K / g / kfactor * c2_fac1; + + const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem); + memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); + + for (size_t im = 0; im < N; ++im) { + for (size_t ib = 0; ib < bits; ib++) { + for (size_t ik = 0; ik < K / g; ik++) { + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + size_t new_im = im / simd_n_out; + size_t new_isno = im % simd_n_out; + size_t new_ib = ib; + size_t new_ik = ik; + size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; + + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + new_im = new_idx / c1_nb0; + size_t new_ing = (new_idx % c1_nb0) / c1_nb1; + size_t new_isni = (new_idx % c1_nb1) / c1_nb2; + new_ik = (new_idx % c1_nb2); + new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; + + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + new_im = new_idx / c2_nb0; + size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; + new_isni = (new_idx % c2_nb1) / c2_nb2; + new_ing = (new_idx % c2_nb2) / c2_nb3; + new_ik = (new_idx % c2_nb3) / c2_nb4; + size_t new_ikf = (new_idx % c2_nb4); + new_idx = new_im * c2_fac0 + + new_ik * c2_fac1 + + new_ibm * c2_fac2 + + new_ikf * c2_fac3 + + new_isni * ngroups_per_elem + + new_ing; + new_idx = new_idx / ngroups_per_elem; + size_t buf_idx = im * bits * K / g + ib * K / g + ik; + uint8_t buf_val = buf[buf_idx]; + + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + PackedQuantBDataBegin[new_idx] = static_cast( + static_cast(PackedQuantBDataBegin[new_idx]) + + (buf_val << (new_ing * g)) + ); + } + } + } +} + +/** + * @brief Scalar reference implementation for packing scales and zero points. + * + * This transforms scales/zero-points to match the tiled weight layout. + * + * Copied from: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/mlas/lib/qlutgemm.cpp + */ +static void +PackScalesAndZeroPoints_Reference( + size_t N, + size_t K, + size_t bits, + size_t BlkLen, + size_t simd_n_out, + size_t bm, + bool HasZeroPoint, + float* PackedScalesBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint +) +{ + const size_t num_elem_per_byte = 8 / bits; + + // ZP array is column-major packed, with per-column alignment to byte boundary + const size_t row_blks = K / BlkLen; // number of blocks per column + const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte; + + for (size_t im = 0; im < N; im += 1) { + for (size_t ik = 0; ik < K; ik += BlkLen) { + size_t idx = (im * K + ik) / BlkLen; // linear block index for scale + float scale = QuantBScale[idx]; + float zp = 0.0f; + if (HasZeroPoint) { + size_t blk_in_col = ik / BlkLen; // block index within column + size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte; + size_t elem_idx = blk_in_col % num_elem_per_byte; + uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & ((1 << bits) - 1); + + // The LUT kernel assumes weights are centered around the midpoint (2 for 2-bit). + int midpoint = 1 << (bits - 1); // 2 for 2-bit + zp = static_cast(static_cast(v) - midpoint) * scale; + } + + size_t nb1 = K / BlkLen; + size_t nb0 = bm / bits * nb1; + + size_t new_im, new_ibm, new_ik; + if (nb1 == 0) { + new_im = 0; + new_ibm = 0; + new_ik = 0; + } else { + new_im = idx / nb0; + new_ibm = (idx % nb0) / nb1; + new_ik = (idx % nb1); + } + + if (HasZeroPoint) { + size_t new_isimd = new_ibm % simd_n_out; + size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out; + size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; + size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd; + + PackedScalesBegin[new_idx_scale] = scale; + PackedScalesBegin[new_idx_zero] = zp; + } else { + size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm; + PackedScalesBegin[new_idx] = scale; + } + } + } +} + +/** + * @brief Full packing reference combining weights and scales + * + * This mirrors the structure of MlasLutGemmPack + */ +static void +LutGemmPack_Reference( + size_t N, + size_t K, + size_t bits, + size_t BlkLen, + bool HasZeroPoint, + size_t g, + size_t ngroups_per_elem, + size_t simd_n_in, + size_t simd_n_out, + size_t bm, + size_t kfactor, + const std::byte* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + std::byte* PackedBuf +) +{ + // Pack B data + if (QuantBData != nullptr) { + PackQuantBData_Reference( + N, K, bits, g, ngroups_per_elem, + simd_n_in, simd_n_out, bm, kfactor, + QuantBData, PackedBuf + ); + } + + // Pack scales/zero points + if (QuantBScale != nullptr) { + size_t packed_b_size = PackQuantBDataSize(N, bits, K, g, ngroups_per_elem); + size_t scales_offset = PackedScalesOffset(packed_b_size); + float* scales_dest = reinterpret_cast(PackedBuf + scales_offset); + + PackScalesAndZeroPoints_Reference( + N, K, bits, BlkLen, simd_n_out, bm, + HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint + ); + } +} + +/** + * @brief Select optimal bm (tile size) for given dimensions + * + * This mirrors the logic in MlasInitLutGemmKernelConfig + */ +static size_t +SelectOptimalBm(size_t N, size_t bits) +{ + std::vector bms = {256, 512, 1024, 2048, 320, 640, 1280}; + + // Use a simple heuristic: pick the largest bm that divides N * bits evenly + for (size_t bm : bms) { + if (N % (bm / bits) == 0 && bm % bits == 0) { + return bm; + } + } + return bms[0]; // fallback +} + +/** + * @brief Select optimal kfactor + */ +static size_t +SelectOptimalKfactor(size_t BlkLen, size_t g, size_t actk) +{ + std::vector kfactors = {16, 8}; + + for (size_t kfactor : kfactors) { + if (kfactor >= actk && kfactor * g <= BlkLen) { + return kfactor; + } + } + return kfactors.back(); +} + +} // namespace ScalarReference + +// +// ============================================================================ +// TEST CLASSES +// ============================================================================ +// + +/** + * @brief Test class for verifying the full packing implementation. + * + * Compares the dispatched (NEON/AVX2) MlasLutGemmPack against the scalar reference. + */ +template +class MlasSQLutGemmPackTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferQuantBData; + MatrixGuardBuffer BufferQuantBScale; + MatrixGuardBuffer BufferQuantBZeroPoint; + MatrixGuardBuffer BufferPackedExpected; + MatrixGuardBuffer BufferPackedActual; + + public: + void Test(size_t N, size_t K, bool Symmetric) { + MLAS_THREADPOOL* tp = GetMlasThreadPool(); + + // Clear config cache + MlasClearLutGemmKernelConfig(); + + // Generate random input matrix B + const float* B = BufferB.GetBuffer(N * K); + + // Quantize B + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes( + BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + uint8_t* QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); + float* QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); + uint8_t* QuantBZeroPoint = Symmetric ? nullptr : BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes); + + MlasQuantizeBlockwise( + QuantBData, QuantBScale, QuantBZeroPoint, + B, BlkLen, + /* columnwise */ true, + static_cast(K), static_cast(N), + static_cast(N), + tp); + + // Initialize kernel config (this sets up the internal params) + MlasInitLutGemmKernelConfig(N, K, BlkBitWidth, BlkLen, !Symmetric); + + // Get packed buffer size + size_t PackedBufSize = MlasLutGemmPackedSize(N, K, BlkBitWidth, BlkLen, !Symmetric); + + std::byte* PackedActual = BufferPackedActual.GetBuffer(PackedBufSize, true); + std::byte* PackedExpected = BufferPackedExpected.GetBuffer(PackedBufSize, true); + + // Fixed T-MAC parameters (these match MlasInitLutGemmKernelConfig) + constexpr size_t g = 4; + constexpr size_t ngroups_per_elem = 2; + constexpr size_t simd_n_in = 16; + constexpr size_t simd_n_out = 8; + + size_t bm = ScalarReference::SelectOptimalBm(N, BlkBitWidth); + size_t act_group_size = (BlkLen % 64 == 0) ? 64 : 32; + size_t actk = act_group_size / g; + size_t kfactor = ScalarReference::SelectOptimalKfactor(BlkLen, g, actk); + + // Run scalar reference implementation + ScalarReference::LutGemmPack_Reference( + N, K, BlkBitWidth, BlkLen, !Symmetric, + g, ngroups_per_elem, simd_n_in, simd_n_out, bm, kfactor, + reinterpret_cast(QuantBData), + QuantBScale, + QuantBZeroPoint, + PackedExpected); + + // Run dispatched implementation via public API + MlasLutGemmPack( + N, K, BlkBitWidth, BlkLen, !Symmetric, + reinterpret_cast(QuantBData), + QuantBScale, + QuantBZeroPoint, + PackedActual, + tp); + + // Compare weight packing portion + size_t packed_b_size = ScalarReference::PackQuantBDataSize(N, BlkBitWidth, K, g, ngroups_per_elem); + + size_t weight_mismatch_count = 0; + constexpr size_t max_mismatches_to_report = 10; + for (size_t i = 0; i < packed_b_size; ++i) { + if (PackedExpected[i] != PackedActual[i]) { + if (weight_mismatch_count < max_mismatches_to_report) { + ADD_FAILURE() << "Weight packing mismatch at byte " << i << " of " << packed_b_size + << ": expected 0x" << std::hex << static_cast(static_cast(PackedExpected[i])) + << ", got 0x" << static_cast(static_cast(PackedActual[i])) << std::dec + << ", N=" << N << ", K=" << K << ", BlkLen=" << BlkLen + << ", bm=" << bm << ", kfactor=" << kfactor; + } + weight_mismatch_count++; + } + } + EXPECT_EQ(weight_mismatch_count, 0u) + << "Weight packing: Total mismatches: " << weight_mismatch_count << " out of " << packed_b_size << " bytes"; + + // Compare scales/zp packing portion + size_t scales_offset = ScalarReference::PackedScalesOffset(packed_b_size); + size_t scales_size = ScalarReference::PackScalesAndZeroPointsSize(N, K, BlkLen, !Symmetric); + + const float* ExpectedScales = reinterpret_cast(PackedExpected + scales_offset); + const float* ActualScales = reinterpret_cast(PackedActual + scales_offset); + + size_t scale_mismatch_count = 0; + for (size_t i = 0; i < scales_size; ++i) { + if (!CloseEnough(ActualScales[i], ExpectedScales[i])) { + if (scale_mismatch_count < max_mismatches_to_report) { + ADD_FAILURE() << "Scale/ZP packing mismatch at index " << i << " of " << scales_size + << ": expected " << ExpectedScales[i] << ", got " << ActualScales[i] + << ", N=" << N << ", K=" << K << ", BlkLen=" << BlkLen + << ", Symmetric=" << Symmetric; + } + scale_mismatch_count++; + } + } + EXPECT_EQ(scale_mismatch_count, 0u) + << "Scale packing: Total mismatches: " << scale_mismatch_count << " out of " << scales_size << " floats"; + } + + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SQLutGemmPack") + + "BlkBitWidth" + std::to_string(BlkBitWidth) + + "BlkLen" + std::to_string(BlkLen); + return suite_name.c_str(); + } +}; + +// +// ============================================================================ +// LUT GENERATION SCALAR REFERENCE +// ============================================================================ +// + +namespace ScalarReference { + +/** + * @brief Scalar reference partial_max for LUT scale computation. + * Computes max(|b0| + |b1| + |b2| + |b3|) across 8 groups of 4 elements. + */ +static void +PartialMax_Reference(float* lut_scales, const float* b) +{ + // Process 32 floats organized as 8 groups of 4 consecutive elements + // Groups: {0-3}, {4-7}, {8-11}, {12-15}, {16-19}, {20-23}, {24-27}, {28-31} + float max_sum = 0.0f; + for (int group = 0; group < 8; ++group) { + float abssum = std::abs(b[group * 4 + 0]) + + std::abs(b[group * 4 + 1]) + + std::abs(b[group * 4 + 2]) + + std::abs(b[group * 4 + 3]); + max_sum = std::max(max_sum, abssum); + } + float scales = max_sum / 127.0f; + *lut_scales = std::max(*lut_scales, scales); +} + +/** + * @brief Scalar reference LUT construction. + * Builds 16-entry LUT for groups of 4 activation values. + */ +static void +LutCtor_Reference( + int32_t act_k, + int8_t* qlut, + const float* b, + float* lut_scales, + float* lut_biases +) +{ + float biases = 0.0f; + float scales = *lut_scales; + float t_scales = scales ? 1.0f / scales : 0.0f; + + for (int k = 0; k < act_k / 32; ++k) { + // For each of 8 groups of 4 elements + float lut[16][8]; // [lut_entry][group] + + for (int group = 0; group < 8; ++group) { + float b0 = b[k * 32 + group * 4 + 0]; + float b1 = b[k * 32 + group * 4 + 1]; + float b2 = b[k * 32 + group * 4 + 2]; + float b3 = b[k * 32 + group * 4 + 3]; + + // Build 16-entry LUT: each entry is ±b0 ±b1 ±b2 ±b3 + for (int g = 1; g < 16; g += 2) { + lut[g][group] = b0; + if (g & 0b0010) { + lut[g][group] += b1; + } else { + lut[g][group] -= b1; + } + if (g & 0b0100) { + lut[g][group] += b2; + } else { + lut[g][group] -= b2; + } + if (g & 0b1000) { + lut[g][group] += b3; + } else { + lut[g][group] -= b3; + } + } + // Symmetric: lut[g] = -lut[15 - g] + for (int g = 0; g < 16; g += 2) { + lut[g][group] = -lut[15 - g][group]; + } + } + + // Accumulate bias + for (int group = 0; group < 8; ++group) { + biases += lut[0][group]; + } + + // Scale and quantize, then store + // Output layout: qlut[k * 8 * 16 + group * 16 + lut_entry] + for (int group = 0; group < 8; ++group) { + for (int g = 0; g < 16; ++g) { + float scaled = lut[g][group] * t_scales; + int8_t quantized = static_cast(std::round(scaled)); + qlut[k * 8 * 16 + group * 16 + g] = quantized; + } + } + } + + *lut_scales = scales; + *lut_biases = biases; +} + +/** + * @brief Scalar reference GenerateLUT. + */ +static void +GenerateLUT_Reference( + const float* b, + int8_t* qlut, + float* lut_scales, + float* lut_biases, + size_t K, + size_t act_group_size +) +{ + const int32_t kk_outer_max = static_cast(K / act_group_size); + const int32_t ags_div32 = static_cast(act_group_size / 32); + + // Phase 1: Compute partial max for each activation group + for (int32_t kk_outer = 0; kk_outer < kk_outer_max; ++kk_outer) { + lut_scales[kk_outer] = 0.0f; + for (int32_t k_outer = 0; k_outer < ags_div32; ++k_outer) { + PartialMax_Reference(&lut_scales[kk_outer], &b[(kk_outer * act_group_size) + (k_outer * 32)]); + } + } + + // Phase 2: Build quantized LUT + for (int32_t k_outer_1 = 0; k_outer_1 < kk_outer_max; ++k_outer_1) { + LutCtor_Reference( + static_cast(act_group_size), + &qlut[k_outer_1 * act_group_size * 4], + &b[k_outer_1 * act_group_size], + &lut_scales[k_outer_1], + &lut_biases[k_outer_1] + ); + } +} + +} // namespace ScalarReference + +// +// ============================================================================ +// LUT GENERATION TEST CLASS +// ============================================================================ +// + +template +class MlasSQLutGemmLutGenTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferActivation; + MatrixGuardBuffer BufferQLutExpected; + MatrixGuardBuffer BufferQLutActual; + MatrixGuardBuffer BufferLutScalesExpected; + MatrixGuardBuffer BufferLutScalesActual; + MatrixGuardBuffer BufferLutBiasesExpected; + MatrixGuardBuffer BufferLutBiasesActual; + + public: + void Test(size_t K) { + constexpr size_t BlkBitWidth = 2; + constexpr size_t N = 128; // Need a valid N for dispatch check + + // Check if LUT GEMM is available + if (!MlasIsLutGemmAvailable(N, K, BlkBitWidth, BlkLen)) { + GTEST_SKIP() << "LUT GEMM not available for this configuration"; + return; + } + + // Determine activation group size (same logic as in NEON kernel) + size_t act_group_size = (BlkLen % 64 == 0) ? 64 : 32; + size_t lut_scales_count = K / act_group_size; + + // Allocate buffers + float* Activation = BufferActivation.GetBuffer(K); + int8_t* QLutExpected = BufferQLutExpected.GetBuffer(K * 4, true); // K * 4 bytes for LUT + float* LutScalesExpected = BufferLutScalesExpected.GetBuffer(lut_scales_count, true); + float* LutBiasesExpected = BufferLutBiasesExpected.GetBuffer(lut_scales_count, true); + + // Generate random activations + std::default_random_engine generator(42); + std::uniform_real_distribution distribution(-10.0f, 10.0f); + for (size_t i = 0; i < K; ++i) { + Activation[i] = distribution(generator); + } + + // Run scalar reference + ScalarReference::GenerateLUT_Reference( + Activation, + QLutExpected, + LutScalesExpected, + LutBiasesExpected, + K, + act_group_size + ); + + // Get the kernel dispatch through internal accessor + // This is defined in qlutgemm.h and qlutgemm.cpp + MlasClearLutGemmKernelConfig(); + MlasInitLutGemmKernelConfig(N, K, BlkBitWidth, BlkLen, false); + + // Use the public GEMM API indirectly by creating a minimal test scenario + // that exercises the GenerateLUT path. We need to call it through the + // internal dispatch mechanism. + + // Access dispatch through platform - this requires linking to internal symbols + // For now, we'll use a workaround: call the full LUT GEMM but with minimal weights + // and compare intermediate LUT results. + + // Since we can't easily access GenerateLUT directly, let's verify the algorithm + // by checking that the scalar reference produces sensible output, then + // trust the integration test (SQLutGemm) to find bugs in the SIMD version. + + // For a proper isolated test, we would need to expose GenerateLUT publicly. + // For now, just verify the scalar reference produces valid output: + + // Check that scales are non-negative + for (size_t i = 0; i < lut_scales_count; ++i) { + EXPECT_GE(LutScalesExpected[i], 0.0f) << "LUT scale should be non-negative"; + } + + // Check that quantized LUT values are within int8 range + for (size_t i = 0; i < K * 4; ++i) { + EXPECT_GE(QLutExpected[i], -128) << "QLUT value out of range"; + EXPECT_LE(QLutExpected[i], 127) << "QLUT value out of range"; + } + + // Log some info for debugging + if (lut_scales_count > 0) { + SCOPED_TRACE(testing::Message() << "First LUT scale: " << LutScalesExpected[0] + << ", First LUT bias: " << LutBiasesExpected[0]); + } + } + + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SQLutGemmLutGen") + "BlkLen" + std::to_string(BlkLen); + return suite_name.c_str(); + } +}; + +// +// ============================================================================ +// TEST FIXTURES +// ============================================================================ +// + +template +class SQLutGemmPackShortExecuteTest : public MlasTestFixture> { + public: + explicit SQLutGemmPackShortExecuteTest(size_t N, size_t K, bool Symmetric) + : N_(N), K_(K), Symmetric_(Symmetric) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(N_, K_, Symmetric_); + } + + static size_t RegisterSingleTest(size_t N, size_t K, bool Symmetric) { + if (!MlasIsLutGemmAvailable(N, K, BlkBitWidth, BlkLen)) { + return 0; + } + + std::stringstream ss; + ss << "Pack" + << "/isSymmetric" << Symmetric + << "/N" << N << "xK" << K; + + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSQLutGemmPackTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture>* { + return new SQLutGemmPackShortExecuteTest(N, K, Symmetric); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t count = 0; + for (bool symmetric : {true, false}) { + // Test various N, K combinations + for (size_t n : {128, 256, 512}) { + for (size_t k : {128, 256, 512}) { + count += RegisterSingleTest(n, k, symmetric); + } + } + } + return count; + } + + private: + size_t N_, K_; + bool Symmetric_; +}; + +// +// LUT Generation Test Fixture +// +template +class SQLutGemmLutGenShortExecuteTest : public MlasTestFixture> { + public: + explicit SQLutGemmLutGenShortExecuteTest(size_t K) : K_(K) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(K_); + } + + static size_t RegisterSingleTest(size_t K) { + constexpr size_t BlkBitWidth = 2; + if (!MlasIsLutGemmAvailable(128, K, BlkBitWidth, BlkLen)) { + return 0; + } + + std::stringstream ss; + ss << "LutGen/K" << K; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSQLutGemmLutGenTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture>* { + return new SQLutGemmLutGenShortExecuteTest(K); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t count = 0; + for (size_t k : {64, 128, 256, 512}) { + count += RegisterSingleTest(k); + } + return count; + } + + private: + size_t K_; +}; + +// +// ============================================================================ +// TEST REGISTRATION +// ============================================================================ +// + +static size_t SQLutGemmComponentsRegisterAllShortExecuteTests() { + size_t count = 0; + + // Pack tests for 2-bit quantization with various block lengths + count += SQLutGemmPackShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + count += SQLutGemmPackShortExecuteTest<2, 64>::RegisterShortExecuteTests(); + count += SQLutGemmPackShortExecuteTest<2, 128>::RegisterShortExecuteTests(); + + // LUT generation tests + count += SQLutGemmLutGenShortExecuteTest<32>::RegisterShortExecuteTests(); + count += SQLutGemmLutGenShortExecuteTest<64>::RegisterShortExecuteTests(); + count += SQLutGemmLutGenShortExecuteTest<128>::RegisterShortExecuteTests(); + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return SQLutGemmComponentsRegisterAllShortExecuteTests(); + } + return 0; + }); From 84e44b34d8daffb88462d7fe16ac38129705c3ba Mon Sep 17 00:00:00 2001 From: Vrajang Parikh Date: Fri, 6 Feb 2026 23:00:56 +0000 Subject: [PATCH 10/10] Cleanup: Improve naming, and file structure --- onnxruntime/core/mlas/lib/mlasi.h | 6 +- onnxruntime/core/mlas/lib/platform.cpp | 4 +- onnxruntime/core/mlas/lib/qlutgemm.cpp | 10 +- .../mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp | 46 +- .../mlas/lib/sqnbitgemm_lut_kernel_avx2.h | 2 +- .../mlas/lib/sqnbitgemm_lut_kernel_neon.cpp | 52 +- .../mlas/lib/sqnbitgemm_lut_kernel_neon.h | 2 +- .../{test_sqlutgemm.cpp => test_lutgemm.cpp} | 34 +- ...m_components.cpp => test_lutgemm_pack.cpp} | 658 +++++++++--------- 9 files changed, 404 insertions(+), 410 deletions(-) rename onnxruntime/test/mlas/unittest/{test_sqlutgemm.cpp => test_lutgemm.cpp} (86%) rename onnxruntime/test/mlas/unittest/{test_sqlutgemm_components.cpp => test_lutgemm_pack.cpp} (51%) diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index d6299af5d8f80..150af9de6d342 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1243,10 +1243,10 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx; struct MLAS_QNBIT_LUT_GEMM_DISPATCH; -extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2; +extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchAvx2; #if defined(MLAS_TARGET_ARM64) -extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon; +extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchNeon; #endif // @@ -1457,7 +1457,7 @@ struct MLAS_PLATFORM { const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr}; - const MLAS_QNBIT_LUT_GEMM_DISPATCH* LutGenKernel{nullptr}; + const MLAS_QNBIT_LUT_GEMM_DISPATCH* LutGemmDispatch{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index b7242dd4704db..7360e692d64fc 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -422,7 +422,7 @@ Return Value: this->RopeDispatch = &MlasRopeDispatchAvx2; // TODO(vraspar): check if this really goes here or if there are other platform reqs that we need to fulfill - this->LutGenKernel = &MlasLutGenKernelAvx2; + this->LutGemmDispatch = &MlasLutGemmDispatchAvx2; // // Check if the processor supports Hybrid core architecture. @@ -655,7 +655,7 @@ Return Value: this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, HasI8MMInstructions); // Enable LUT-based GEMM for 2-bit quantization on ARM64 - this->LutGenKernel = &MlasLutGenKernelNeon; + this->LutGemmDispatch = &MlasLutGemmDispatchNeon; #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon; diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index 73e185075cc19..94fa2d870e623 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -191,7 +191,7 @@ LutGemmPackQuantBData( const size_t kfactor = tmac_params.kfactor; // LUT GEMM requires a valid LUT dispatch implementation, so dispatch must be available - const auto* Dispatch = GetMlasPlatform().LutGenKernel; + const auto* Dispatch = GetMlasPlatform().LutGemmDispatch; if (Dispatch == nullptr || Dispatch->PackQuantBData == nullptr) { MLAS_THROW_EX(std::runtime_error, "PackQuantBData requires LUT GEMM dispatch support"); } @@ -240,9 +240,9 @@ LutPackScalesAndZeroPoints( const size_t bm = tmac_params.bm; // LUT GEMM is only available for AVX2, so dispatch must be available - const auto* Dispatch = GetMlasPlatform().LutGenKernel; + const auto* Dispatch = GetMlasPlatform().LutGemmDispatch; if (Dispatch == nullptr || Dispatch->PackScalesAndZeroPoints == nullptr) { - MLAS_THROW_EX(std::runtime_error, "PackScalesAndZeroPoints requires AVX2 dispatch"); + MLAS_THROW_EX(std::runtime_error, "PackScalesAndZeroPoints requires LUT GEMM dispatch support"); } Dispatch->PackScalesAndZeroPoints( @@ -320,7 +320,7 @@ MlasIsLutGemmAvailable( size_t BlkLen ) { - const auto* lut_kernel = GetMlasPlatform().LutGenKernel; + const auto* lut_kernel = GetMlasPlatform().LutGemmDispatch; if (lut_kernel == nullptr || lut_kernel->GenerateLUT == nullptr || lut_kernel->ComputeGemm == nullptr || @@ -392,7 +392,7 @@ MlasLutGemm( ) { // adapted from ggml_backend_tmac_mul_mat - const auto* Dispatch = GetMlasPlatform().LutGenKernel; + const auto* Dispatch = GetMlasPlatform().LutGemmDispatch; // This should be ensured by calling MlasIsLutGemmAvailable() before MlasLutGemm() if (Dispatch == nullptr || Dispatch->GenerateLUT == nullptr || Dispatch->ComputeGemm == nullptr) { MLAS_THROW_EX(std::runtime_error, "TMAC not supported in this configuration"); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index cc8308675e399..5c9f31e1a6ffa 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -54,6 +54,12 @@ _mm256_addv_ps(const __m256 v) #define extract_low_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_castsi256_si128(v)) #define extract_high_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_extracti128_si256(v, 1)) +namespace lutgemm_avx2 +{ + +namespace +{ + // Template classes for accumulation template struct SignedHalvingAdder { @@ -324,9 +330,11 @@ lut_ctor_g4_int8_impl( *lut_biases = biases; } -// based on lut_ctor_g4_int8_impl +} // namespace + +// LutGemmGenerateLUT_CompFp32 - Entry point for LUT generation void -GenerateLUT_avx2( +LutGemmGenerateLUT_CompFp32( const float* b, int8_t* qlut, float* lut_scales, @@ -495,10 +503,9 @@ tbl_g4_int8_float_update_impl(int32_t m, float* c, const int8_t* lut, const uint return 0; } -// based on qgemm_lut_int8_g4 -// Simplified version with hardcoded configuration for 2-bit quantization +// LutGemmCompute_CompFp32 - Entry point for GEMM computation void -TMACComputeGemm_avx2( +LutGemmCompute_CompFp32( const uint8_t* A, // Quantized packed weights const float* Scales, // Weight scales (and optionally zero-points) const int8_t* LUT, // Pre-computed quantized lookup table @@ -651,11 +658,11 @@ TMACComputeGemm_avx2( } // -// AVX2 optimized weight packing for T-MAC LUT GEMM +// LutGemmPackQuantBData_CompFp32 - AVX2 optimized weight packing for T-MAC LUT GEMM // This performs the same transformation as the scalar version but uses SIMD operations // void -PackQuantBData_avx2( +LutGemmPackQuantBData_CompFp32( size_t N, size_t K, size_t bits, @@ -864,12 +871,11 @@ PackQuantBData_avx2( } // -// AVX2 optimized scales and zero points packing for T-MAC LUT GEMM -// This performs the same transformation as the scalar version but uses SIMD operations +// LutGemmPackScalesAndZeroPoints_CompFp32 - Scales and zero points packing // template -static void -PackScalesAndZeroPoints_avx2_impl( +void +LutGemmPackScalesAndZeroPoints_CompFp32_Impl( size_t N, size_t K, size_t bits, @@ -984,7 +990,7 @@ PackScalesAndZeroPoints_avx2_impl( } void -PackScalesAndZeroPoints_avx2( +LutGemmPackScalesAndZeroPoints_CompFp32( size_t N, size_t K, size_t bits, @@ -1002,25 +1008,27 @@ PackScalesAndZeroPoints_avx2( assert(bits == 2); if (HasZeroPoint) { - PackScalesAndZeroPoints_avx2_impl( + LutGemmPackScalesAndZeroPoints_CompFp32_Impl( N, K, bits, BlkLen, simd_n_out, bm, PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool ); } else { - PackScalesAndZeroPoints_avx2_impl( + LutGemmPackScalesAndZeroPoints_CompFp32_Impl( N, K, bits, BlkLen, simd_n_out, bm, PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool ); } } +} // namespace lutgemm_avx2 + // Kernel dispatch structure definition. -const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2 = []() { +const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchAvx2 = []() { MLAS_QNBIT_LUT_GEMM_DISPATCH d; - d.GenerateLUT = GenerateLUT_avx2; - d.ComputeGemm = TMACComputeGemm_avx2; - d.PackQuantBData = PackQuantBData_avx2; - d.PackScalesAndZeroPoints = PackScalesAndZeroPoints_avx2; + d.GenerateLUT = lutgemm_avx2::LutGemmGenerateLUT_CompFp32; + d.ComputeGemm = lutgemm_avx2::LutGemmCompute_CompFp32; + d.PackQuantBData = lutgemm_avx2::LutGemmPackQuantBData_CompFp32; + d.PackScalesAndZeroPoints = lutgemm_avx2::LutGemmPackScalesAndZeroPoints_CompFp32; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h index e900080403d51..1f4afa89591fb 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h @@ -23,4 +23,4 @@ Module Name: // External dispatch table for AVX2 LUT GEMM kernels. // Kernel functions are internal to the .cpp file and accessed via this dispatch. // -extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2; +extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchAvx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp index 23cc244eb28fd..8b75e3ef7fb12 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.cpp @@ -46,6 +46,12 @@ Module Name: #define PRAGMA_UNROLL #endif +namespace lutgemm_neon +{ + +namespace +{ + // // Template classes for accumulation - adapted from llama.cpp tbl.cpp // @@ -282,11 +288,13 @@ lut_ctor_g4_int8_impl_neon( *lut_biases = biases; } +} // namespace + // -// GenerateLUT - Entry point for LUT generation +// LutGemmGenerateLUT_CompFp32 - Entry point for LUT generation // -static void -GenerateLUT_neon( +void +LutGemmGenerateLUT_CompFp32( const float* b, int8_t* qlut, float* lut_scales, @@ -620,10 +628,10 @@ tbl_g4_int8_float_update_impl_neon( } // -// TMACComputeGemm - Entry point for GEMM computation +// LutGemmCompute_CompFp32 - Entry point for GEMM computation // -static void -TMACComputeGemm_neon( +void +LutGemmCompute_CompFp32( const uint8_t* A, const float* Scales, const int8_t* LUT, @@ -756,11 +764,11 @@ TMACComputeGemm_neon( } // -// Weight packing for NEON (can use scalar or NEON implementation) +// LutGemmPackQuantBData_CompFp32 - Weight packing for NEON // This is done during model load, so performance is less critical // -static void -PackQuantBData_neon( +void +LutGemmPackQuantBData_CompFp32( size_t N, size_t K, size_t bits, @@ -917,11 +925,11 @@ PackQuantBData_neon( } // -// Scales and zero points packing +// LutGemmPackScalesAndZeroPoints_CompFp32 - Scales and zero points packing // template -static void -PackScalesAndZeroPoints_neon_impl( +void +LutGemmPackScalesAndZeroPoints_CompFp32_Impl( size_t N, size_t K, size_t bits, @@ -991,8 +999,8 @@ PackScalesAndZeroPoints_neon_impl( ); } -static void -PackScalesAndZeroPoints_neon( +void +LutGemmPackScalesAndZeroPoints_CompFp32( size_t N, size_t K, size_t bits, @@ -1009,27 +1017,29 @@ PackScalesAndZeroPoints_neon( assert(bits == 2); if (HasZeroPoint) { - PackScalesAndZeroPoints_neon_impl( + LutGemmPackScalesAndZeroPoints_CompFp32_Impl( N, K, bits, BlkLen, simd_n_out, bm, PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool ); } else { - PackScalesAndZeroPoints_neon_impl( + LutGemmPackScalesAndZeroPoints_CompFp32_Impl( N, K, bits, BlkLen, simd_n_out, bm, PackedScalesBegin, QuantBScale, QuantBZeroPoint, ThreadPool ); } } +} // namespace lutgemm_neon + // // Kernel dispatch structure definition // -const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon = []() { +const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchNeon = []() { MLAS_QNBIT_LUT_GEMM_DISPATCH d; - d.GenerateLUT = GenerateLUT_neon; - d.ComputeGemm = TMACComputeGemm_neon; - d.PackQuantBData = PackQuantBData_neon; - d.PackScalesAndZeroPoints = PackScalesAndZeroPoints_neon; + d.GenerateLUT = lutgemm_neon::LutGemmGenerateLUT_CompFp32; + d.ComputeGemm = lutgemm_neon::LutGemmCompute_CompFp32; + d.PackQuantBData = lutgemm_neon::LutGemmPackQuantBData_CompFp32; + d.PackScalesAndZeroPoints = lutgemm_neon::LutGemmPackScalesAndZeroPoints_CompFp32; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.h b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.h index f8710c8a90c0b..e638945e51808 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_neon.h @@ -23,5 +23,5 @@ Module Name: // External dispatch table for ARM NEON LUT GEMM kernels. // Kernel functions are internal to the .cpp file and accessed via this dispatch. // -extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon; +extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchNeon; diff --git a/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp b/onnxruntime/test/mlas/unittest/test_lutgemm.cpp similarity index 86% rename from onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp rename to onnxruntime/test/mlas/unittest/test_lutgemm.cpp index 181fda23f299d..bcc25d2980bfb 100644 --- a/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_lutgemm.cpp @@ -6,11 +6,11 @@ Licensed under the MIT License. Module Name: - test_sqlutgemm.cpp + test_lutgemm.cpp Abstract: - Tests for MLAS LUT-based n-bit GEMM (TMAC/LUT path) for 2-bit.. + Tests for MLAS LUT-based n-bit GEMM (TMAC/LUT path) for 2-bit. --*/ @@ -20,7 +20,7 @@ Module Name: // Generic template to future-proof for different bit widths; instantiate with 2 for now. template -class MlasSQLutGemmTest : public MlasTestBase { +class MlasLutGemmTest : public MlasTestBase { private: MatrixGuardBuffer BufferA; MatrixGuardBuffer BufferB; @@ -145,7 +145,7 @@ class MlasSQLutGemmTest : public MlasTestBase { public: static const char* GetTestSuiteName() { - static std::string suite_name = std::string("SQLutGemm") + + static std::string suite_name = std::string("LutGemm") + "BlkBitWidth" + std::to_string(BlkBitWidth) + "BlkLen" + std::to_string(BlkLen); return suite_name.c_str(); @@ -154,10 +154,10 @@ class MlasSQLutGemmTest : public MlasTestBase { // Fixture to register parameterized tests quickly template -class SQLutGemmShortExecuteTest : public MlasTestFixture> { +class LutGemmShortExecuteTest : public MlasTestFixture> { public: - explicit SQLutGemmShortExecuteTest(size_t M, size_t N, size_t K, - bool WithThreadpool, bool Symmetric) + explicit LutGemmShortExecuteTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool Symmetric) : M_(M), N_(N), K_(K), @@ -166,7 +166,7 @@ class SQLutGemmShortExecuteTest : public MlasTestFixture>::mlas_tester->Test( + MlasTestFixture>::mlas_tester->Test( M_, N_, K_, WithThreadpool_, Symmetric_); } @@ -187,15 +187,15 @@ class SQLutGemmShortExecuteTest : public MlasTestFixture::GetTestSuiteName(), + MlasLutGemmTest::GetTestSuiteName(), test_name.c_str(), nullptr, test_name.c_str(), __FILE__, __LINE__, // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new SQLutGemmShortExecuteTest( + [=]() -> MlasTestFixture>* { + return new LutGemmShortExecuteTest( M, N, K, WithThreadpool, Symmetric); }); @@ -225,19 +225,19 @@ class SQLutGemmShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); - count += SQLutGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); - count += SQLutGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); - count += SQLutGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); + count += LutGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + count += LutGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); + count += LutGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); + count += LutGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); return count; } static UNUSED_VARIABLE bool added_to_main = AddTestRegister( [](bool is_short_execute) -> size_t { if (is_short_execute) { - return SQLutGemmRegisterAllShortExecuteTests(); + return LutGemmRegisterAllShortExecuteTests(); } return 0; }); diff --git a/onnxruntime/test/mlas/unittest/test_sqlutgemm_components.cpp b/onnxruntime/test/mlas/unittest/test_lutgemm_pack.cpp similarity index 51% rename from onnxruntime/test/mlas/unittest/test_sqlutgemm_components.cpp rename to onnxruntime/test/mlas/unittest/test_lutgemm_pack.cpp index 0048e8f8d5d06..5aafdccb7c3bf 100644 --- a/onnxruntime/test/mlas/unittest/test_sqlutgemm_components.cpp +++ b/onnxruntime/test/mlas/unittest/test_lutgemm_pack.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - test_sqlutgemm_components.cpp + test_lutgemm_pack.cpp Abstract: @@ -48,10 +48,8 @@ PackQuantBDataSize( size_t bits, size_t K, size_t g, - size_t ngroups_per_elem -) -{ - return (N * bits) * (K / g / ngroups_per_elem); + size_t ngroups_per_elem) { + return (N * bits) * (K / g / ngroups_per_elem); } /** @@ -62,14 +60,12 @@ PackScalesAndZeroPointsSize( size_t N, size_t K, size_t BlkLen, - bool HasZeroPoint -) -{ - if (HasZeroPoint) { - return N * K / BlkLen * 2; - } else { - return N * K / BlkLen; - } + bool HasZeroPoint) { + if (HasZeroPoint) { + return N * K / BlkLen * 2; + } else { + return N * K / BlkLen; + } } /** @@ -77,11 +73,9 @@ PackScalesAndZeroPointsSize( */ static size_t PackedScalesOffset( - size_t PackedQuantBDataSize -) -{ - constexpr size_t kAlignment = 64; - return ((PackedQuantBDataSize + kAlignment - 1) / kAlignment) * kAlignment; + size_t PackedQuantBDataSize) { + constexpr size_t kAlignment = 64; + return ((PackedQuantBDataSize + kAlignment - 1) / kAlignment) * kAlignment; } /** @@ -105,103 +99,100 @@ PackQuantBData_Reference( size_t bm, size_t kfactor, const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin -) -{ - assert(bits == 2 && g == 4 && ngroups_per_elem == 2); - - const size_t mgroup = ngroups_per_elem * simd_n_in; // 32 - assert(bm % mgroup == 0); - assert(bm % bits == 0); - - std::unique_ptr buf(new uint8_t[N * bits * (K / g)]); - memset(buf.get(), 0, N * bits * (K / g)); - - // Phase 1: Bit-plane decomposition - for (size_t im = 0; im < N; ++im) { - for (size_t ik = 0; ik < K; ++ik) { - size_t idx = (im * K + ik); - size_t num_elem_per_byte = 8 / bits; - size_t elem_idx = idx % num_elem_per_byte; - - uint8_t v = ((const uint8_t*)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits); - - for (size_t ib = 0; ib < bits; ++ib) { - size_t new_ik = ik / g; - size_t shft_left = ik % g; - buf[im * bits * K / g + ib * K / g + new_ik] += - static_cast(((v >> ib) & 1) << shft_left); - } - } + std::byte* PackedQuantBDataBegin) { + assert(bits == 2 && g == 4 && ngroups_per_elem == 2); + + const size_t mgroup = ngroups_per_elem * simd_n_in; // 32 + assert(bm % mgroup == 0); + assert(bm % bits == 0); + + std::unique_ptr buf(new uint8_t[N * bits * (K / g)]); + memset(buf.get(), 0, N * bits * (K / g)); + + // Phase 1: Bit-plane decomposition + for (size_t im = 0; im < N; ++im) { + for (size_t ik = 0; ik < K; ++ik) { + size_t idx = (im * K + ik); + size_t num_elem_per_byte = 8 / bits; + size_t elem_idx = idx % num_elem_per_byte; + + uint8_t v = ((const uint8_t*)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits); + + for (size_t ib = 0; ib < bits; ++ib) { + size_t new_ik = ik / g; + size_t shft_left = ik % g; + buf[im * bits * K / g + ib * K / g + new_ik] += + static_cast(((v >> ib) & 1) << shft_left); + } } + } - // Phase 2: Multi-reshape/transpose into final layout - const size_t c0_fac2 = K / g; - const size_t c0_fac1 = simd_n_out * c0_fac2; - const size_t c0_fac0 = bits * c0_fac1; - - const size_t c1_nb2 = K / g; - const size_t c1_nb1 = simd_n_in * c1_nb2; - const size_t c1_nb0 = ngroups_per_elem * c1_nb1; - const size_t c1_fac2 = K / g; - const size_t c1_fac1 = ngroups_per_elem * c1_fac2; - const size_t c1_fac0 = simd_n_in * c1_fac1; - - const size_t c2_nb4 = kfactor; - const size_t c2_nb3 = K / g / kfactor * c2_nb4; - const size_t c2_nb2 = ngroups_per_elem * c2_nb3; - const size_t c2_nb1 = simd_n_in * c2_nb2; - const size_t c2_nb0 = bm / mgroup * c2_nb1; - const size_t c2_fac3 = simd_n_in * ngroups_per_elem; - const size_t c2_fac2 = kfactor * c2_fac3; - const size_t c2_fac1 = bm / mgroup * c2_fac2; - const size_t c2_fac0 = K / g / kfactor * c2_fac1; - - const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem); - memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); - - for (size_t im = 0; im < N; ++im) { - for (size_t ib = 0; ib < bits; ib++) { - for (size_t ik = 0; ik < K / g; ik++) { - // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) - size_t new_im = im / simd_n_out; - size_t new_isno = im % simd_n_out; - size_t new_ib = ib; - size_t new_ik = ik; - size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; - - // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) - new_im = new_idx / c1_nb0; - size_t new_ing = (new_idx % c1_nb0) / c1_nb1; - size_t new_isni = (new_idx % c1_nb1) / c1_nb2; - new_ik = (new_idx % c1_nb2); - new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; - - // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) - new_im = new_idx / c2_nb0; - size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; - new_isni = (new_idx % c2_nb1) / c2_nb2; - new_ing = (new_idx % c2_nb2) / c2_nb3; - new_ik = (new_idx % c2_nb3) / c2_nb4; - size_t new_ikf = (new_idx % c2_nb4); - new_idx = new_im * c2_fac0 + - new_ik * c2_fac1 + - new_ibm * c2_fac2 + - new_ikf * c2_fac3 + - new_isni * ngroups_per_elem + - new_ing; - new_idx = new_idx / ngroups_per_elem; - size_t buf_idx = im * bits * K / g + ib * K / g + ik; - uint8_t buf_val = buf[buf_idx]; - - // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) - PackedQuantBDataBegin[new_idx] = static_cast( - static_cast(PackedQuantBDataBegin[new_idx]) + - (buf_val << (new_ing * g)) - ); - } - } + // Phase 2: Multi-reshape/transpose into final layout + const size_t c0_fac2 = K / g; + const size_t c0_fac1 = simd_n_out * c0_fac2; + const size_t c0_fac0 = bits * c0_fac1; + + const size_t c1_nb2 = K / g; + const size_t c1_nb1 = simd_n_in * c1_nb2; + const size_t c1_nb0 = ngroups_per_elem * c1_nb1; + const size_t c1_fac2 = K / g; + const size_t c1_fac1 = ngroups_per_elem * c1_fac2; + const size_t c1_fac0 = simd_n_in * c1_fac1; + + const size_t c2_nb4 = kfactor; + const size_t c2_nb3 = K / g / kfactor * c2_nb4; + const size_t c2_nb2 = ngroups_per_elem * c2_nb3; + const size_t c2_nb1 = simd_n_in * c2_nb2; + const size_t c2_nb0 = bm / mgroup * c2_nb1; + const size_t c2_fac3 = simd_n_in * ngroups_per_elem; + const size_t c2_fac2 = kfactor * c2_fac3; + const size_t c2_fac1 = bm / mgroup * c2_fac2; + const size_t c2_fac0 = K / g / kfactor * c2_fac1; + + const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem); + memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); + + for (size_t im = 0; im < N; ++im) { + for (size_t ib = 0; ib < bits; ib++) { + for (size_t ik = 0; ik < K / g; ik++) { + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + size_t new_im = im / simd_n_out; + size_t new_isno = im % simd_n_out; + size_t new_ib = ib; + size_t new_ik = ik; + size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; + + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + new_im = new_idx / c1_nb0; + size_t new_ing = (new_idx % c1_nb0) / c1_nb1; + size_t new_isni = (new_idx % c1_nb1) / c1_nb2; + new_ik = (new_idx % c1_nb2); + new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; + + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + new_im = new_idx / c2_nb0; + size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; + new_isni = (new_idx % c2_nb1) / c2_nb2; + new_ing = (new_idx % c2_nb2) / c2_nb3; + new_ik = (new_idx % c2_nb3) / c2_nb4; + size_t new_ikf = (new_idx % c2_nb4); + new_idx = new_im * c2_fac0 + + new_ik * c2_fac1 + + new_ibm * c2_fac2 + + new_ikf * c2_fac3 + + new_isni * ngroups_per_elem + + new_ing; + new_idx = new_idx / ngroups_per_elem; + size_t buf_idx = im * bits * K / g + ib * K / g + ik; + uint8_t buf_val = buf[buf_idx]; + + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + PackedQuantBDataBegin[new_idx] = static_cast( + static_cast(PackedQuantBDataBegin[new_idx]) + + (buf_val << (new_ing * g))); + } } + } } /** @@ -222,59 +213,57 @@ PackScalesAndZeroPoints_Reference( bool HasZeroPoint, float* PackedScalesBegin, const float* QuantBScale, - const uint8_t* QuantBZeroPoint -) -{ - const size_t num_elem_per_byte = 8 / bits; - - // ZP array is column-major packed, with per-column alignment to byte boundary - const size_t row_blks = K / BlkLen; // number of blocks per column - const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte; - - for (size_t im = 0; im < N; im += 1) { - for (size_t ik = 0; ik < K; ik += BlkLen) { - size_t idx = (im * K + ik) / BlkLen; // linear block index for scale - float scale = QuantBScale[idx]; - float zp = 0.0f; - if (HasZeroPoint) { - size_t blk_in_col = ik / BlkLen; // block index within column - size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte; - size_t elem_idx = blk_in_col % num_elem_per_byte; - uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & ((1 << bits) - 1); - - // The LUT kernel assumes weights are centered around the midpoint (2 for 2-bit). - int midpoint = 1 << (bits - 1); // 2 for 2-bit - zp = static_cast(static_cast(v) - midpoint) * scale; - } - - size_t nb1 = K / BlkLen; - size_t nb0 = bm / bits * nb1; - - size_t new_im, new_ibm, new_ik; - if (nb1 == 0) { - new_im = 0; - new_ibm = 0; - new_ik = 0; - } else { - new_im = idx / nb0; - new_ibm = (idx % nb0) / nb1; - new_ik = (idx % nb1); - } - - if (HasZeroPoint) { - size_t new_isimd = new_ibm % simd_n_out; - size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out; - size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; - size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd; - - PackedScalesBegin[new_idx_scale] = scale; - PackedScalesBegin[new_idx_zero] = zp; - } else { - size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm; - PackedScalesBegin[new_idx] = scale; - } - } + const uint8_t* QuantBZeroPoint) { + const size_t num_elem_per_byte = 8 / bits; + + // ZP array is column-major packed, with per-column alignment to byte boundary + const size_t row_blks = K / BlkLen; // number of blocks per column + const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte; + + for (size_t im = 0; im < N; im += 1) { + for (size_t ik = 0; ik < K; ik += BlkLen) { + size_t idx = (im * K + ik) / BlkLen; // linear block index for scale + float scale = QuantBScale[idx]; + float zp = 0.0f; + if (HasZeroPoint) { + size_t blk_in_col = ik / BlkLen; // block index within column + size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte; + size_t elem_idx = blk_in_col % num_elem_per_byte; + uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & ((1 << bits) - 1); + + // The LUT kernel assumes weights are centered around the midpoint (2 for 2-bit). + int midpoint = 1 << (bits - 1); // 2 for 2-bit + zp = static_cast(static_cast(v) - midpoint) * scale; + } + + size_t nb1 = K / BlkLen; + size_t nb0 = bm / bits * nb1; + + size_t new_im, new_ibm, new_ik; + if (nb1 == 0) { + new_im = 0; + new_ibm = 0; + new_ik = 0; + } else { + new_im = idx / nb0; + new_ibm = (idx % nb0) / nb1; + new_ik = (idx % nb1); + } + + if (HasZeroPoint) { + size_t new_isimd = new_ibm % simd_n_out; + size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out; + size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; + size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd; + + PackedScalesBegin[new_idx_scale] = scale; + PackedScalesBegin[new_idx_zero] = zp; + } else { + size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm; + PackedScalesBegin[new_idx] = scale; + } } + } } /** @@ -298,29 +287,25 @@ LutGemmPack_Reference( const std::byte* QuantBData, const float* QuantBScale, const uint8_t* QuantBZeroPoint, - std::byte* PackedBuf -) -{ - // Pack B data - if (QuantBData != nullptr) { - PackQuantBData_Reference( - N, K, bits, g, ngroups_per_elem, - simd_n_in, simd_n_out, bm, kfactor, - QuantBData, PackedBuf - ); - } + std::byte* PackedBuf) { + // Pack B data + if (QuantBData != nullptr) { + PackQuantBData_Reference( + N, K, bits, g, ngroups_per_elem, + simd_n_in, simd_n_out, bm, kfactor, + QuantBData, PackedBuf); + } - // Pack scales/zero points - if (QuantBScale != nullptr) { - size_t packed_b_size = PackQuantBDataSize(N, bits, K, g, ngroups_per_elem); - size_t scales_offset = PackedScalesOffset(packed_b_size); - float* scales_dest = reinterpret_cast(PackedBuf + scales_offset); + // Pack scales/zero points + if (QuantBScale != nullptr) { + size_t packed_b_size = PackQuantBDataSize(N, bits, K, g, ngroups_per_elem); + size_t scales_offset = PackedScalesOffset(packed_b_size); + float* scales_dest = reinterpret_cast(PackedBuf + scales_offset); - PackScalesAndZeroPoints_Reference( - N, K, bits, BlkLen, simd_n_out, bm, - HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint - ); - } + PackScalesAndZeroPoints_Reference( + N, K, bits, BlkLen, simd_n_out, bm, + HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint); + } } /** @@ -329,33 +314,31 @@ LutGemmPack_Reference( * This mirrors the logic in MlasInitLutGemmKernelConfig */ static size_t -SelectOptimalBm(size_t N, size_t bits) -{ - std::vector bms = {256, 512, 1024, 2048, 320, 640, 1280}; - - // Use a simple heuristic: pick the largest bm that divides N * bits evenly - for (size_t bm : bms) { - if (N % (bm / bits) == 0 && bm % bits == 0) { - return bm; - } +SelectOptimalBm(size_t N, size_t bits) { + std::vector bms = {256, 512, 1024, 2048, 320, 640, 1280}; + + // Use a simple heuristic: pick the largest bm that divides N * bits evenly + for (size_t bm : bms) { + if (N % (bm / bits) == 0 && bm % bits == 0) { + return bm; } - return bms[0]; // fallback + } + return bms[0]; // fallback } /** * @brief Select optimal kfactor */ static size_t -SelectOptimalKfactor(size_t BlkLen, size_t g, size_t actk) -{ - std::vector kfactors = {16, 8}; - - for (size_t kfactor : kfactors) { - if (kfactor >= actk && kfactor * g <= BlkLen) { - return kfactor; - } +SelectOptimalKfactor(size_t BlkLen, size_t g, size_t actk) { + std::vector kfactors = {16, 8}; + + for (size_t kfactor : kfactors) { + if (kfactor >= actk && kfactor * g <= BlkLen) { + return kfactor; } - return kfactors.back(); + } + return kfactors.back(); } } // namespace ScalarReference @@ -372,7 +355,7 @@ SelectOptimalKfactor(size_t BlkLen, size_t g, size_t actk) * Compares the dispatched (NEON/AVX2) MlasLutGemmPack against the scalar reference. */ template -class MlasSQLutGemmPackTest : public MlasTestBase { +class MlasLutGemmPackTest : public MlasTestBase { private: MatrixGuardBuffer BufferB; MatrixGuardBuffer BufferQuantBData; @@ -424,7 +407,7 @@ class MlasSQLutGemmPackTest : public MlasTestBase { constexpr size_t ngroups_per_elem = 2; constexpr size_t simd_n_in = 16; constexpr size_t simd_n_out = 8; - + size_t bm = ScalarReference::SelectOptimalBm(N, BlkBitWidth); size_t act_group_size = (BlkLen % 64 == 0) ? 64 : 32; size_t actk = act_group_size / g; @@ -450,17 +433,17 @@ class MlasSQLutGemmPackTest : public MlasTestBase { // Compare weight packing portion size_t packed_b_size = ScalarReference::PackQuantBDataSize(N, BlkBitWidth, K, g, ngroups_per_elem); - + size_t weight_mismatch_count = 0; constexpr size_t max_mismatches_to_report = 10; for (size_t i = 0; i < packed_b_size; ++i) { if (PackedExpected[i] != PackedActual[i]) { if (weight_mismatch_count < max_mismatches_to_report) { ADD_FAILURE() << "Weight packing mismatch at byte " << i << " of " << packed_b_size - << ": expected 0x" << std::hex << static_cast(static_cast(PackedExpected[i])) - << ", got 0x" << static_cast(static_cast(PackedActual[i])) << std::dec - << ", N=" << N << ", K=" << K << ", BlkLen=" << BlkLen - << ", bm=" << bm << ", kfactor=" << kfactor; + << ": expected 0x" << std::hex << static_cast(static_cast(PackedExpected[i])) + << ", got 0x" << static_cast(static_cast(PackedActual[i])) << std::dec + << ", N=" << N << ", K=" << K << ", BlkLen=" << BlkLen + << ", bm=" << bm << ", kfactor=" << kfactor; } weight_mismatch_count++; } @@ -480,9 +463,9 @@ class MlasSQLutGemmPackTest : public MlasTestBase { if (!CloseEnough(ActualScales[i], ExpectedScales[i])) { if (scale_mismatch_count < max_mismatches_to_report) { ADD_FAILURE() << "Scale/ZP packing mismatch at index " << i << " of " << scales_size - << ": expected " << ExpectedScales[i] << ", got " << ActualScales[i] - << ", N=" << N << ", K=" << K << ", BlkLen=" << BlkLen - << ", Symmetric=" << Symmetric; + << ": expected " << ExpectedScales[i] << ", got " << ActualScales[i] + << ", N=" << N << ", K=" << K << ", BlkLen=" << BlkLen + << ", Symmetric=" << Symmetric; } scale_mismatch_count++; } @@ -492,7 +475,7 @@ class MlasSQLutGemmPackTest : public MlasTestBase { } static const char* GetTestSuiteName() { - static std::string suite_name = std::string("SQLutGemmPack") + + static std::string suite_name = std::string("LutGemmPack") + "BlkBitWidth" + std::to_string(BlkBitWidth) + "BlkLen" + std::to_string(BlkLen); return suite_name.c_str(); @@ -512,20 +495,19 @@ namespace ScalarReference { * Computes max(|b0| + |b1| + |b2| + |b3|) across 8 groups of 4 elements. */ static void -PartialMax_Reference(float* lut_scales, const float* b) -{ - // Process 32 floats organized as 8 groups of 4 consecutive elements - // Groups: {0-3}, {4-7}, {8-11}, {12-15}, {16-19}, {20-23}, {24-27}, {28-31} - float max_sum = 0.0f; - for (int group = 0; group < 8; ++group) { - float abssum = std::abs(b[group * 4 + 0]) + - std::abs(b[group * 4 + 1]) + - std::abs(b[group * 4 + 2]) + - std::abs(b[group * 4 + 3]); - max_sum = std::max(max_sum, abssum); - } - float scales = max_sum / 127.0f; - *lut_scales = std::max(*lut_scales, scales); +PartialMax_Reference(float* lut_scales, const float* b) { + // Process 32 floats organized as 8 groups of 4 consecutive elements + // Groups: {0-3}, {4-7}, {8-11}, {12-15}, {16-19}, {20-23}, {24-27}, {28-31} + float max_sum = 0.0f; + for (int group = 0; group < 8; ++group) { + float abssum = std::abs(b[group * 4 + 0]) + + std::abs(b[group * 4 + 1]) + + std::abs(b[group * 4 + 2]) + + std::abs(b[group * 4 + 3]); + max_sum = std::max(max_sum, abssum); + } + float scales = max_sum / 127.0f; + *lut_scales = std::max(*lut_scales, scales); } /** @@ -538,66 +520,64 @@ LutCtor_Reference( int8_t* qlut, const float* b, float* lut_scales, - float* lut_biases -) -{ - float biases = 0.0f; - float scales = *lut_scales; - float t_scales = scales ? 1.0f / scales : 0.0f; - - for (int k = 0; k < act_k / 32; ++k) { - // For each of 8 groups of 4 elements - float lut[16][8]; // [lut_entry][group] - - for (int group = 0; group < 8; ++group) { - float b0 = b[k * 32 + group * 4 + 0]; - float b1 = b[k * 32 + group * 4 + 1]; - float b2 = b[k * 32 + group * 4 + 2]; - float b3 = b[k * 32 + group * 4 + 3]; - - // Build 16-entry LUT: each entry is ±b0 ±b1 ±b2 ±b3 - for (int g = 1; g < 16; g += 2) { - lut[g][group] = b0; - if (g & 0b0010) { - lut[g][group] += b1; - } else { - lut[g][group] -= b1; - } - if (g & 0b0100) { - lut[g][group] += b2; - } else { - lut[g][group] -= b2; - } - if (g & 0b1000) { - lut[g][group] += b3; - } else { - lut[g][group] -= b3; - } - } - // Symmetric: lut[g] = -lut[15 - g] - for (int g = 0; g < 16; g += 2) { - lut[g][group] = -lut[15 - g][group]; - } + float* lut_biases) { + float biases = 0.0f; + float scales = *lut_scales; + float t_scales = scales ? 1.0f / scales : 0.0f; + + for (int k = 0; k < act_k / 32; ++k) { + // For each of 8 groups of 4 elements + float lut[16][8]; // [lut_entry][group] + + for (int group = 0; group < 8; ++group) { + float b0 = b[k * 32 + group * 4 + 0]; + float b1 = b[k * 32 + group * 4 + 1]; + float b2 = b[k * 32 + group * 4 + 2]; + float b3 = b[k * 32 + group * 4 + 3]; + + // Build 16-entry LUT: each entry is ±b0 ±b1 ±b2 ±b3 + for (int g = 1; g < 16; g += 2) { + lut[g][group] = b0; + if (g & 0b0010) { + lut[g][group] += b1; + } else { + lut[g][group] -= b1; } - - // Accumulate bias - for (int group = 0; group < 8; ++group) { - biases += lut[0][group]; + if (g & 0b0100) { + lut[g][group] += b2; + } else { + lut[g][group] -= b2; } - - // Scale and quantize, then store - // Output layout: qlut[k * 8 * 16 + group * 16 + lut_entry] - for (int group = 0; group < 8; ++group) { - for (int g = 0; g < 16; ++g) { - float scaled = lut[g][group] * t_scales; - int8_t quantized = static_cast(std::round(scaled)); - qlut[k * 8 * 16 + group * 16 + g] = quantized; - } + if (g & 0b1000) { + lut[g][group] += b3; + } else { + lut[g][group] -= b3; } + } + // Symmetric: lut[g] = -lut[15 - g] + for (int g = 0; g < 16; g += 2) { + lut[g][group] = -lut[15 - g][group]; + } + } + + // Accumulate bias + for (int group = 0; group < 8; ++group) { + biases += lut[0][group]; } - - *lut_scales = scales; - *lut_biases = biases; + + // Scale and quantize, then store + // Output layout: qlut[k * 8 * 16 + group * 16 + lut_entry] + for (int group = 0; group < 8; ++group) { + for (int g = 0; g < 16; ++g) { + float scaled = lut[g][group] * t_scales; + int8_t quantized = static_cast(std::round(scaled)); + qlut[k * 8 * 16 + group * 16 + g] = quantized; + } + } + } + + *lut_scales = scales; + *lut_biases = biases; } /** @@ -610,30 +590,27 @@ GenerateLUT_Reference( float* lut_scales, float* lut_biases, size_t K, - size_t act_group_size -) -{ - const int32_t kk_outer_max = static_cast(K / act_group_size); - const int32_t ags_div32 = static_cast(act_group_size / 32); - - // Phase 1: Compute partial max for each activation group - for (int32_t kk_outer = 0; kk_outer < kk_outer_max; ++kk_outer) { - lut_scales[kk_outer] = 0.0f; - for (int32_t k_outer = 0; k_outer < ags_div32; ++k_outer) { - PartialMax_Reference(&lut_scales[kk_outer], &b[(kk_outer * act_group_size) + (k_outer * 32)]); - } + size_t act_group_size) { + const int32_t kk_outer_max = static_cast(K / act_group_size); + const int32_t ags_div32 = static_cast(act_group_size / 32); + + // Phase 1: Compute partial max for each activation group + for (int32_t kk_outer = 0; kk_outer < kk_outer_max; ++kk_outer) { + lut_scales[kk_outer] = 0.0f; + for (int32_t k_outer = 0; k_outer < ags_div32; ++k_outer) { + PartialMax_Reference(&lut_scales[kk_outer], &b[(kk_outer * act_group_size) + (k_outer * 32)]); } + } - // Phase 2: Build quantized LUT - for (int32_t k_outer_1 = 0; k_outer_1 < kk_outer_max; ++k_outer_1) { - LutCtor_Reference( - static_cast(act_group_size), - &qlut[k_outer_1 * act_group_size * 4], - &b[k_outer_1 * act_group_size], - &lut_scales[k_outer_1], - &lut_biases[k_outer_1] - ); - } + // Phase 2: Build quantized LUT + for (int32_t k_outer_1 = 0; k_outer_1 < kk_outer_max; ++k_outer_1) { + LutCtor_Reference( + static_cast(act_group_size), + &qlut[k_outer_1 * act_group_size * 4], + &b[k_outer_1 * act_group_size], + &lut_scales[k_outer_1], + &lut_biases[k_outer_1]); + } } } // namespace ScalarReference @@ -645,7 +622,7 @@ GenerateLUT_Reference( // template -class MlasSQLutGemmLutGenTest : public MlasTestBase { +class MlasLutGemmLutGenTest : public MlasTestBase { private: MatrixGuardBuffer BufferActivation; MatrixGuardBuffer BufferQLutExpected; @@ -690,8 +667,7 @@ class MlasSQLutGemmLutGenTest : public MlasTestBase { LutScalesExpected, LutBiasesExpected, K, - act_group_size - ); + act_group_size); // Get the kernel dispatch through internal accessor // This is defined in qlutgemm.h and qlutgemm.cpp @@ -701,38 +677,38 @@ class MlasSQLutGemmLutGenTest : public MlasTestBase { // Use the public GEMM API indirectly by creating a minimal test scenario // that exercises the GenerateLUT path. We need to call it through the // internal dispatch mechanism. - + // Access dispatch through platform - this requires linking to internal symbols // For now, we'll use a workaround: call the full LUT GEMM but with minimal weights // and compare intermediate LUT results. - + // Since we can't easily access GenerateLUT directly, let's verify the algorithm - // by checking that the scalar reference produces sensible output, then - // trust the integration test (SQLutGemm) to find bugs in the SIMD version. - + // by checking that the scalar reference produces sensible output, then + // trust the integration test (LutGemm) to find bugs in the SIMD version. + // For a proper isolated test, we would need to expose GenerateLUT publicly. // For now, just verify the scalar reference produces valid output: - + // Check that scales are non-negative for (size_t i = 0; i < lut_scales_count; ++i) { EXPECT_GE(LutScalesExpected[i], 0.0f) << "LUT scale should be non-negative"; } - + // Check that quantized LUT values are within int8 range for (size_t i = 0; i < K * 4; ++i) { EXPECT_GE(QLutExpected[i], -128) << "QLUT value out of range"; EXPECT_LE(QLutExpected[i], 127) << "QLUT value out of range"; } - + // Log some info for debugging if (lut_scales_count > 0) { - SCOPED_TRACE(testing::Message() << "First LUT scale: " << LutScalesExpected[0] - << ", First LUT bias: " << LutBiasesExpected[0]); + SCOPED_TRACE(testing::Message() << "First LUT scale: " << LutScalesExpected[0] + << ", First LUT bias: " << LutBiasesExpected[0]); } } static const char* GetTestSuiteName() { - static std::string suite_name = std::string("SQLutGemmLutGen") + "BlkLen" + std::to_string(BlkLen); + static std::string suite_name = std::string("LutGemmLutGen") + "BlkLen" + std::to_string(BlkLen); return suite_name.c_str(); } }; @@ -744,13 +720,13 @@ class MlasSQLutGemmLutGenTest : public MlasTestBase { // template -class SQLutGemmPackShortExecuteTest : public MlasTestFixture> { +class LutGemmPackShortExecuteTest : public MlasTestFixture> { public: - explicit SQLutGemmPackShortExecuteTest(size_t N, size_t K, bool Symmetric) + explicit LutGemmPackShortExecuteTest(size_t N, size_t K, bool Symmetric) : N_(N), K_(K), Symmetric_(Symmetric) {} void TestBody() override { - MlasTestFixture>::mlas_tester->Test(N_, K_, Symmetric_); + MlasTestFixture>::mlas_tester->Test(N_, K_, Symmetric_); } static size_t RegisterSingleTest(size_t N, size_t K, bool Symmetric) { @@ -766,14 +742,14 @@ class SQLutGemmPackShortExecuteTest : public MlasTestFixture::GetTestSuiteName(), + MlasLutGemmPackTest::GetTestSuiteName(), test_name.c_str(), nullptr, test_name.c_str(), __FILE__, __LINE__, - [=]() -> MlasTestFixture>* { - return new SQLutGemmPackShortExecuteTest(N, K, Symmetric); + [=]() -> MlasTestFixture>* { + return new LutGemmPackShortExecuteTest(N, K, Symmetric); }); return 1; @@ -801,12 +777,12 @@ class SQLutGemmPackShortExecuteTest : public MlasTestFixture -class SQLutGemmLutGenShortExecuteTest : public MlasTestFixture> { +class LutGemmLutGenShortExecuteTest : public MlasTestFixture> { public: - explicit SQLutGemmLutGenShortExecuteTest(size_t K) : K_(K) {} + explicit LutGemmLutGenShortExecuteTest(size_t K) : K_(K) {} void TestBody() override { - MlasTestFixture>::mlas_tester->Test(K_); + MlasTestFixture>::mlas_tester->Test(K_); } static size_t RegisterSingleTest(size_t K) { @@ -820,14 +796,14 @@ class SQLutGemmLutGenShortExecuteTest : public MlasTestFixture::GetTestSuiteName(), + MlasLutGemmLutGenTest::GetTestSuiteName(), test_name.c_str(), nullptr, test_name.c_str(), __FILE__, __LINE__, - [=]() -> MlasTestFixture>* { - return new SQLutGemmLutGenShortExecuteTest(K); + [=]() -> MlasTestFixture>* { + return new LutGemmLutGenShortExecuteTest(K); }); return 1; @@ -851,18 +827,18 @@ class SQLutGemmLutGenShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); - count += SQLutGemmPackShortExecuteTest<2, 64>::RegisterShortExecuteTests(); - count += SQLutGemmPackShortExecuteTest<2, 128>::RegisterShortExecuteTests(); + count += LutGemmPackShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + count += LutGemmPackShortExecuteTest<2, 64>::RegisterShortExecuteTests(); + count += LutGemmPackShortExecuteTest<2, 128>::RegisterShortExecuteTests(); // LUT generation tests - count += SQLutGemmLutGenShortExecuteTest<32>::RegisterShortExecuteTests(); - count += SQLutGemmLutGenShortExecuteTest<64>::RegisterShortExecuteTests(); - count += SQLutGemmLutGenShortExecuteTest<128>::RegisterShortExecuteTests(); + count += LutGemmLutGenShortExecuteTest<32>::RegisterShortExecuteTests(); + count += LutGemmLutGenShortExecuteTest<64>::RegisterShortExecuteTests(); + count += LutGemmLutGenShortExecuteTest<128>::RegisterShortExecuteTests(); return count; } @@ -870,7 +846,7 @@ static size_t SQLutGemmComponentsRegisterAllShortExecuteTests() { static UNUSED_VARIABLE bool added_to_main = AddTestRegister( [](bool is_short_execute) -> size_t { if (is_short_execute) { - return SQLutGemmComponentsRegisterAllShortExecuteTests(); + return LutGemmPackRegisterAllShortExecuteTests(); } return 0; });