@@ -118,13 +118,23 @@ Status ConvertToMlasQ4Format(const uint8_t* quantized_data,
118118
119119 DequantizeBlockWithMlas (quantized_data, scales, zero_points, block_size, num_bits, rows, cols, temp_float, nullptr );
120120
121- size_t packed_size = MlasQ4GemmPackBSize (qtype, static_cast <size_t >(cols), static_cast <size_t >(rows));
121+ // Transpose from N x K (weights) to K x N.
122+ // DirectQ4Gemm expects weights to be packed in a specific layout ([K, N] logically)
123+ auto transposed_float_buffer = IAllocator::MakeUniquePtr<float >(allocator, static_cast <size_t >(rows * cols));
124+ float * transposed_float = transposed_float_buffer.get ();
125+ for (int64_t r = 0 ; r < rows; ++r) {
126+ for (int64_t c = 0 ; c < cols; ++c) {
127+ transposed_float[c * rows + r] = temp_float[r * cols + c];
128+ }
129+ }
130+
131+ size_t packed_size = MlasQ4GemmPackBSize (qtype, static_cast <size_t >(rows), static_cast <size_t >(cols));
122132 if (packed_size == 0 ) {
123133 return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " MLAS Q4 packing not supported for this configuration" );
124134 }
125135
126136 mlas_packed_buffer = IAllocator::MakeUniquePtr<uint8_t >(allocator, packed_size);
127- MlasQ4GemmPackB (qtype, mlas_packed_buffer.get (), temp_float , static_cast <size_t >(cols ), static_cast <size_t >(rows ), static_cast <size_t >(cols ));
137+ MlasQ4GemmPackB (qtype, mlas_packed_buffer.get (), transposed_float , static_cast <size_t >(rows ), static_cast <size_t >(cols ), static_cast <size_t >(rows ));
128138
129139 return Status::OK ();
130140}
@@ -634,6 +644,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
634644 float * thread_bias2_buffer = thread_bias1_buffer + static_cast <size_t >(fc1_out_features);
635645
636646 for (int64_t expert_idx : expert_batch) {
647+ bool fc2_bias_added_by_mlas = false ;
637648 const auto & routes = expert_token_map[static_cast <size_t >(expert_idx)];
638649 if (routes.empty ()) {
639650 continue ;
@@ -711,8 +722,6 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
711722 bool use_direct_q4_gemm = (fc1_zp_data == nullptr ) &&
712723 CanUseMlasQ4Gemm (expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0 ,
713724 fc1_out_features, hidden_size, q_type);
714- bool fc1_used_direct_q4 = false ;
715- bool fc1_bias_handled_by_q4_gemm = false ;
716725
717726 if (use_direct_q4_gemm) {
718727 IAllocatorUniquePtr<uint8_t > mlas_packed_fc1;
@@ -750,7 +759,6 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
750759 num_expert_tokens, fc1_out_features, hidden_size, q_type, tp);
751760
752761 if (gemm_status.IsOK ()) {
753- fc1_used_direct_q4 = true ;
754762 goto fc1_gemm_done;
755763 }
756764 }
@@ -797,8 +805,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
797805 0 .0f , C1, n,
798806 tp);
799807
800- fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias;
801- if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) {
808+ if (has_fc1_bias) {
802809 const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features;
803810 if constexpr (std::is_same_v<T, MLFloat16>) {
804811 MlasConvertHalfToFloatBuffer (reinterpret_cast <const MLFloat16*>(B1_bias), thread_bias1_buffer, static_cast <size_t >(fc1_out_features));
@@ -891,7 +898,6 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
891898 bool use_direct_q4_gemm_fc2 = (fc2_zp_data == nullptr ) &&
892899 CanUseMlasQ4Gemm (expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0 ,
893900 hidden_size, inter_size, q_type2);
894- bool fc2_used_direct_q4 = false ;
895901
896902 if (use_direct_q4_gemm_fc2) {
897903 IAllocatorUniquePtr<uint8_t > mlas_packed_fc2;
@@ -929,7 +935,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
929935 num_expert_tokens, hidden_size, inter_size, q_type2, tp);
930936
931937 if (gemm_status.IsOK ()) {
932- fc2_used_direct_q4 = true ;
938+ fc2_bias_added_by_mlas = true ;
933939 goto fc2_gemm_done;
934940 }
935941 }
@@ -979,8 +985,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
979985
980986 fc2_gemm_done:
981987
982- bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias;
983- if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) {
988+ if (has_fc2_bias && !fc2_bias_added_by_mlas) {
984989 const T* B2_bias = fc2_bias_data + expert_idx * hidden_size;
985990 if constexpr (std::is_same_v<T, MLFloat16>) {
986991 MlasConvertHalfToFloatBuffer (reinterpret_cast <const MLFloat16*>(B2_bias), thread_bias2_buffer, static_cast <size_t >(hidden_size));
@@ -1015,7 +1020,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
10151020 float * dest = thread_local_outputs + static_cast <size_t >(thread_id) * output_buffer_size + buffer_offset;
10161021 const float * src = C2 + i * hidden_size;
10171022
1018- if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm ) {
1023+ if (has_fc2_bias && !fc2_bias_added_by_mlas ) {
10191024 const size_t unroll_factor = narrow<size_t >(GetUnrollFactor (hidden_size));
10201025 size_t j = 0 ;
10211026 for (; j + unroll_factor <= narrow<size_t >(hidden_size); j += unroll_factor) {
0 commit comments