Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
555e951
matmul nbits to optimize memory layout for avx instructions
liqunfu Sep 24, 2024
076998c
Merge branch 'main' into liqun/avx-layout
liqunfu Nov 7, 2024
99aec95
intermediate push
liqunfu Nov 18, 2024
8ce1a2a
pass mlas and utest for blklen32 avx512
liqunfu Nov 27, 2024
f016555
Merge branch 'main' into liqun/avx-layout
liqunfu Nov 27, 2024
d371c59
pass avx512/vnni-blklen32
liqunfu Nov 29, 2024
790b03f
pass avx512vnni-blklen128. plan to compute blksum in different loop t…
liqunfu Nov 29, 2024
557fbb0
attmpt to make blklen256 work. failed because blksum computation need…
liqunfu Nov 29, 2024
6b28657
avx512 blklen64 to compute blksum in a separate loop
liqunfu Nov 30, 2024
0b867f8
avx512 scaled_zp compute in a separate loop except blklen16
liqunfu Nov 30, 2024
2e74f56
avx512, all blklens, scaled_zp compute in a separate loop
liqunfu Nov 30, 2024
0bf47f7
Merge branch 'main' into liqun/avx-layout
liqunfu Dec 12, 2024
c19ae9e
avx2 passes
liqunfu Dec 13, 2024
b26b075
avxvnni, matmul_nbit kernel
liqunfu Dec 15, 2024
7e99d50
mlas nbit print correct compType
liqunfu Jan 8, 2025
f36ec96
clean up a bit
liqunfu Jan 8, 2025
6d0404f
Merge branch 'main' into liqun/avx-layout
liqunfu Jan 8, 2025
5901b52
lint
liqunfu Jan 9, 2025
eba1908
remove unused __m512 load_1blksum_512(const float* BlksumPtr)
liqunfu Jan 10, 2025
e8484eb
Merge branch 'main' into liqun/avx-layout
liqunfu Jan 10, 2025
6dac6ad
sqnbitgemm_kernel_avx512.cpp to apply -mavx512f
liqunfu Jan 10, 2025
429054a
undo sqnbitgemm_kernel_avx512.cpp to apply -mavx512f
liqunfu Jan 11, 2025
b1d7474
restore avx512 blklen32 from use special layout because related code …
liqunfu Jan 11, 2025
5647598
Merge branch 'main' into liqun/avx-layout
liqunfu Mar 17, 2025
af15c91
merge main
liqunfu Apr 2, 2025
42fc7e3
const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input<Tens…
liqunfu Apr 3, 2025
419822b
scales_are_packed_ set to ture in x64
liqunfu Apr 3, 2025
534befe
use scales_are_packed_
liqunfu Apr 3, 2025
e3f1b29
check scales against nullptr
liqunfu Apr 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,13 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
auto sptr = tensor.Data<float>();
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr,
has_zp_input_, nullptr, nullptr);
is_packed = false;
is_packed = true;
scales_are_packed_ = true;
} else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr,
has_zp_input_, zptr, nullptr);
is_packed = false;
is_packed = true;
}
#elif defined(MLAS_TARGET_ARM64)
if (input_idx == InputIndex::scales && packed_b_ != nullptr &&
Expand Down Expand Up @@ -273,12 +274,12 @@ Status MatMulNBits<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, /*ou
if (input_idx == InputIndex::scales && packed_b_ != nullptr) {
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(),
scales_fp32_.get(), has_zp_input_, nullptr, nullptr);
is_packed = false;
is_packed = true;
} else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(),
nullptr, has_zp_input_, zptr, nullptr);
is_packed = false;
is_packed = true;
}
#endif // MLAS_TARGET_AMD64_IX86
}
Expand Down Expand Up @@ -310,8 +311,6 @@ Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const {
const auto* a_data = a->Data<T1>();
const auto* scales_data = scales == nullptr ? nullptr : scales->Data<T1>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* bias_data = bias == nullptr ? nullptr : bias->Data<T1>();
auto* y_data = y->MutableData<T1>();

Expand All @@ -329,16 +328,23 @@ Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a,
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size, true);
}

bool bpacked_with_scale_zp = scales == nullptr;
#ifdef MLAS_TARGET_AMD64_IX86
if (compute_type_ == SQNBIT_CompInt8) {
bpacked_with_scale_zp = true;
}
#endif
InlinedVector<MLAS_QNBIT_GEMM_DATA_PARAMS<T1>> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
if (compute_type_ == SQNBIT_CompInt8) {
if (bpacked_with_scale_zp) {
data[i].QuantBDataWorkspace = packed_b_.get();
} else {
data[i].PackedQuantBData = static_cast<std::byte*>(packed_b_.get());
data[i].QuantBScale = scales->Data<T1>();
data[i].QuantBZeroPoint = zero_points == nullptr ? nullptr : zero_points->DataRaw();
}
data[i].PackedQuantBData = static_cast<std::byte*>(packed_b_.get());
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].Bias = bias_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
Expand All @@ -359,8 +365,6 @@ Status MatMulNBits<MLFloat16>::ComputeBPacked(const Tensor* a,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const {
const auto* a_data = a->Data<MLFloat16>();
const auto* scales_data = scales->Data<MLFloat16>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* bias_data = bias == nullptr ? nullptr : bias->Data<MLFloat16>();
auto* y_data = y->MutableData<MLFloat16>();

Expand All @@ -383,12 +387,15 @@ Status MatMulNBits<MLFloat16>::ComputeBPacked(const Tensor* a,
MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), a_size);

float* scales_ptr = nullptr;
if (!scales_fp32_) {
auto scales_temp = IAllocator::MakeUniquePtr<float>(allocator, static_cast<size_t>(scales->Shape().Size()), true);
MlasConvertHalfToFloatBuffer(scales_data, scales_temp.get(), static_cast<size_t>(scales->Shape().Size()));
scales_ptr = scales_temp.get();
} else {
scales_ptr = scales_fp32_.get();
if (scales) {
const auto* scales_data = scales->Data<MLFloat16>();
if (!scales_fp32_) {
auto scales_temp = IAllocator::MakeUniquePtr<float>(allocator, static_cast<size_t>(scales->Shape().Size()), true);
MlasConvertHalfToFloatBuffer(scales_data, scales_temp.get(), static_cast<size_t>(scales->Shape().Size()));
scales_ptr = scales_temp.get();
} else {
scales_ptr = scales_fp32_.get();
}
}

float* bias_ptr = nullptr;
Expand All @@ -405,18 +412,24 @@ Status MatMulNBits<MLFloat16>::ComputeBPacked(const Tensor* a,
size_t c_size = static_cast<size_t>(y->Shape().Size());
std::vector<float> c_v(c_size);

bool bpacked_with_scale_zp = scales == nullptr;
#ifdef MLAS_TARGET_AMD64_IX86
if (compute_type_ == SQNBIT_CompInt8) {
bpacked_with_scale_zp = true;
}
#endif
InlinedVector<MLAS_QNBIT_GEMM_DATA_PARAMS<float>> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i];
data[i].lda = lda;
#ifdef MLAS_TARGET_AMD64_IX86
if (compute_type_ == SQNBIT_CompInt8) {

if (bpacked_with_scale_zp) {
data[i].QuantBDataWorkspace = packed_b_.get();
} else {
data[i].PackedQuantBData = static_cast<std::byte*>(packed_b_.get());
data[i].QuantBScale = scales_ptr;
data[i].QuantBZeroPoint = zero_points == nullptr ? nullptr : zero_points->DataRaw();
}
#endif
data[i].PackedQuantBData = static_cast<std::byte*>(packed_b_.get());
data[i].QuantBScale = scales_ptr;
data[i].QuantBZeroPoint = zero_points_data;
data[i].Bias = bias ? bias_ptr : nullptr;
data[i].C = c_v.data() + helper.OutputOffsets()[i];
data[i].ldc = N;
Expand Down Expand Up @@ -674,8 +687,6 @@ template <typename T1>
Status MatMulNBits<T1>::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
const Tensor* a = ctx->Input<Tensor>(InputIndex::A);
const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input<Tensor>(InputIndex::scales);
const Tensor* zero_points = ctx->Input<Tensor>(InputIndex::zero_points);
const Tensor* reorder_idx = ctx->Input<Tensor>(InputIndex::g_idx);
const Tensor* bias = ctx->Input<Tensor>(InputIndex::bias);

Expand Down Expand Up @@ -706,12 +717,22 @@ Status MatMulNBits<T1>::Compute(OpKernelContext* ctx) const {
// MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch()
// with B directly too.
if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) {
#ifdef MLAS_TARGET_AMD64_IX86
if (compute_type_ == SQNBIT_CompInt8) {
// scale, zp are prepacked, it have been removed from the context
return ComputeBPacked(a, nullptr, nullptr, bias, y, allocator, thread_pool, helper);
}
#endif
const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input<Tensor>(InputIndex::scales);
const Tensor* zero_points = ctx->Input<Tensor>(InputIndex::zero_points);
return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper);
}
}

// If B is prepacked, B would have been removed from the context
// If B, scale, zp are prepacked, it would have been removed from the context
const Tensor* b = ctx->Input<Tensor>(InputIndex::B);
const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input<Tensor>(InputIndex::scales);
const Tensor* zero_points = ctx->Input<Tensor>(InputIndex::zero_points);
return ComputeBUnpacked(a, b, scales, zero_points, reorder_idx, bias, y, allocator, thread_pool, helper);
}

Expand Down
60 changes: 27 additions & 33 deletions onnxruntime/core/mlas/lib/qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,19 +568,14 @@ SQ4BitGemm_CompInt8(
const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen));
const size_t ldc = DataParams->ldc;
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);

const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda;
const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks;
const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks;

assert(RangeStartN % 4 == 0);
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->PackedQuantBData) + RangeStartN * ldb;
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
const std::byte* QuantBZeroPoint =
(DataParams->QuantBZeroPoint == nullptr)
? nullptr
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks;
const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks;
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;

Expand Down Expand Up @@ -608,42 +603,17 @@ SQ4BitGemm_CompInt8(

const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
#endif

size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});

const std::byte* a_row = QuantA;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) {
size_t RowsRemaining = RangeCountM;
while (RowsRemaining > 0) {
const auto RowsHandled = GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n,
RowsHandled, CountN, ldc
);
}

c_blk += RowsHandled * ldc;
a_row += RowsHandled * lda;

RowsRemaining -= RowsHandled;
}
}
#ifdef MLAS_TARGET_AMD64_IX86
else if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr)
if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr)
{
const float* b_blk_sum = QuantBBlkSum + n * k_blks;
GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8(
Expand All @@ -652,7 +622,6 @@ SQ4BitGemm_CompInt8(
QuantAScale,
b_col,
b_col_scale,
b_col_zp,
c_blk,
RangeCountM,
CountN,
Expand All @@ -671,6 +640,31 @@ SQ4BitGemm_CompInt8(
);
}
}
#else
const std::byte* a_row = QuantA;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) {
size_t RowsRemaining = RangeCountM;
while (RowsRemaining > 0) {
const auto RowsHandled = GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n,
RowsHandled, CountN, ldc
);
}

c_blk += RowsHandled * ldc;
a_row += RowsHandled * lda;

RowsRemaining -= RowsHandled;
}
}
#endif
}
}
Expand Down
17 changes: 1 addition & 16 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@ MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen)
return BlkLen * BlkBitWidth / 8;
}

MLAS_FORCEINLINE void*
MlasAlignAddress(void* addr, const size_t alignment)
{
const uintptr_t QuantBBlkSumAddr = reinterpret_cast<uintptr_t>(addr);
addr = (void*)((QuantBBlkSumAddr + alignment - 1) & (~(alignment - 1)));
return addr;
}

template <typename T>
struct PackedQuantBDataStruct {
PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen)
Expand All @@ -54,15 +46,9 @@ struct PackedQuantBDataStruct {
// TODO: duplicate code from Q4BitGemmPackQuantBDataSize
constexpr size_t BlkBitWidth = 4;
const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T);
#if defined(MLAS_TARGET_AMD64_IX86)
// _mm256_load_si256 requires alignment on a 32-byte boundary
PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32);
#else
size_t BlkSumSize = N * BlockCountK * sizeof(T);
PackedQuantBData = (std::byte*)PackedQuantBWorkspace;
#endif
QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize);
QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment());
PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize);
}
std::byte* PackedQuantBData;
Expand Down Expand Up @@ -331,7 +317,6 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
const float* QuantAScale,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
Expand Down
Loading
Loading