Skip to content
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
Expand Down Expand Up @@ -470,6 +472,8 @@ else()
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,10 @@ struct MLAS_QNBIT_LUT_GEMM_DISPATCH;

extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2;

#if defined(MLAS_TARGET_ARM64)
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelNeon;
#endif

//
// Rotary embedding dispatch structure.
//
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,9 @@ Return Value:
this->ArmNeonIsQuantActivationsUnsigned = HasI8MMInstructions ? false : true;
this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, HasI8MMInstructions);

// Enable LUT-based GEMM for 2-bit quantization on ARM64
this->LutGenKernel = &MlasLutGenKernelNeon;

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon;
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelNeon;
Expand Down
234 changes: 64 additions & 170 deletions onnxruntime/core/mlas/lib/qlutgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,47 @@ Module Name:

#include <cassert>
#include <cstring>
#include <functional>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>

/** T-MAC GEMM kernel config key - struct-based for type safety and performance */
struct TMACConfigKey {
size_t M;
size_t N;
size_t nbits;
size_t block_size;
bool has_zero_point;

bool operator==(const TMACConfigKey& other) const {
return M == other.M && N == other.N && nbits == other.nbits &&
block_size == other.block_size && has_zero_point == other.has_zero_point;
}
};

struct TMACConfigKeyHash {
size_t operator()(const TMACConfigKey& k) const {
// Combine hash values using a simple mixing function
size_t h = std::hash<size_t>{}(k.M);
h ^= std::hash<size_t>{}(k.N) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<size_t>{}(k.nbits) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<size_t>{}(k.block_size) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<bool>{}(k.has_zero_point) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};

/** T-MAC GEMM kernel Config */
static std::unordered_map<std::string, struct MlasTMACKernelParams> tmac_kernel_configs;
static std::unordered_map<TMACConfigKey, MlasTMACKernelParams, TMACConfigKeyHash> tmac_kernel_configs;

const MlasTMACKernelParams&
MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
{
std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0");
if (tmac_kernel_configs.count(key)) {
return tmac_kernel_configs[key];
TMACConfigKey key{M, N, nbits, block_size, has_zero_point};
auto it = tmac_kernel_configs.find(key);
if (it != tmac_kernel_configs.end()) {
return it->second;
}
MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized");
}
Expand All @@ -49,7 +76,7 @@ MlasClearLutGemmKernelConfig()
void MLASCALL
MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
{
std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0");
TMACConfigKey key{M, N, nbits, block_size, has_zero_point};
if (tmac_kernel_configs.count(key)) {
return;
}
Expand Down Expand Up @@ -163,111 +190,16 @@ LutGemmPackQuantBData(
const size_t bm = tmac_params.bm;
const size_t kfactor = tmac_params.kfactor;

assert(BlkLen % g == 0);
assert((BlkLen / g) % kfactor == 0);

const size_t mgroup = ngroups_per_elem * simd_n_in; // 32
assert(bm % mgroup == 0);
assert(bm % bits == 0);

std::unique_ptr<uint8_t[]> buf(new uint8_t[N * bits * (K / g)]);
memset(buf.get(), 0, N * bits * (K / g));

const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed

MlasTrySimpleParallel(
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
size_t im = static_cast<size_t>(tid);
for (size_t ik = 0; ik < K; ++ik) {
size_t idx = (im * K + ik);
size_t num_elem_per_byte = 8 / bits;
size_t elem_idx = idx % num_elem_per_byte;

uint8_t v = ((const uint8_t*)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits);

for (size_t ib = 0; ib < bits; ++ib) {
size_t new_ik = ik / g;
size_t shft_left = ik % g;
buf[im * bits * K / g + ib * K / g + new_ik] += static_cast<uint8_t>(((v >> ib) & 1) << shft_left);
}
}
}
);

// Now buf contains the bit planes grouped by g along K
// Next, we need to do a multi-reshape/transpose into the final layout

const size_t c0_fac2 = K / g;
const size_t c0_fac1 = simd_n_out * c0_fac2;
const size_t c0_fac0 = bits * c0_fac1;

const size_t c1_nb2 = K / g;
const size_t c1_nb1 = simd_n_in * c1_nb2;
const size_t c1_nb0 = ngroups_per_elem * c1_nb1;
const size_t c1_fac2 = K / g;
const size_t c1_fac1 = ngroups_per_elem * c1_fac2;
const size_t c1_fac0 = simd_n_in * c1_fac1;

const size_t c2_nb4 = kfactor;
const size_t c2_nb3 = K / g / kfactor * c2_nb4;
const size_t c2_nb2 = ngroups_per_elem * c2_nb3;
const size_t c2_nb1 = simd_n_in * c2_nb2;
const size_t c2_nb0 = bm / mgroup * c2_nb1;
const size_t c2_fac3 = simd_n_in * ngroups_per_elem;
const size_t c2_fac2 = kfactor * c2_fac3;
const size_t c2_fac1 = bm / mgroup * c2_fac2;
const size_t c2_fac0 = K / g / kfactor * c2_fac1;

const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem);
memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed?
// LUT GEMM requires a valid LUT dispatch implementation, so dispatch must be available
const auto* Dispatch = GetMlasPlatform().LutGenKernel;
if (Dispatch == nullptr || Dispatch->PackQuantBData == nullptr) {
MLAS_THROW_EX(std::runtime_error, "PackQuantBData requires LUT GEMM dispatch support");
}

MlasTrySimpleParallel(
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
size_t im = static_cast<size_t>(tid);
for (size_t ib = 0; ib < bits; ib++) {
for (size_t ik = 0; ik < K / g; ik++) {
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
size_t new_im = im / simd_n_out;
size_t new_isno = im % simd_n_out;
size_t new_ib = ib;
size_t new_ik = ik;
size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;

// w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
new_im = new_idx / c1_nb0;
size_t new_ing = (new_idx % c1_nb0) / c1_nb1;
size_t new_isni = (new_idx % c1_nb1) / c1_nb2;
new_ik = (new_idx % c1_nb2);
new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;

// # 0 1 2 3 4 5
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
new_im = new_idx / c2_nb0;
size_t new_ibm = (new_idx % c2_nb0) / c2_nb1;
new_isni = (new_idx % c2_nb1) / c2_nb2;
new_ing = (new_idx % c2_nb2) / c2_nb3;
new_ik = (new_idx % c2_nb3) / c2_nb4;
size_t new_ikf = (new_idx % c2_nb4);
new_idx = new_im * c2_fac0 +
new_ik * c2_fac1 +
new_ibm * c2_fac2 +
new_ikf * c2_fac3 +
new_isni * ngroups_per_elem +
new_ing;
new_idx = new_idx / ngroups_per_elem;
size_t buf_idx = im * bits * K / g + ib * K / g + ik;
uint8_t buf_val = buf[buf_idx];

// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
PackedQuantBDataBegin[new_idx] = static_cast<std::byte>(
static_cast<unsigned>(PackedQuantBDataBegin[new_idx]) +
(buf_val << (new_ing * g))
);
}
}
}
Dispatch->PackQuantBData(
N, K, bits, g, ngroups_per_elem,
simd_n_in, simd_n_out, bm, kfactor,
QuantBDataBegin, PackedQuantBDataBegin, ThreadPool
);
}

Expand Down Expand Up @@ -298,67 +230,25 @@ LutPackScalesAndZeroPoints(
bool HasZeroPoint,
float* PackedQuantBZPBegin,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint
const uint8_t* QuantBZeroPoint,
MLAS_THREADPOOL* ThreadPool
)
{
const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, BlkBitWidth, BlkLen, HasZeroPoint);
const size_t bits = tmac_params.bits;
const size_t simd_n_out = tmac_params.simd_n_out;
const size_t bm = tmac_params.bm;
const size_t num_elem_per_byte = 8 / bits;

// ZP array is column-major packed, with per-column alignment to byte boundary
const size_t row_blks = K / BlkLen; // number of blocks per column
const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte;

for (size_t im = 0; im < N; im += 1) {
for (size_t ik = 0; ik < K; ik += BlkLen) {
size_t idx = (im * K + ik) / BlkLen; // linear block index for scale (scale is NOT packed)
float scale = QuantBScale[idx];
float zp = 0.0f;
if (HasZeroPoint) {
size_t blk_in_col = ik / BlkLen; // block index within column
size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte;
size_t elem_idx = blk_in_col % num_elem_per_byte;
uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & ((1 << bits) - 1);

// The LUT kernel assumes weights are centered around the midpoint (2 for 2-bit).
// Thus, need to correct for the actual ZP relative to the midpoint.

int midpoint = 1 << (bits - 1); // 2 for 2-bit
zp = static_cast<float>(static_cast<int>(v) - midpoint) * scale;
}

// TODO(vraspar): fix when k < BlkLen and nb1 is 0
size_t nb1 = K / BlkLen;
size_t nb0 = bm / bits * nb1;

size_t new_im, new_ibm, new_ik;
if (nb1 == 0) {
new_im = 0;
new_ibm = 0;
new_ik = 0;

} else {
new_im = idx / nb0;
new_ibm = (idx % nb0) / nb1;
new_ik = (idx % nb1);
}

if (HasZeroPoint) {
size_t new_isimd = new_ibm % simd_n_out;
size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out;
size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd;
size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd;

PackedQuantBZPBegin[new_idx_scale] = scale;
PackedQuantBZPBegin[new_idx_zero] = zp;
} else {
size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm;
PackedQuantBZPBegin[new_idx] = scale;
}
}
// LUT GEMM is only available for AVX2, so dispatch must be available
const auto* Dispatch = GetMlasPlatform().LutGenKernel;
if (Dispatch == nullptr || Dispatch->PackScalesAndZeroPoints == nullptr) {
MLAS_THROW_EX(std::runtime_error, "PackScalesAndZeroPoints requires AVX2 dispatch");
}

Dispatch->PackScalesAndZeroPoints(
N, K, bits, BlkLen, simd_n_out, bm, HasZeroPoint,
PackedQuantBZPBegin, QuantBScale, QuantBZeroPoint, ThreadPool
);
}

// Internal helper: calculates the offset to scales in the packed buffer
Expand Down Expand Up @@ -418,7 +308,7 @@ MlasLutGemmPack(
if (QuantBScale != nullptr) {
size_t scales_offset = LutGemmPackedScalesOffset(N, K, BlkBitWidth, BlkLen, HasZeroPoint);
float* scales_dest = reinterpret_cast<float*>(PackedBuf + scales_offset);
LutPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint);
LutPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint, ThreadPool);
}
}

