Skip to content

Commit ed0478a

Browse files
committed
Fix QMoE CPU
1 parent 0df5dbc commit ed0478a

File tree

3 files changed

+126
-94
lines changed

3 files changed

+126
-94
lines changed

onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,23 @@ Status ConvertToMlasQ4Format(const uint8_t* quantized_data,
118118

119119
DequantizeBlockWithMlas(quantized_data, scales, zero_points, block_size, num_bits, rows, cols, temp_float, nullptr);
120120

121-
size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast<size_t>(cols), static_cast<size_t>(rows));
121+
// Transpose from N x K (weights) to K x N.
122+
// DirectQ4Gemm expects weights to be packed in a specific layout ([K, N] logically)
123+
auto transposed_float_buffer = IAllocator::MakeUniquePtr<float>(allocator, static_cast<size_t>(rows * cols));
124+
float* transposed_float = transposed_float_buffer.get();
125+
for (int64_t r = 0; r < rows; ++r) {
126+
for (int64_t c = 0; c < cols; ++c) {
127+
transposed_float[c * rows + r] = temp_float[r * cols + c];
128+
}
129+
}
130+
131+
size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast<size_t>(rows), static_cast<size_t>(cols));
122132
if (packed_size == 0) {
123133
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration");
124134
}
125135

126136
mlas_packed_buffer = IAllocator::MakeUniquePtr<uint8_t>(allocator, packed_size);
127-
MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast<size_t>(cols), static_cast<size_t>(rows), static_cast<size_t>(cols));
137+
MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), transposed_float, static_cast<size_t>(rows), static_cast<size_t>(cols), static_cast<size_t>(rows));
128138

129139
return Status::OK();
130140
}
@@ -634,6 +644,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
634644
float* thread_bias2_buffer = thread_bias1_buffer + static_cast<size_t>(fc1_out_features);
635645

636646
for (int64_t expert_idx : expert_batch) {
647+
bool fc2_bias_added_by_mlas = false;
637648
const auto& routes = expert_token_map[static_cast<size_t>(expert_idx)];
638649
if (routes.empty()) {
639650
continue;
@@ -711,8 +722,6 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
711722
bool use_direct_q4_gemm = (fc1_zp_data == nullptr) &&
712723
CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0,
713724
fc1_out_features, hidden_size, q_type);
714-
bool fc1_used_direct_q4 = false;
715-
bool fc1_bias_handled_by_q4_gemm = false;
716725

717726
if (use_direct_q4_gemm) {
718727
IAllocatorUniquePtr<uint8_t> mlas_packed_fc1;
@@ -750,7 +759,6 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
750759
num_expert_tokens, fc1_out_features, hidden_size, q_type, tp);
751760

752761
if (gemm_status.IsOK()) {
753-
fc1_used_direct_q4 = true;
754762
goto fc1_gemm_done;
755763
}
756764
}
@@ -797,8 +805,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
797805
0.0f, C1, n,
798806
tp);
799807

800-
fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias;
801-
if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) {
808+
if (has_fc1_bias) {
802809
const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features;
803810
if constexpr (std::is_same_v<T, MLFloat16>) {
804811
MlasConvertHalfToFloatBuffer(reinterpret_cast<const MLFloat16*>(B1_bias), thread_bias1_buffer, static_cast<size_t>(fc1_out_features));
@@ -891,7 +898,6 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
891898
bool use_direct_q4_gemm_fc2 = (fc2_zp_data == nullptr) &&
892899
CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0,
893900
hidden_size, inter_size, q_type2);
894-
bool fc2_used_direct_q4 = false;
895901

896902
if (use_direct_q4_gemm_fc2) {
897903
IAllocatorUniquePtr<uint8_t> mlas_packed_fc2;
@@ -929,7 +935,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
929935
num_expert_tokens, hidden_size, inter_size, q_type2, tp);
930936

931937
if (gemm_status.IsOK()) {
932-
fc2_used_direct_q4 = true;
938+
fc2_bias_added_by_mlas = true;
933939
goto fc2_gemm_done;
934940
}
935941
}
@@ -979,8 +985,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
979985

980986
fc2_gemm_done:
981987

982-
bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias;
983-
if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) {
988+
if (has_fc2_bias && !fc2_bias_added_by_mlas) {
984989
const T* B2_bias = fc2_bias_data + expert_idx * hidden_size;
985990
if constexpr (std::is_same_v<T, MLFloat16>) {
986991
MlasConvertHalfToFloatBuffer(reinterpret_cast<const MLFloat16*>(B2_bias), thread_bias2_buffer, static_cast<size_t>(hidden_size));
@@ -1015,7 +1020,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
10151020
float* dest = thread_local_outputs + static_cast<size_t>(thread_id) * output_buffer_size + buffer_offset;
10161021
const float* src = C2 + i * hidden_size;
10171022

1018-
if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) {
1023+
if (has_fc2_bias && !fc2_bias_added_by_mlas) {
10191024
const size_t unroll_factor = narrow<size_t>(GetUnrollFactor(hidden_size));
10201025
size_t j = 0;
10211026
for (; j + unroll_factor <= narrow<size_t>(hidden_size); j += unroll_factor) {

onnxruntime/core/mlas/inc/mlas_q4.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ MlasQ4GemmPackBSize(
5757
*
5858
* @param QType type of block quantization
5959
* @param PackedBuf destination buffer
60-
* @param FpData the pointer to fp32 matrix
61-
* @param N the number of columns of matrix B.
62-
* @param K the number of rows of matrix B.
63-
* @param ldb leading dimension of B
60+
* @param FpData the pointer to fp32 matrix, with shape [K, N].
61+
* @param N the number of columns of matrix B (Output Channels).
62+
* @param K the number of rows of matrix B (Input Channels).
63+
* @param ldb leading dimension of FpData (usually N)
6464
*/
6565
void
6666
MLASCALL

0 commit comments

Comments
 (0)