Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,23 @@ Status ConvertToMlasQ4Format(const uint8_t* quantized_data,

DequantizeBlockWithMlas(quantized_data, scales, zero_points, block_size, num_bits, rows, cols, temp_float, nullptr);

size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast<size_t>(cols), static_cast<size_t>(rows));
// Transpose from N x K (weights) to K x N.
// DirectQ4Gemm expects weights to be packed in a specific layout ([K, N] logically)
auto transposed_float_buffer = IAllocator::MakeUniquePtr<float>(allocator, static_cast<size_t>(rows * cols));
float* transposed_float = transposed_float_buffer.get();
for (int64_t r = 0; r < rows; ++r) {
for (int64_t c = 0; c < cols; ++c) {
transposed_float[c * rows + r] = temp_float[r * cols + c];
}
}

size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast<size_t>(rows), static_cast<size_t>(cols));
if (packed_size == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration");
}

mlas_packed_buffer = IAllocator::MakeUniquePtr<uint8_t>(allocator, packed_size);
MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast<size_t>(cols), static_cast<size_t>(rows), static_cast<size_t>(cols));
MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), transposed_float, static_cast<size_t>(rows), static_cast<size_t>(cols), static_cast<size_t>(rows));

return Status::OK();
}
Expand Down Expand Up @@ -634,6 +644,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
float* thread_bias2_buffer = thread_bias1_buffer + static_cast<size_t>(fc1_out_features);

for (int64_t expert_idx : expert_batch) {
bool fc2_bias_added_by_mlas = false;
const auto& routes = expert_token_map[static_cast<size_t>(expert_idx)];
if (routes.empty()) {
continue;
Expand Down Expand Up @@ -711,8 +722,6 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
bool use_direct_q4_gemm = (fc1_zp_data == nullptr) &&
CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0,
fc1_out_features, hidden_size, q_type);
bool fc1_used_direct_q4 = false;
bool fc1_bias_handled_by_q4_gemm = false;

if (use_direct_q4_gemm) {
IAllocatorUniquePtr<uint8_t> mlas_packed_fc1;
Expand Down Expand Up @@ -750,7 +759,6 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
num_expert_tokens, fc1_out_features, hidden_size, q_type, tp);

if (gemm_status.IsOK()) {
fc1_used_direct_q4 = true;
goto fc1_gemm_done;
}
}
Expand Down Expand Up @@ -797,8 +805,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
0.0f, C1, n,
tp);

fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias;
if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) {
if (has_fc1_bias) {
const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features;
if constexpr (std::is_same_v<T, MLFloat16>) {
MlasConvertHalfToFloatBuffer(reinterpret_cast<const MLFloat16*>(B1_bias), thread_bias1_buffer, static_cast<size_t>(fc1_out_features));
Expand Down Expand Up @@ -891,7 +898,6 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
bool use_direct_q4_gemm_fc2 = (fc2_zp_data == nullptr) &&
CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0,
hidden_size, inter_size, q_type2);
bool fc2_used_direct_q4 = false;

if (use_direct_q4_gemm_fc2) {
IAllocatorUniquePtr<uint8_t> mlas_packed_fc2;
Expand Down Expand Up @@ -929,7 +935,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
num_expert_tokens, hidden_size, inter_size, q_type2, tp);

if (gemm_status.IsOK()) {
fc2_used_direct_q4 = true;
fc2_bias_added_by_mlas = true;
goto fc2_gemm_done;
}
}
Expand Down Expand Up @@ -979,8 +985,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {

fc2_gemm_done:

bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias;
if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) {
if (has_fc2_bias && !fc2_bias_added_by_mlas) {
const T* B2_bias = fc2_bias_data + expert_idx * hidden_size;
if constexpr (std::is_same_v<T, MLFloat16>) {
MlasConvertHalfToFloatBuffer(reinterpret_cast<const MLFloat16*>(B2_bias), thread_bias2_buffer, static_cast<size_t>(hidden_size));
Expand Down Expand Up @@ -1015,7 +1020,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
float* dest = thread_local_outputs + static_cast<size_t>(thread_id) * output_buffer_size + buffer_offset;
const float* src = C2 + i * hidden_size;

if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) {
if (has_fc2_bias && !fc2_bias_added_by_mlas) {
const size_t unroll_factor = narrow<size_t>(GetUnrollFactor(hidden_size));
size_t j = 0;
for (; j + unroll_factor <= narrow<size_t>(hidden_size); j += unroll_factor) {
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ MlasQ4GemmPackBSize(
*
* @param QType type of block quantization
* @param PackedBuf destination buffer
* @param FpData the pointer to fp32 matrix
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
* @param FpData the pointer to fp32 matrix, with shape [K, N].
* @param N the number of columns of matrix B (Output Channels).
* @param K the number of rows of matrix B (Input Channels).
* @param ldb leading dimension of FpData (usually N)
*/
void
MLASCALL
Expand Down
Loading
Loading