Expand All @@ -431,7 +321,11 @@ MlasIsLutGemmAvailable(
)
{
const auto* lut_kernel = GetMlasPlatform().LutGenKernel;
if (lut_kernel == nullptr || lut_kernel->GenerateLUT == nullptr || lut_kernel->ComputeGemm == nullptr) {
if (lut_kernel == nullptr ||
lut_kernel->GenerateLUT == nullptr ||
lut_kernel->ComputeGemm == nullptr ||
lut_kernel->PackQuantBData == nullptr ||
lut_kernel->PackScalesAndZeroPoints == nullptr) {
return false;
}

Expand Down Expand Up @@ -500,7 +394,9 @@ MlasLutGemm(
// adapted from ggml_backend_tmac_mul_mat
const auto* Dispatch = GetMlasPlatform().LutGenKernel;
// This should be ensured by calling MlasIsLutGemmAvailable() before MlasLutGemm()
assert(Dispatch && Dispatch->GenerateLUT && "TMAC not supported in this configuration.");
if (Dispatch == nullptr || Dispatch->GenerateLUT == nullptr || Dispatch->ComputeGemm == nullptr) {
MLAS_THROW_EX(std::runtime_error, "TMAC not supported in this configuration");
}

// Calculate scales offset from packed buffer
// TODO(vraspar): support other bitwidths
Expand Down Expand Up @@ -620,10 +516,8 @@ MlasLutGemm(
size_t scales_size_per_tile = 0;

if (scales_size_total % n_tiles_num != 0) {
// Sanity: scales should partition evenly across tiles. If they don't, choose floor division
// and document that callers must layout scales accordingly.
// Prefer to error loudly in debug builds.
fprintf(stderr, "Warning: scales_size_total=%zu is not divisible by n_tiles_num=%zu; using floor division.\n", scales_size_total, n_tiles_num);
// Scales must partition evenly across tiles. Callers must ensure proper layout.
MLAS_THROW_EX(std::runtime_error, "scales_size_total must be divisible by n_tiles_num");
}
scales_size_per_tile = scales_size_total / n_tiles_num;

Expand Down
39 changes: 39 additions & 0 deletions onnxruntime/core/mlas/lib/qlutgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,41 @@ typedef void(MLAS_QNBIT_LUT_GEMM_COMPUTE)(
bool HasZeroPoint
);

//
// Function signature for packing quantized B data
//
typedef void(MLAS_QNBIT_LUT_PACK_QUANTB_DATA)(
size_t N,
size_t K,
size_t bits,
size_t g,
size_t ngroups_per_elem,
size_t simd_n_in,
size_t simd_n_out,
size_t bm,
size_t kfactor,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
);

//
// Function signature for packing scales and zero points
//
typedef void(MLAS_QNBIT_LUT_PACK_SCALES_AND_ZP)(
size_t N,
size_t K,
size_t bits,
size_t BlkLen,
size_t simd_n_out,
size_t bm,
bool HasZeroPoint,
float* PackedScalesBegin,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint,
MLAS_THREADPOOL* ThreadPool
);

//
// Kernel dispatch structure.
//
Expand All @@ -81,4 +116,8 @@ struct MLAS_QNBIT_LUT_GEMM_DISPATCH {
MLAS_QNBIT_GEMM_LUT_GEN* GenerateLUT = nullptr;

MLAS_QNBIT_LUT_GEMM_COMPUTE* ComputeGemm = nullptr;

MLAS_QNBIT_LUT_PACK_QUANTB_DATA* PackQuantBData = nullptr;

MLAS_QNBIT_LUT_PACK_SCALES_AND_ZP* PackScalesAndZeroPoints = nullptr;
};
Loading
Loading