Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
Expand Down Expand Up @@ -470,6 +472,8 @@ else()
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_neon.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,11 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx;

struct MLAS_QNBIT_LUT_GEMM_DISPATCH;

extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2;
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchAvx2;

#if defined(MLAS_TARGET_ARM64)
extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGemmDispatchNeon;
#endif

//
// Rotary embedding dispatch structure.
Expand Down Expand Up @@ -1453,7 +1457,7 @@ struct MLAS_PLATFORM {
const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr};

const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr};
const MLAS_QNBIT_LUT_GEMM_DISPATCH* LutGenKernel{nullptr};
const MLAS_QNBIT_LUT_GEMM_DISPATCH* LutGemmDispatch{nullptr};

MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ Return Value:
this->RopeDispatch = &MlasRopeDispatchAvx2;

// TODO(vraspar): check if this really goes here or if there are other platform reqs that we need to fulfill
this->LutGenKernel = &MlasLutGenKernelAvx2;
this->LutGemmDispatch = &MlasLutGemmDispatchAvx2;

//
// Check if the processor supports Hybrid core architecture.
Expand Down Expand Up @@ -654,6 +654,9 @@ Return Value:
this->ArmNeonIsQuantActivationsUnsigned = HasI8MMInstructions ? false : true;
this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, HasI8MMInstructions);

// Enable LUT-based GEMM for 2-bit quantization on ARM64
this->LutGemmDispatch = &MlasLutGemmDispatchNeon;

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon;
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelNeon;
Expand Down
238 changes: 66 additions & 172 deletions onnxruntime/core/mlas/lib/qlutgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,47 @@ Module Name:

#include <cassert>
#include <cstring>
#include <functional>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>

/** T-MAC GEMM kernel config key - struct-based for type safety and performance */
struct TMACConfigKey {
size_t M;
size_t N;
size_t nbits;
size_t block_size;
bool has_zero_point;

bool operator==(const TMACConfigKey& other) const {
return M == other.M && N == other.N && nbits == other.nbits &&
block_size == other.block_size && has_zero_point == other.has_zero_point;
}
};

struct TMACConfigKeyHash {
size_t operator()(const TMACConfigKey& k) const {
// Combine hash values using a simple mixing function
size_t h = std::hash<size_t>{}(k.M);
h ^= std::hash<size_t>{}(k.N) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<size_t>{}(k.nbits) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<size_t>{}(k.block_size) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<bool>{}(k.has_zero_point) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};

/** T-MAC GEMM kernel Config */
static std::unordered_map<std::string, struct MlasTMACKernelParams> tmac_kernel_configs;
static std::unordered_map<TMACConfigKey, MlasTMACKernelParams, TMACConfigKeyHash> tmac_kernel_configs;

const MlasTMACKernelParams&
MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
{
std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0");
if (tmac_kernel_configs.count(key)) {
return tmac_kernel_configs[key];
TMACConfigKey key{M, N, nbits, block_size, has_zero_point};
auto it = tmac_kernel_configs.find(key);
if (it != tmac_kernel_configs.end()) {
return it->second;
}
MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized");
}
Expand All @@ -49,7 +76,7 @@ MlasClearLutGemmKernelConfig()
void MLASCALL
MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
{
std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0");
TMACConfigKey key{M, N, nbits, block_size, has_zero_point};
if (tmac_kernel_configs.count(key)) {
return;
}
Expand Down Expand Up @@ -163,111 +190,16 @@ LutGemmPackQuantBData(
const size_t bm = tmac_params.bm;
const size_t kfactor = tmac_params.kfactor;

assert(BlkLen % g == 0);
assert((BlkLen / g) % kfactor == 0);

const size_t mgroup = ngroups_per_elem * simd_n_in; // 32
assert(bm % mgroup == 0);
assert(bm % bits == 0);

std::unique_ptr<uint8_t[]> buf(new uint8_t[N * bits * (K / g)]);
memset(buf.get(), 0, N * bits * (K / g));

const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed

MlasTrySimpleParallel(
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
size_t im = static_cast<size_t>(tid);
for (size_t ik = 0; ik < K; ++ik) {
size_t idx = (im * K + ik);
size_t num_elem_per_byte = 8 / bits;
size_t elem_idx = idx % num_elem_per_byte;

uint8_t v = ((const uint8_t*)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits);

for (size_t ib = 0; ib < bits; ++ib) {
size_t new_ik = ik / g;
size_t shft_left = ik % g;
buf[im * bits * K / g + ib * K / g + new_ik] += static_cast<uint8_t>(((v >> ib) & 1) << shft_left);
}
}
}
);

// Now buf contains the bit planes grouped by g along K
// Next, we need to do a multi-reshape/transpose into the final layout

const size_t c0_fac2 = K / g;
const size_t c0_fac1 = simd_n_out * c0_fac2;
const size_t c0_fac0 = bits * c0_fac1;

const size_t c1_nb2 = K / g;
const size_t c1_nb1 = simd_n_in * c1_nb2;
const size_t c1_nb0 = ngroups_per_elem * c1_nb1;
const size_t c1_fac2 = K / g;
const size_t c1_fac1 = ngroups_per_elem * c1_fac2;
const size_t c1_fac0 = simd_n_in * c1_fac1;

const size_t c2_nb4 = kfactor;
const size_t c2_nb3 = K / g / kfactor * c2_nb4;
const size_t c2_nb2 = ngroups_per_elem * c2_nb3;
const size_t c2_nb1 = simd_n_in * c2_nb2;
const size_t c2_nb0 = bm / mgroup * c2_nb1;
const size_t c2_fac3 = simd_n_in * ngroups_per_elem;
const size_t c2_fac2 = kfactor * c2_fac3;
const size_t c2_fac1 = bm / mgroup * c2_fac2;
const size_t c2_fac0 = K / g / kfactor * c2_fac1;

const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem);
memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed?
// LUT GEMM requires a valid LUT dispatch implementation, so dispatch must be available
const auto* Dispatch = GetMlasPlatform().LutGemmDispatch;
if (Dispatch == nullptr || Dispatch->PackQuantBData == nullptr) {
MLAS_THROW_EX(std::runtime_error, "PackQuantBData requires LUT GEMM dispatch support");
}

MlasTrySimpleParallel(
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
size_t im = static_cast<size_t>(tid);
for (size_t ib = 0; ib < bits; ib++) {
for (size_t ik = 0; ik < K / g; ik++) {
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
size_t new_im = im / simd_n_out;
size_t new_isno = im % simd_n_out;
size_t new_ib = ib;
size_t new_ik = ik;
size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;

// w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
new_im = new_idx / c1_nb0;
size_t new_ing = (new_idx % c1_nb0) / c1_nb1;
size_t new_isni = (new_idx % c1_nb1) / c1_nb2;
new_ik = (new_idx % c1_nb2);
new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;

// # 0 1 2 3 4 5
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
new_im = new_idx / c2_nb0;
size_t new_ibm = (new_idx % c2_nb0) / c2_nb1;
new_isni = (new_idx % c2_nb1) / c2_nb2;
new_ing = (new_idx % c2_nb2) / c2_nb3;
new_ik = (new_idx % c2_nb3) / c2_nb4;
size_t new_ikf = (new_idx % c2_nb4);
new_idx = new_im * c2_fac0 +
new_ik * c2_fac1 +
new_ibm * c2_fac2 +
new_ikf * c2_fac3 +
new_isni * ngroups_per_elem +
new_ing;
new_idx = new_idx / ngroups_per_elem;
size_t buf_idx = im * bits * K / g + ib * K / g + ik;
uint8_t buf_val = buf[buf_idx];

// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
PackedQuantBDataBegin[new_idx] = static_cast<std::byte>(
static_cast<unsigned>(PackedQuantBDataBegin[new_idx]) +
(buf_val << (new_ing * g))
);
}
}
}
Dispatch->PackQuantBData(
N, K, bits, g, ngroups_per_elem,
simd_n_in, simd_n_out, bm, kfactor,
QuantBDataBegin, PackedQuantBDataBegin, ThreadPool
);
}

Expand Down Expand Up @@ -298,67 +230,25 @@ LutPackScalesAndZeroPoints(
bool HasZeroPoint,
float* PackedQuantBZPBegin,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint
const uint8_t* QuantBZeroPoint,
MLAS_THREADPOOL* ThreadPool
)
{
const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, BlkBitWidth, BlkLen, HasZeroPoint);
const size_t bits = tmac_params.bits;
const size_t simd_n_out = tmac_params.simd_n_out;
const size_t bm = tmac_params.bm;
const size_t num_elem_per_byte = 8 / bits;

// ZP array is column-major packed, with per-column alignment to byte boundary
const size_t row_blks = K / BlkLen; // number of blocks per column
const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte;

for (size_t im = 0; im < N; im += 1) {
for (size_t ik = 0; ik < K; ik += BlkLen) {
size_t idx = (im * K + ik) / BlkLen; // linear block index for scale (scale is NOT packed)
float scale = QuantBScale[idx];
float zp = 0.0f;
if (HasZeroPoint) {
size_t blk_in_col = ik / BlkLen; // block index within column
size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte;
size_t elem_idx = blk_in_col % num_elem_per_byte;
uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & ((1 << bits) - 1);

// The LUT kernel assumes weights are centered around the midpoint (2 for 2-bit).
// Thus, need to correct for the actual ZP relative to the midpoint.

int midpoint = 1 << (bits - 1); // 2 for 2-bit
zp = static_cast<float>(static_cast<int>(v) - midpoint) * scale;
}

// TODO(vraspar): fix when k < BlkLen and nb1 is 0
size_t nb1 = K / BlkLen;
size_t nb0 = bm / bits * nb1;

size_t new_im, new_ibm, new_ik;
if (nb1 == 0) {
new_im = 0;
new_ibm = 0;
new_ik = 0;

} else {
new_im = idx / nb0;
new_ibm = (idx % nb0) / nb1;
new_ik = (idx % nb1);
}

if (HasZeroPoint) {
size_t new_isimd = new_ibm % simd_n_out;
size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out;
size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd;
size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd;

PackedQuantBZPBegin[new_idx_scale] = scale;
PackedQuantBZPBegin[new_idx_zero] = zp;
} else {
size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm;
PackedQuantBZPBegin[new_idx] = scale;
}
}
// LUT GEMM is only available for AVX2, so dispatch must be available
const auto* Dispatch = GetMlasPlatform().LutGemmDispatch;
if (Dispatch == nullptr || Dispatch->PackScalesAndZeroPoints == nullptr) {
MLAS_THROW_EX(std::runtime_error, "PackScalesAndZeroPoints requires LUT GEMM dispatch support");
}

Dispatch->PackScalesAndZeroPoints(
N, K, bits, BlkLen, simd_n_out, bm, HasZeroPoint,
PackedQuantBZPBegin, QuantBScale, QuantBZeroPoint, ThreadPool
);
}

// Internal helper: calculates the offset to scales in the packed buffer
Expand Down Expand Up @@ -418,7 +308,7 @@ MlasLutGemmPack(
if (QuantBScale != nullptr) {
size_t scales_offset = LutGemmPackedScalesOffset(N, K, BlkBitWidth, BlkLen, HasZeroPoint);
float* scales_dest = reinterpret_cast<float*>(PackedBuf + scales_offset);
LutPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint);
LutPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint, ThreadPool);
}
}

Expand All @@ -430,8 +320,12 @@ MlasIsLutGemmAvailable(
size_t BlkLen
)
{
const auto* lut_kernel = GetMlasPlatform().LutGenKernel;
if (lut_kernel == nullptr || lut_kernel->GenerateLUT == nullptr || lut_kernel->ComputeGemm == nullptr) {
const auto* lut_kernel = GetMlasPlatform().LutGemmDispatch;
if (lut_kernel == nullptr ||
lut_kernel->GenerateLUT == nullptr ||
lut_kernel->ComputeGemm == nullptr ||
lut_kernel->PackQuantBData == nullptr ||
lut_kernel->PackScalesAndZeroPoints == nullptr) {
return false;
}

Expand Down Expand Up @@ -498,9 +392,11 @@ MlasLutGemm(
)
{
// adapted from ggml_backend_tmac_mul_mat
const auto* Dispatch = GetMlasPlatform().LutGenKernel;
const auto* Dispatch = GetMlasPlatform().LutGemmDispatch;
// This should be ensured by calling MlasIsLutGemmAvailable() before MlasLutGemm()
assert(Dispatch && Dispatch->GenerateLUT && "TMAC not supported in this configuration.");
if (Dispatch == nullptr || Dispatch->GenerateLUT == nullptr || Dispatch->ComputeGemm == nullptr) {
MLAS_THROW_EX(std::runtime_error, "TMAC not supported in this configuration");
}

// Calculate scales offset from packed buffer
// TODO(vraspar): support other bitwidths
Expand Down Expand Up @@ -620,10 +516,8 @@ MlasLutGemm(
size_t scales_size_per_tile = 0;

if (scales_size_total % n_tiles_num != 0) {
// Sanity: scales should partition evenly across tiles. If they don't, choose floor division
// and document that callers must layout scales accordingly.
// Prefer to error loudly in debug builds.
fprintf(stderr, "Warning: scales_size_total=%zu is not divisible by n_tiles_num=%zu; using floor division.\n", scales_size_total, n_tiles_num);
// Scales must partition evenly across tiles. Callers must ensure proper layout.
MLAS_THROW_EX(std::runtime_error, "scales_size_total must be divisible by n_tiles_num");
}
scales_size_per_tile = scales_size_total / n_tiles_num;

Expand Down
Loading
Loading