diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index 257c5a189b3bd..f9122406be633 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -35,44 +35,108 @@ struct MoEParameters { }; namespace moe_helper { +// Helper to check shape dimensions +#define ASSERT_SHAPE_DIMENSION(shape_ptr, dim, name) \ + if (shape_ptr != nullptr) { \ + if (shape_ptr->NumDimensions() != dim) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \ + "' is expected to have ", dim, " dimensions, got ", \ + shape_ptr->NumDimensions()); \ + } \ + } + +#define ASSERT_SHAPE_3D(shape_ptr, name) ASSERT_SHAPE_DIMENSION(shape_ptr, 3, name) + +#define CHECK_SHAPE(shape_ptr, name, ...) \ + if (shape_ptr != nullptr) { \ + const TensorShape& expected_shape = make_shape(__VA_ARGS__); \ + if (*shape_ptr != expected_shape) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \ + "' is expected to have shape ", expected_shape, \ + ", got ", *shape_ptr); \ + } \ + } + template Status CheckInputs(MoEParameters& parameters, - const Tensor* input, // required - const Tensor* router_probs, // required - const Tensor* fc1_experts_weights, // required - const Tensor* fc1_experts_bias, // optional - const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc1_zero_points, // optional, for qMoE - const Tensor* fc2_experts_weights, // required - const Tensor* fc2_experts_bias, // optional - const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc2_zero_points, // optional, for qMoE - const Tensor* fc3_experts_weights, // optional - const Tensor* fc3_experts_bias, // optional - const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc3_zero_points, // optional, for qMoE - const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const Tensor* input, // required + const Tensor* router_probs, // required + const TensorShape* fc1_experts_weights_shape, + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc1_zero_points, // optional, for qMoE + const TensorShape* fc2_experts_weights_shape, + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_zero_points, // optional, for qMoE + const TensorShape* fc3_experts_weights_shape, + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_zero_points, // optional, for qMoE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) const bool is_fused_swiglu, const int64_t block_size = 0) { // block size for block-wise quantization // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. + if (input == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is required."); + } ASSERT_TENSOR_2D_OR_3D(input); - ASSERT_TENSOR_3D(fc1_experts_weights); - ASSERT_TENSOR_3D(fc2_experts_weights); + + if (router_probs == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'router_probs' is required."); + } ASSERT_TENSOR_2D(router_probs); + if (fc1_experts_weights_shape == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc1_experts_weights' is required."); + } + ASSERT_SHAPE_3D(fc1_experts_weights_shape, "fc1_experts_weights"); + + if (fc2_experts_weights_shape == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc2_experts_weights' is required."); + } + ASSERT_SHAPE_3D(fc2_experts_weights_shape, "fc2_experts_weights"); + const auto& input_dims = input->Shape().GetDims(); const auto& router_probs_dims = router_probs->Shape().GetDims(); - const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); - const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; int64_t hidden_size = input_dims[input_dims.size() - 1]; - int64_t local_num_experts = fc1_experts_weights_dims[0]; int64_t num_experts = router_probs_dims[1]; - int64_t inter_size = (fc2_experts_weights_dims[1] * fc2_experts_weights_dims[2] * pack_size) / hidden_size; - const bool legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || - (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); + int64_t local_num_experts; + if (fc1_experts_weights_shape != nullptr) { + local_num_experts = fc1_experts_weights_shape->GetDims()[0]; + } else if (fc1_experts_scales != nullptr) { + local_num_experts = fc1_experts_scales->Shape().GetDims()[0]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid MoE configuration: both fc1_experts_weights and fc1_experts_scales are null. " + "At least one must be provided."); + } + + int64_t inter_size; + if (fc2_experts_weights_shape != nullptr) { + const auto& dims = fc2_experts_weights_shape->GetDims(); + inter_size = (dims[1] * dims[2] * pack_size) / hidden_size; + } else if (fc3_experts_scales != nullptr) { + inter_size = fc3_experts_scales->Shape().GetDims()[1]; + } else if (fc1_experts_scales != nullptr) { + int64_t fc1_inter_size = fc1_experts_scales->Shape().GetDims()[1]; + inter_size = is_fused_swiglu ? fc1_inter_size / 2 : fc1_inter_size; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid MoE configuration: unable to infer inter_size because " + "fc2_experts_weights, fc3_experts_scales, and fc1_experts_scales are all null."); + } + + bool legacy_shape = false; + if (fc2_experts_weights_shape != nullptr && fc1_experts_weights_shape != nullptr) { + const auto& fc2_experts_weights_dims = fc2_experts_weights_shape->GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights_shape->GetDims(); + legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || + (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); + } // Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one. const int64_t fc1_inter_size = is_fused_swiglu ? (inter_size + inter_size) : inter_size; @@ -80,13 +144,13 @@ Status CheckInputs(MoEParameters& parameters, if (legacy_shape) { // legacy shape does not match column major memory layout. This is for backward compatibility. - CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size); - CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size); - CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size); + CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, hidden_size, fc1_inter_size / pack_size); + CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, hidden_size, inter_size / pack_size); } else { - CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size); - CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size); - CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, fc1_inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, hidden_size, inter_size / pack_size); + CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, inter_size, hidden_size / pack_size); } CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts); @@ -168,9 +232,11 @@ Status CheckInputs(MoEParameters& parameters, } } - if (fc3_experts_weights == nullptr) { + if (fc3_experts_weights_shape == nullptr) { + // If fc3 weights are not provided, ensure no other fc3 parameters are provided ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr && fc3_zero_points == nullptr); } else { + // If fc3 weights are provided, ensure scales logic is consistent ORT_ENFORCE(fc1_experts_scales == nullptr || fc3_experts_scales != nullptr); // MOE no scale, or qMOE need scales } @@ -200,6 +266,36 @@ Status CheckInputs(MoEParameters& parameters, return Status::OK(); } +template +Status CheckInputs(MoEParameters& parameters, + const Tensor* input, // required + const Tensor* router_probs, // required + const Tensor* fc1_experts_weights, // required + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc1_zero_points, // optional, for qMoE + const Tensor* fc2_experts_weights, // required + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_zero_points, // optional, for qMoE + const Tensor* fc3_experts_weights, // optional + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_zero_points, // optional, for qMoE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const bool is_fused_swiglu, + const int64_t block_size = 0) { // block size for block-wise quantization + + const TensorShape* fc1_shape = (fc1_experts_weights != nullptr) ? &fc1_experts_weights->Shape() : nullptr; + const TensorShape* fc2_shape = (fc2_experts_weights != nullptr) ? &fc2_experts_weights->Shape() : nullptr; + const TensorShape* fc3_shape = (fc3_experts_weights != nullptr) ? &fc3_experts_weights->Shape() : nullptr; + + return CheckInputs(parameters, input, router_probs, fc1_shape, fc1_experts_bias, fc1_experts_scales, fc1_zero_points, + fc2_shape, fc2_experts_bias, fc2_experts_scales, fc2_zero_points, + fc3_shape, fc3_experts_bias, fc3_experts_scales, fc3_zero_points, + pack_size, is_fused_swiglu, block_size); +} + } // namespace moe_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 6d1d191689466..483e5184f63ac 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -13,6 +13,7 @@ #include "core/common/narrow.h" #include "core/framework/tensor_type_and_shape.h" #include "core/util/math.h" +#include "core/platform/env_var_utils.h" #include "contrib_ops/cpu/moe/moe_utils.h" #include "contrib_ops/cpu/moe/moe_helper.h" @@ -69,7 +70,7 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, out_qtype = BlkQ4Sym64; } else if (block_size == 128) { out_qtype = BlkQ4Sym128; - } else if (block_size == 0) { + } else if (block_size == 0 || block_size == 32) { out_qtype = BlkQ4Sym; } else { return false; @@ -84,6 +85,8 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, namespace onnxruntime { namespace contrib { +constexpr const char* kUseMlasQ4GemmMoe = "ORT_USE_MLAS_Q4_GEMM_MOE"; + template void DequantizeBlockWithMlas(const uint8_t* quantized_data, const TScale* scales, @@ -118,13 +121,23 @@ Status ConvertToMlasQ4Format(const uint8_t* quantized_data, DequantizeBlockWithMlas(quantized_data, scales, zero_points, block_size, num_bits, rows, cols, temp_float, nullptr); - size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(cols), static_cast(rows)); + // Transpose from N x K (weights) to K x N. + // DirectQ4Gemm expects weights to be packed in a specific layout ([K, N] logically) + auto transposed_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + float* transposed_float = transposed_float_buffer.get(); + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + transposed_float[c * rows + r] = temp_float[r * cols + c]; + } + } + + size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); if (packed_size == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); } mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); - MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast(cols), static_cast(rows), static_cast(cols)); + MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), transposed_float, static_cast(rows), static_cast(cols), static_cast(rows)); return Status::OK(); } @@ -354,6 +367,273 @@ void DequantizeBlock(const uint8_t* quantized_data, DequantizeBlockWithMlas(quantized_data, scales, zero_points, block_size, num_bits, rows, cols, dequantized_data, thread_pool); } +template +void DequantizePrePacked(const uint8_t* prepacked_data, + const TScale* scales, + const uint8_t* zero_points, + int64_t block_size, + int64_t rows, + int64_t cols, + float* dequantized_data, + const gsl::span& scale_dims) { + // prepacked_data is [cols, rows] (transposed, unpacked) + // dequantized_data is [cols, rows] (transposed) + // scales, zero_points correspond to original [rows, cols] layout + + const float default_zp_4bit = 8.0f; + const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; + const int64_t zp_pack_size = 2; // Always 2 for 4-bit + + // Iterate over Columns (K) then Rows (N) because prepacked_data is [K, N] + for (int64_t c = 0; c < cols; ++c) { + for (int64_t r = 0; r < rows; ++r) { + uint8_t val = prepacked_data[c * rows + r]; + + int64_t block_idx = (block_size > 0) ? (c / block_size) : 0; + if (block_size > 0) block_idx = std::min(block_idx, blocks_per_row - 1); + + int64_t scale_idx; + if (scale_dims.size() == 3 && scale_dims[2] > 1) { // block-wise + scale_idx = r * blocks_per_row + block_idx; + } else { // per-channel + scale_idx = r; + } + + float scale = static_cast(scales[scale_idx]); + float zp = default_zp_4bit; + + if (zero_points != nullptr) { + int64_t zp_idx; + bool is_lower_nibble; + + if (scale_dims.size() == 3 && scale_dims[2] > 1) { // block-wise + int64_t zp_blocks_packed = (blocks_per_row + zp_pack_size - 1) / zp_pack_size; + zp_idx = r * zp_blocks_packed + block_idx / 2; + is_lower_nibble = (block_idx % 2 == 0); + } else { + zp_idx = r / 2; + is_lower_nibble = (r % 2 == 0); + } + + uint8_t packed_zp = zero_points[zp_idx]; + zp = is_lower_nibble ? static_cast(packed_zp & 0x0F) : static_cast(packed_zp >> 4); + } + + dequantized_data[c * rows + r] = scale * (static_cast(val) - zp); + } + } +} + +template +Status BuildDirectQ4PackedBCache(const uint8_t* prepacked_weights, + const TScale* scales_data, + int64_t num_experts, + int64_t rows, + int64_t cols, + int64_t block_size, + const gsl::span& scales_dims, + MLAS_BLK_QUANT_TYPE qtype, + AllocatorPtr allocator, + IAllocatorUniquePtr& packed_b) { + const size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); + if (packed_size == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to compute MLAS Q4 packed size for cache"); + } + + const bool is_block_wise = (scales_dims.size() == 3 && scales_dims[2] > 1); + const int64_t scales_expert_stride = is_block_wise ? (rows * scales_dims[2]) : rows; + const size_t prepacked_expert_stride = static_cast(rows * cols); + const size_t total_packed_size = packed_size * static_cast(num_experts); + + packed_b = IAllocator::MakeUniquePtr(allocator, total_packed_size, true); + uint8_t* packed_b_ptr = static_cast(packed_b.get()); + + std::vector dequantized_transposed(static_cast(rows * cols)); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const uint8_t* expert_prepacked = prepacked_weights + static_cast(expert_idx) * prepacked_expert_stride; + const TScale* expert_scales = scales_data + expert_idx * scales_expert_stride; + + DequantizePrePacked(expert_prepacked, expert_scales, nullptr, block_size, rows, cols, + dequantized_transposed.data(), scales_dims); + + MlasQ4GemmPackB(qtype, packed_b_ptr + expert_idx * packed_size, dequantized_transposed.data(), + static_cast(rows), static_cast(cols), static_cast(rows)); + } + + return Status::OK(); +} + +template +Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + + // If scales are prepacked, they are constant initializers. + if (input_idx == 3) { + return Status::OK(); + } + if (input_idx == 6) { + return Status::OK(); + } + + // Only support PrePack for FC1 (2), FC2 (5), and FC3 (8) weights + // and only if expert_weight_bits_ == 4 (since we unpack to uint8) + if (expert_weight_bits_ != 4) { + return Status::OK(); + } + + if (input_idx == 2 || input_idx == 5 || input_idx == 8) { + const auto& shape = tensor.Shape(); + const int64_t num_experts = shape[0]; + const int64_t rows = shape[1]; + const int64_t cols_packed = shape[2]; + const int64_t cols = cols_packed * 2; + + size_t packed_size = static_cast(num_experts * rows * cols); + auto packed_buffer = IAllocator::MakeUniquePtr(alloc, packed_size, true); + uint8_t* dst_base = static_cast(packed_buffer.get()); + const uint8_t* src_base = static_cast(tensor.DataRaw()); + + for (int64_t i = 0; i < num_experts; ++i) { + const uint8_t* src = src_base + i * rows * cols_packed; + uint8_t* dst = dst_base + i * rows * cols; + + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + uint8_t packed_val = src[r * cols_packed + (c / 2)]; + uint8_t val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); + + dst[c * rows + r] = val; + } + } + } + + if (input_idx == 2) { + fc1_shape_ = shape; + } else if (input_idx == 5) { + fc2_shape_ = shape; + } else if (input_idx == 8) { + fc3_shape_ = shape; + } + + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_buffer)); + prepacked_weights->buffer_sizes_.push_back(packed_size); + is_packed = true; + + // Pack Shape (Buffer 1) + auto dims = shape.GetDims(); + size_t rank_bytes = sizeof(int64_t); + size_t dims_bytes = dims.size() * sizeof(int64_t); + size_t shape_size = rank_bytes + dims_bytes; + + auto shape_buffer = IAllocator::MakeUniquePtr(alloc, shape_size); + int64_t* buffer_data = static_cast(shape_buffer.get()); + *buffer_data = static_cast(dims.size()); + memcpy(buffer_data + 1, dims.data(), dims_bytes); + + prepacked_weights->buffers_.push_back(std::move(shape_buffer)); + prepacked_weights->buffer_sizes_.push_back(shape_size); + + // Try build MLAS Q4 cache if scales are available + if (use_mlas_q4_gemm_) { + const Tensor* scales_tensor = nullptr; + MLAS_BLK_QUANT_TYPE qtype = BlkQ4Sym; + int scales_idx = -1; + int zp_idx = -1; + + if (input_idx == 2) { // FC1 + scales_idx = 3; + zp_idx = 11; + CanUseMlasQ4Gemm(expert_weight_bits_, block_size_, rows, cols, qtype); + } else if (input_idx == 5) { // FC2 + scales_idx = 6; + zp_idx = 12; + CanUseMlasQ4Gemm(expert_weight_bits_, block_size_, rows, cols, qtype); + } + // FC3 (8) not supported for now + + if (scales_idx != -1 && + !Info().node().InputDefs()[zp_idx]->Exists() && + Info().TryGetConstantInput(scales_idx, &scales_tensor) && + scales_tensor != nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, block_size_ > 0 ? block_size_ : 0, rows, cols, qtype)) { + IAllocatorUniquePtr cache_buffer; + const auto& scales_dims = scales_tensor->Shape().GetDims(); + const T* scales_data = scales_tensor->Data(); + // Use the simple packed buffer we just created (buffer 0) as input + const uint8_t* simple_packed = dst_base; + + if (BuildDirectQ4PackedBCache(simple_packed, scales_data, num_experts, rows, cols, + block_size_ > 0 ? block_size_ : 0, scales_dims, qtype, + alloc, cache_buffer) + .IsOK()) { + // Store the size so we can verify later? Container holds size. + // We push it as a THIRD buffer (Buffer 2) now. + size_t cache_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)) * static_cast(num_experts); + prepacked_weights->buffers_.push_back(std::move(cache_buffer)); + prepacked_weights->buffer_sizes_.push_back(cache_size); + } + } + } + } + } + + return Status::OK(); +} + +template +Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + + if (expert_weight_bits_ != 4) { + return Status::OK(); + } + + if (input_idx == 2 && !prepacked_buffers.empty()) { + packed_fc1_ = std::move(prepacked_buffers[0]); + if (prepacked_buffers.size() > 1) { + int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); + int64_t rank = buffer_data[0]; + std::vector dims(rank); + memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t)); + fc1_shape_ = TensorShape(dims); + } + if (prepacked_buffers.size() > 2) { + packed_fc1_mlas_cache_ = std::move(prepacked_buffers[2]); + } + used_shared_buffers = true; + } else if (input_idx == 5 && !prepacked_buffers.empty()) { + packed_fc2_ = std::move(prepacked_buffers[0]); + if (prepacked_buffers.size() > 1) { + int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); + int64_t rank = buffer_data[0]; + std::vector dims(rank); + memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t)); + fc2_shape_ = TensorShape(dims); + } + if (prepacked_buffers.size() > 2) { + packed_fc2_mlas_cache_ = std::move(prepacked_buffers[2]); + } + used_shared_buffers = true; + } else if (input_idx == 8 && !prepacked_buffers.empty()) { + packed_fc3_ = std::move(prepacked_buffers[0]); + if (prepacked_buffers.size() > 1) { + int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); + int64_t rank = buffer_data[0]; + std::vector dims(rank); + memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t)); + fc3_shape_ = TensorShape(dims); + } + used_shared_buffers = true; + } + + return Status::OK(); +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), @@ -367,36 +647,50 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(block_size_ >= 16, "block_size must be >= 16 when provided."); ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); } + + const auto use_mlas_q4_gemm = ParseEnvironmentVariable(kUseMlasQ4GemmMoe); + if (use_mlas_q4_gemm.has_value()) { + use_mlas_q4_gemm_ = *use_mlas_q4_gemm; + use_mlas_q4_gemm_overridden_ = true; + } else { + // Default policy: enable fast path unless this run hits a known accuracy-loss configuration. + use_mlas_q4_gemm_ = true; + use_mlas_q4_gemm_overridden_ = false; + } } template Status QMoECPU::Compute(OpKernelContext* context) const { const auto* input = context->Input(0); const auto* router_probs = context->Input(1); - const auto* fc1_experts_weights = context->Input(2); + const auto* fc1_experts_weights = packed_fc1_ ? nullptr : context->Input(2); const auto* fc1_scales = context->Input(3); const auto* fc1_experts_bias = context->Input(4); - const auto* fc2_experts_weights = context->Input(5); + const auto* fc2_experts_weights = packed_fc2_ ? nullptr : context->Input(5); const auto* fc2_scales = context->Input(6); const auto* fc2_experts_bias = context->Input(7); - const auto* fc3_experts_weights = context->Input(8); + const auto* fc3_experts_weights = packed_fc3_ ? nullptr : context->Input(8); const auto* fc3_scales = context->Input(9); const auto* fc3_experts_bias = context->Input(10); const auto* fc1_zero_points = context->Input(11); const auto* fc2_zero_points = context->Input(12); const auto* fc3_zero_points = context->Input(13); + const TensorShape* fc1_shape_ptr = packed_fc1_ ? &fc1_shape_ : (fc1_experts_weights ? &fc1_experts_weights->Shape() : nullptr); + const TensorShape* fc2_shape_ptr = packed_fc2_ ? &fc2_shape_ : (fc2_experts_weights ? &fc2_experts_weights->Shape() : nullptr); + const TensorShape* fc3_shape_ptr = packed_fc3_ ? &fc3_shape_ : (fc3_experts_weights ? &fc3_experts_weights->Shape() : nullptr); + MoEParameters moe_params; ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias, fc1_scales, fc1_zero_points, - fc2_experts_weights, fc2_experts_bias, fc2_scales, fc2_zero_points, - fc3_experts_weights, fc3_experts_bias, fc3_scales, fc3_zero_points, + fc1_shape_ptr, fc1_experts_bias, fc1_scales, fc1_zero_points, + fc2_shape_ptr, fc2_experts_bias, fc2_scales, fc2_zero_points, + fc3_shape_ptr, fc3_experts_bias, fc3_scales, fc3_zero_points, expert_weight_bits_ == 4 ? 2 : 1, true, block_size_)); - if (fc3_experts_weights || fc3_experts_bias || fc3_scales || fc3_zero_points) { + if (fc3_shape_ptr || fc3_experts_bias || fc3_scales || fc3_zero_points) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); } @@ -559,8 +853,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const bool is_fc1_block_wise = (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1); const bool is_fc2_block_wise = (fc2_scales_dims.size() == 3 && fc2_scales_dims[2] > 1); - const uint8_t* fc1_weights_data = fc1_experts_weights->Data(); - const uint8_t* fc2_weights_data = fc2_experts_weights->Data(); + const uint8_t* fc1_weights_data = (packed_fc1_ != nullptr) ? nullptr : fc1_experts_weights->template Data(); + const uint8_t* fc2_weights_data = (packed_fc2_ != nullptr) ? nullptr : fc2_experts_weights->template Data(); const T* fc1_scales_data = fc1_scales->Data(); const T* fc2_scales_data = fc2_scales->Data(); const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; @@ -568,6 +862,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const uint8_t* fc1_zp_data = fc1_zero_points ? fc1_zero_points->Data() : nullptr; const uint8_t* fc2_zp_data = fc2_zero_points ? fc2_zero_points->Data() : nullptr; + // Known loss-prone case from parity testing: 4-bit symmetric path (row-wise and block-wise). + const bool known_accuracy_loss_case = (expert_weight_bits_ == 4) && + (fc1_zp_data == nullptr) && (fc2_zp_data == nullptr); + const bool use_mlas_q4_gemm_effective = use_mlas_q4_gemm_overridden_ + ? use_mlas_q4_gemm_ + : (use_mlas_q4_gemm_ && !known_accuracy_loss_case); + const int64_t pack_unit = (8 / expert_weight_bits_); const int64_t fc1_packed_cols = (hidden_size + pack_unit - 1) / pack_unit; const int64_t fc2_packed_cols = (inter_size + pack_unit - 1) / pack_unit; @@ -595,6 +896,22 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_zp_expert_stride = (hidden_size + zp_pack_size - 1) / zp_pack_size; } + MLAS_BLK_QUANT_TYPE fc1_direct_qtype = BlkQ4Sym; + MLAS_BLK_QUANT_TYPE fc2_direct_qtype = BlkQ4Sym; + + // Use pre-packed MLAS cache if available + const void* fc1_direct_q4_cache_ptr = nullptr; + if (use_mlas_q4_gemm_effective && packed_fc1_mlas_cache_ && fc1_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, fc1_direct_qtype)) { + fc1_direct_q4_cache_ptr = packed_fc1_mlas_cache_.get(); + } + + const void* fc2_direct_q4_cache_ptr = nullptr; + if (use_mlas_q4_gemm_effective && packed_fc2_mlas_cache_ && fc2_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, fc2_direct_qtype)) { + fc2_direct_q4_cache_ptr = packed_fc2_mlas_cache_.get(); + } + std::vector> expert_workload; size_t total_work = 0; @@ -634,6 +951,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* thread_bias2_buffer = thread_bias1_buffer + static_cast(fc1_out_features); for (int64_t expert_idx : expert_batch) { + bool fc2_bias_added_by_mlas = false; const auto& routes = expert_token_map[static_cast(expert_idx)]; if (routes.empty()) { continue; @@ -707,12 +1025,57 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t k = static_cast(hidden_size); MLAS_BLK_QUANT_TYPE q_type = BlkQ4Sym; // Initialize to default - // Direct Q4 GEMM only supports symmetric quantization, so we disable it if zero_points are provided. - bool use_direct_q4_gemm = (fc1_zp_data == nullptr) && - CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, - fc1_out_features, hidden_size, q_type); - bool fc1_used_direct_q4 = false; - bool fc1_bias_handled_by_q4_gemm = false; + bool use_direct_q4_gemm = use_mlas_q4_gemm_effective && + ((fc1_direct_q4_cache_ptr != nullptr) || + ((packed_fc1_ == nullptr) && (fc1_zp_data == nullptr) && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type))); + + if (packed_fc1_ != nullptr) { + if (use_mlas_q4_gemm_effective && fc1_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type)) { + if (fc1_direct_q4_cache_ptr != nullptr) { + float* fc1_bias_float = nullptr; + if (has_fc1_bias) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); + } else { + std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } + fc1_bias_float = thread_bias1_buffer; + } + + size_t packed_size = MlasQ4GemmPackBSize(q_type, static_cast(fc1_out_features), static_cast(hidden_size)); + const uint8_t* packed_b = static_cast(fc1_direct_q4_cache_ptr) + expert_idx * packed_size; + + Status gemm_status = DirectQ4Gemm(A1, packed_b, fc1_bias_float, C1, + num_expert_tokens, fc1_out_features, hidden_size, fc1_direct_qtype, tp); + if (gemm_status.IsOK()) { + goto fc1_gemm_done; + } + } + } + + // Fallback: Dequantize from PrePacked (transposed, unpacked) -> MlasGemm + const uint8_t* current_packed_ptr = static_cast(packed_fc1_.get()) + expert_idx * fc1_out_features * hidden_size; + + DequantizePrePacked(current_packed_ptr, fc1_scales_ptr, fc1_zp_ptr, + is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, + B1_dequant, fc1_scales_dims); + + // Use MlasGemm with B1_dequant (which is already float transposed) + MlasGemm(CblasNoTrans, CblasNoTrans, + m, n, k, + 1.0f, A1, k, + B1_dequant, n, + 0.0f, C1, n, + tp); + + goto fc1_bias_handling; + } if (use_direct_q4_gemm) { IAllocatorUniquePtr mlas_packed_fc1; @@ -750,7 +1113,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); if (gemm_status.IsOK()) { - fc1_used_direct_q4 = true; goto fc1_gemm_done; } } @@ -797,8 +1159,9 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C1, n, tp); - fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias; - if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) { + fc1_bias_handling: + + if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); @@ -888,10 +1251,58 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t k2 = static_cast(inter_size); MLAS_BLK_QUANT_TYPE q_type2 = BlkQ4Sym; // Initialize to default - bool use_direct_q4_gemm_fc2 = (fc2_zp_data == nullptr) && - CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, - hidden_size, inter_size, q_type2); - bool fc2_used_direct_q4 = false; + bool use_direct_q4_gemm_fc2 = use_mlas_q4_gemm_effective && + ((fc2_direct_q4_cache_ptr != nullptr) || + ((packed_fc2_ == nullptr) && (fc2_zp_data == nullptr) && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2))); + + if (packed_fc2_ != nullptr) { + if (use_mlas_q4_gemm_effective && fc2_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2)) { + if (fc2_direct_q4_cache_ptr != nullptr) { + float* fc2_bias_float = nullptr; + if (has_fc2_bias) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); + } else { + std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + } + fc2_bias_float = thread_bias2_buffer; + } + + size_t packed_size = MlasQ4GemmPackBSize(q_type2, static_cast(hidden_size), static_cast(inter_size)); + const uint8_t* packed_b = static_cast(fc2_direct_q4_cache_ptr) + expert_idx * packed_size; + + Status gemm_status = DirectQ4Gemm(A2, packed_b, fc2_bias_float, C2, + num_expert_tokens, hidden_size, inter_size, fc2_direct_qtype, tp); + if (gemm_status.IsOK()) { + fc2_bias_added_by_mlas = true; + goto fc2_gemm_done; + } + } + } + + // Dequantize from PrePacked (transposed, unpacked) + const uint8_t* current_packed_ptr = static_cast(packed_fc2_.get()) + expert_idx * hidden_size * inter_size; + + DequantizePrePacked(current_packed_ptr, fc2_scales_ptr, fc2_zp_ptr, + is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, + B2_dequant, fc2_scales_dims); + + // Fallback + MlasGemm(CblasNoTrans, CblasNoTrans, + m2, n2, k2, + 1.0f, A2, k2, + B2_dequant, n2, + 0.0f, C2, n2, + tp); + + goto fc2_gemm_done; + } if (use_direct_q4_gemm_fc2) { IAllocatorUniquePtr mlas_packed_fc2; @@ -929,7 +1340,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { num_expert_tokens, hidden_size, inter_size, q_type2, tp); if (gemm_status.IsOK()) { - fc2_used_direct_q4 = true; + fc2_bias_added_by_mlas = true; goto fc2_gemm_done; } } @@ -979,8 +1390,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_gemm_done: - bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias && !fc2_bias_added_by_mlas) { const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); @@ -1015,7 +1425,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias && !fc2_bias_added_by_mlas) { const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); size_t j = 0; for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { @@ -1110,9 +1520,14 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); + template Status QMoECPU::Compute(OpKernelContext* context) const; +template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); +template Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, bool& used_shared_buffers); template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; +template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); +template Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, bool& used_shared_buffers); // Kernel Registration ONNX_OPERATOR_TYPED_KERNEL_EX( diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index 890580e051a8e..94105a4661ec1 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -5,7 +5,9 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas_q4.h" #include "contrib_ops/cpu/moe/moe_base_cpu.h" +#include namespace onnxruntime { namespace contrib { @@ -25,9 +27,31 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { explicit QMoECPU(const OpKernelInfo& op_kernel_info); Status Compute(OpKernelContext* context) const override; + private: + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + private: int64_t expert_weight_bits_; int64_t block_size_; + bool use_mlas_q4_gemm_{false}; + bool use_mlas_q4_gemm_overridden_{false}; + + IAllocatorUniquePtr packed_fc1_; + IAllocatorUniquePtr packed_fc2_; + IAllocatorUniquePtr packed_fc3_; + + TensorShape fc1_shape_; + TensorShape fc2_shape_; + TensorShape fc3_shape_; + + IAllocatorUniquePtr packed_fc1_mlas_cache_; + IAllocatorUniquePtr packed_fc2_mlas_cache_; }; } // namespace contrib diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc index 38dd8de01147c..5137c22d6cf61 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc @@ -621,8 +621,8 @@ void DumpNodeInputs( std::cout << " is non-tensor type.\n"; } } else { - // this could happen with an empty Optional input - std::cout << " was missing data type\n"; + // this could happen with an empty Optional input or the tensor is removed after pre-packing. + std::cout << " was missing data type (maybe pre-packed).\n"; } } else { std::cout << "Input " << i << " is optional and was not provided.\n"; diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 69f0435615079..d60e5b0164fe8 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -57,10 +57,10 @@ MlasQ4GemmPackBSize( * * @param QType type of block quantization * @param PackedBuf destination buffer - * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B + * @param FpData the pointer to fp32 matrix, with shape [K, N]. + * @param N the number of columns of matrix B (Output Channels). + * @param K the number of rows of matrix B (Input Channels). + * @param ldb leading dimension of FpData (usually N) */ void MLASCALL diff --git a/onnxruntime/test/python/transformers/benchmark_qmoe.py b/onnxruntime/test/python/transformers/benchmark_qmoe.py new file mode 100644 index 0000000000000..53854e053ef93 --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_qmoe.py @@ -0,0 +1,188 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os +import sys +import time +import unittest + +import numpy +import torch + +# Add current directory to path to allow importing from test_qmoe_cpu +current_dir = os.path.dirname(os.path.abspath(__file__)) +if current_dir not in sys.path: + sys.path.append(current_dir) + +from test_qmoe_cpu import PhiMoEConfig, PhiMoESparseMoeBlock, TensorProto # noqa: E402 + +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" + + +@unittest.skipIf(pipeline_mode, "Skip benchmark in CI pipeline.") +class TestQMoESwiGLUBenchmark(unittest.TestCase): + """Benchmark tests for QMoE SwiGLU performance measurement.""" + + def test_qmoe_swiglu_throughput_benchmark(self): + """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" + print("\n=== QMoE SwiGLU Throughput Benchmark ===") + + # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits) + configs = [ + ("Medium-4bit", 2880, 2880, 32, 4, 4), + ("Medium-8bit", 2880, 2880, 32, 4, 8), + ] + + batch_size = 1 + sequence_length = 512 + num_runs = 1000 + + results = [] + + for config_name, hidden_size, intermediate_size, num_experts, top_k, quant_bits in configs: + torch.manual_seed(42) + numpy.random.seed(42) + + print(f"\nTesting {config_name}:") + print(f" Hidden: {hidden_size}, Intermediate: {intermediate_size}") + print(f" Experts: {num_experts}, Top-K: {top_k}, Quant: {quant_bits}-bit") + + try: + # Create config and model + config = PhiMoEConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_local_experts=num_experts, + num_experts_per_tok=top_k, + ) + + qmoe_swiglu = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + # Create test input with fixed sequence length to match ONNX model + full_hidden_states = torch.randn(batch_size, sequence_length, hidden_size).to(torch.float32) + + # For TTFT simulation, we'll measure single forward pass time + # This represents the time to process one token in autoregressive generation + + # Warm up with full context + for _ in range(3): + _ = qmoe_swiglu.forward(full_hidden_states) + + # Benchmark PyTorch TTFT (Time to First Token) + # Measure time for a single forward pass (represents token generation time) + torch.manual_seed(42) + + start_time = time.time() + for _ in range(num_runs): + torch_output = qmoe_swiglu.forward(full_hidden_states) + end_time = time.time() + torch_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second (throughput) + # For sequence generation, this represents the rate at which we can generate tokens + torch_tokens_per_sec = 1000.0 / torch_ttft_ms # 1 token / (time_ms / 1000) + + print(f" PyTorch TTFT: {torch_ttft_ms:.3f} ms (per token generation time)") + print(f" PyTorch Throughput: {torch_tokens_per_sec:.1f} tokens/sec") + + # Benchmark ONNX Runtime + ort_ttft_ms = 0 + ort_tokens_per_sec = 0 + speedup = 0 + throughput_ratio = 0 + max_diff = 0 + + model_updated = qmoe_swiglu.recreate_onnx_model() + if model_updated and qmoe_swiglu.ort_sess is not None: + # Warm up ORT with full context + for _ in range(3): + _ = qmoe_swiglu.ort_forward(full_hidden_states) + + torch.manual_seed(42) + + # Measure ONNX Runtime TTFT (Time to First Token) + start_time = time.time() + for _ in range(num_runs): + ort_output = qmoe_swiglu.ort_forward(full_hidden_states) + end_time = time.time() + ort_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second for ONNX Runtime + ort_tokens_per_sec = 1000.0 / ort_ttft_ms # 1 token / (time_ms / 1000) + + speedup = torch_ttft_ms / ort_ttft_ms if ort_ttft_ms > 0 else 0 + throughput_ratio = ort_tokens_per_sec / torch_tokens_per_sec if torch_tokens_per_sec > 0 else 0 + + print(f" ONNX RT TTFT: {ort_ttft_ms:.3f} ms (per token generation time)") + print(f" ONNX RT Throughput: {ort_tokens_per_sec:.1f} tokens/sec") + print(f" TTFT Speedup: {speedup:.2f}x") + print(f" Throughput Gain: {throughput_ratio:.2f}x") + else: + print(" ONNX RT: Not available") + + # Calculate max difference if both outputs available + if torch_output is not None and ort_output is not None: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max().item() + print(f" Max diff: {max_diff:.6f}") + + results.append( + { + "config": config_name, + "torch_ttft_ms": torch_ttft_ms, + "torch_tokens_per_sec": torch_tokens_per_sec, + "ort_ttft_ms": ort_ttft_ms, + "ort_tokens_per_sec": ort_tokens_per_sec, + "speedup": speedup, + "throughput_ratio": throughput_ratio, + "max_diff": max_diff, + } + ) + + except Exception as e: + print(f" Error: {e}") + continue + + # Summary + print("\n=== Token Generation Time & Throughput Summary ===") + print( + f"{'Config':<15} {'PT Time':<10} {'PT tok/s':<10} {'ORT Time':<11} {'ORT tok/s':<11} {'Time Gain':<10} {'Throughput':<11} {'Max Diff':<10}" + ) + print("-" * 105) + for result in results: + config = result["config"] + torch_ttft = result["torch_ttft_ms"] + torch_tps = result["torch_tokens_per_sec"] + ort_ttft = result["ort_ttft_ms"] + ort_tps = result["ort_tokens_per_sec"] + speedup = result["speedup"] + throughput_ratio = result["throughput_ratio"] + max_diff = result["max_diff"] + + ort_ttft_str = f"{ort_ttft:.3f}" if ort_ttft > 0 else "N/A" + ort_tps_str = f"{ort_tps:.1f}" if ort_tps > 0 else "N/A" + speedup_str = f"{speedup:.2f}x" if speedup > 0 else "N/A" + throughput_str = f"{throughput_ratio:.2f}x" if throughput_ratio > 0 else "N/A" + + print( + f"{config:<15} {torch_ttft:<10.3f} {torch_tps:<10.1f} {ort_ttft_str:<11} {ort_tps_str:<11} {speedup_str:<10} {throughput_str:<11} {max_diff:<10.6f}" + ) + + print("\nNotes:") + print("- Time: Token generation time in ms (lower is better)") + print("- tok/s: Tokens per second throughput (higher is better)") + print("- Time Gain: ORT speedup for latency (higher is better)") + print("- Throughput: ORT throughput improvement (higher is better)") + + +if __name__ == "__main__": + benchmark = TestQMoESwiGLUBenchmark() + benchmark.test_qmoe_swiglu_throughput_benchmark() diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 90ebb148a26a5..8415c7b08b77c 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -23,9 +23,11 @@ # normalization on the selected experts. This provides proper weight distribution # while maintaining computational efficiency. # -------------------------------------------------------------------------- +import os import time import unittest from collections import OrderedDict +from contextlib import contextmanager import numpy import torch @@ -76,6 +78,8 @@ class TensorProtoPlaceholder: ort_provider = ["CPUExecutionProvider"] +ORT_USE_MLAS_Q4_GEMM_MOE = "ORT_USE_MLAS_Q4_GEMM_MOE" + torch.manual_seed(42) numpy.random.seed(42) @@ -364,7 +368,7 @@ def create_cpu_moe_onnx_graph( use_swiglu=False, use_quant=False, quant_bits=4, - swiglu_interleaved=False, + swiglu_fusion=0, block_size=0, ): if not has_onnx: @@ -400,10 +404,10 @@ def create_cpu_moe_onnx_graph( "router_probs", # 1 "fc1_experts_weights", # 2 "fc1_scales", # 3 - "", # 4: fc1_bias + "fc1_experts_bias" if fc1_bias is not None else "", # 4 "fc2_experts_weights", # 5 "fc2_scales", # 6 - "", # 7: fc2_bias + "fc2_experts_bias" if fc2_bias is not None else "", # 7 "", # 8: fc3_weights "", # 9: fc3_scales "", # 10: fc3_bias @@ -442,11 +446,10 @@ def create_cpu_moe_onnx_graph( normalize_routing_weights=normalize_routing, activation_type=activation, # Add new attributes with backwards-compatible default values - swiglu_fusion=1 if use_swiglu else 0, # 1 if using SwiGLU activation + swiglu_fusion=swiglu_fusion, swiglu_limit=7.0, activation_alpha=1.702, activation_beta=1.0, - swiglu_interleaved=1 if swiglu_interleaved else 0, # Enable this attribute domain="com.microsoft", ), ] @@ -559,6 +562,30 @@ def create_cpu_moe_onnx_graph( ) ) + if fc1_bias is not None: + fc1_bias_np = fc1_bias.detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]) + initializers.append( + helper.make_tensor( + "fc1_experts_bias", + onnx_dtype, + list(fc1_bias.shape), + fc1_bias_np.flatten().tolist(), + raw=False, + ) + ) + + if fc2_bias is not None: + fc2_bias_np = fc2_bias.detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]) + initializers.append( + helper.make_tensor( + "fc2_experts_bias", + onnx_dtype, + list(fc2_bias.shape), + fc2_bias_np.flatten().tolist(), + raw=False, + ) + ) + graph_inputs = [ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] @@ -626,7 +653,7 @@ def __init__( self.num_experts_per_token = num_experts_per_token -def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): +def swiglu(x: torch.Tensor, alpha: float = 1.702, beta: float = 1.0, limit: float = 7.0): dim = x.shape[-1] x = x.view(-1, dim // 2, 2) x_glu, x_linear = x[..., 0], x[..., 1] @@ -635,8 +662,8 @@ def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): x_glu = x_glu.clamp(max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) - y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) - return y + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + beta) + return y.view(-1, dim // 2) class MoEBlockSparseTop2MLP(nn.Module): @@ -855,7 +882,7 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False e = time.time() time_ms = (e - s) / repeat * 1000 is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu - is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + is_interleaved = hasattr(self, "swiglu_fusion") and self.swiglu_fusion == 1 act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" print(f"ORT Performance - {act_type} {self.quant_bits}-bit: {time_ms:.3f} ms/inference") @@ -868,62 +895,80 @@ def recreate_onnx_model(self): """Recreate the ONNX model with the current weights to reflect any changes to the quantization code.""" w1_list, w2_list = [], [] + w1_bias_list, w2_bias_list = [], [] w1_scale_list, w2_scale_list = [], [] w1_zp_list, w2_zp_list = [], [] is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - if self.block_size > 0: - # Use block-wise quantization - w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( - self.experts[i].w1.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( - self.experts[i].w2.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) + if hasattr(self.experts[i], "w3"): + w1, w3 = self.experts[i].w1.weight, self.experts[i].w3.weight + w2 = self.experts[i].w2.weight + w1_bias = self.experts[i].w1.bias + w3_bias = getattr(self.experts[i].w3, "bias", None) + + # Combine and interleave w1 and w3 for the fused kernel + w1_combined = torch.cat([w1, w3], dim=0) # [2*inter, hidden] + if getattr(self, "swiglu_fusion", 0) == 1: + w1_combined = w1_combined.view(2, -1, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + + if self.block_size > 0: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( + w1_combined, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( + w2, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + else: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( + w1_combined, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( + w2, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + + if w1_bias is not None and w3_bias is not None: + b1_combined = torch.cat([w1_bias, w3_bias], dim=0) + if getattr(self, "swiglu_fusion", 0) == 1: + b1_combined = b1_combined.view(2, -1).transpose(0, 1).reshape(-1) + w1_bias_list.append(b1_combined.detach().cpu()) + elif w1_bias is not None: + w1_bias_list.append(w1_bias.detach().cpu()) else: - # Use row-wise quantization - w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( - self.experts[i].w1.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( - self.experts[i].w2.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) + # PhiMoESwiGLUMLP already has interleaved weights in w1 + w1 = self.experts[i].w1.weight + w2 = self.experts[i].w2.weight + w1_bias = self.experts[i].w1.bias - if self.use_swiglu: - if self.swiglu_interleaved: - pass + if self.block_size > 0: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( + w1, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( + w2, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) else: - if self.block_size > 0: - w3_scale, pre_qweight3, w3_qdq, w3_zp = quant_dequant_blockwise( - self.experts[i].w3.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - else: - w3_scale, pre_qweight3, w3_qdq, w3_zp = quant_dequant( - self.experts[i].w3.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - - gate_weights = pre_qweight1 - value_weights = pre_qweight3 - gate_scales = w1_scale - value_scales = w3_scale - gate_zp = w1_zp - value_zp = w3_zp - - pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) - w1_scale = torch.cat([gate_scales, value_scales], dim=0) - if w1_zp is not None and w3_zp is not None: - w1_zp = torch.cat([gate_zp, value_zp], dim=0) - - if self.swiglu_interleaved: - self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( + w1, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( + w2, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + if w1_bias is not None: + w1_bias_list.append(w1_bias.detach().cpu()) + if self.use_swiglu: + if getattr(self, "swiglu_fusion", 0) == 1: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) else: intermediate_size = self.experts[i].w1.weight.shape[0] gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() value_dequant = w1_qdq[intermediate_size:].contiguous().clone() - self.experts[i].w1.weight.data = gate_dequant - self.experts[i].w3.weight.data = value_dequant + if hasattr(self.experts[i], "w3"): + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant + else: + self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() else: self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() @@ -931,6 +976,9 @@ def recreate_onnx_model(self): w1_list.append(pre_qweight1) w2_list.append(pre_qweight2) + + if self.experts[i].w2.bias is not None: + w2_bias_list.append(self.experts[i].w2.bias) w1_scale_list.append(w1_scale) w2_scale_list.append(w2_scale) if w1_zp is not None: @@ -963,9 +1011,9 @@ def recreate_onnx_model(self): onnx_dtype=self.onnx_dtype, fc1_experts_weights=self.moe_experts_weight1, fc2_experts_weights=self.moe_experts_weight2, - # Biases are not used in QMoE - fc1_bias=None, - fc2_bias=None, + # Pass collected biases + fc1_bias=torch.stack(w1_bias_list, dim=0) if w1_bias_list else None, + fc2_bias=torch.stack(w2_bias_list, dim=0) if w2_bias_list else None, # Scales are used for dequantization fc1_scales=moe_experts_weight_scale1, fc2_scales=moe_experts_weight_scale2, @@ -975,7 +1023,7 @@ def recreate_onnx_model(self): use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, - swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + swiglu_fusion=getattr(self, "swiglu_fusion", 0), block_size=self.block_size, # Add block_size for block-wise quantization ) except Exception: @@ -1020,7 +1068,7 @@ def parity_check(self): max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max() is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu - is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + is_interleaved = getattr(self, "swiglu_fusion", 0) == 1 act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" quant_type = "Asymmetric" if self.use_asymmetric_quant else "Symmetric" block_type = f"Block({self.block_size})" if self.block_size > 0 else "Row" @@ -1047,24 +1095,6 @@ def parity_check(self): ) print("Torch sample:", torch_output.cpu().reshape(-1, hidden_dim)[i, k].item()) print("ORT sample:", ort_output.cpu().reshape(-1, hidden_dim)[i, k].item()) - # Print routing and per-expert contributions for this token from the PyTorch reference - try: - hidden_states_flat = hidden_state.view(-1, hidden_dim) - token_vec = hidden_states_flat[i : i + 1] - gate_logits = self.gate(token_vec) - topk_vals, topk_experts = torch.topk(gate_logits, self.top_k, dim=-1) - topk_soft = F.softmax(topk_vals, dim=1) - print("Gate logits:", gate_logits.detach().cpu().numpy()) - print("Selected experts:", topk_experts.detach().cpu().numpy()) - print("Routing weights:", topk_soft.detach().cpu().numpy()) - # Compute per-expert contributions for selected experts - for idx_e, e in enumerate(topk_experts[0].tolist()): - expert_layer = self.experts[e] - expert_out = expert_layer(token_vec) - contrib = expert_out[0, k].item() * topk_soft[0, idx_e].item() - print(f"Expert {e} contrib at hidden {k}: {contrib}") - except Exception as _: - pass ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), @@ -1111,6 +1141,43 @@ def small_test_cases(): yield batch_size, sequence_length +def with_mlas_q4_mode(test_cases): + expanded_cases = [] + for case in test_cases: + quant_bits = case[2] + if quant_bits == 4: + expanded_cases.append((*case, None)) + expanded_cases.append((*case, False)) + expanded_cases.append((*case, True)) + else: + expanded_cases.append((*case, None)) + return expanded_cases + + +@contextmanager +def scoped_env_var(name: str, value: str): + previous = os.environ.get(name) + os.environ[name] = value + try: + yield + finally: + if previous is None: + os.environ.pop(name, None) + else: + os.environ[name] = previous + + +def run_parity_with_mlas_q4_mode(test_runner, enable_mlas_q4_gemm: bool | None): + if enable_mlas_q4_gemm is None: # No env var + test_runner() + else: + env_value = "1" if enable_mlas_q4_gemm else "0" + mode = "enabled" if enable_mlas_q4_gemm else "disabled" + print(f"DirectQ4 mode ({ORT_USE_MLAS_Q4_GEMM_MOE}) is {mode}") + with scoped_env_var(ORT_USE_MLAS_Q4_GEMM_MOE, env_value): + test_runner() + + class SwigluMoEBlock(SparseMoeBlockORTHelper): def __init__( self, @@ -1128,7 +1195,7 @@ def __init__( self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_token self.use_swiglu = True - self.swiglu_interleaved = True + self.swiglu_fusion = 1 self.block_size = block_size use_quant = self.quant_bits > 0 @@ -1232,7 +1299,7 @@ def __init__( self.top_k = config.num_experts_per_tok self.router_jitter_noise = config.router_jitter_noise self.use_swiglu = True - self.swiglu_interleaved = True + self.swiglu_fusion = 1 self.block_size = block_size use_quant = self.quant_bits > 0 @@ -1314,7 +1381,8 @@ def __init__( use_swiglu=self.use_swiglu, use_quant=use_quant, quant_bits=self.quant_bits, - swiglu_interleaved=self.swiglu_interleaved, + # swiglu_fusion=1 means fused and interleaved, which is the standard for QMoE. + swiglu_fusion=getattr(self, "swiglu_fusion", 0), block_size=self.block_size, ) @@ -1354,8 +1422,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states -disable_cpu_qmoe_tests = False - # Define test cases for different MoE types phi3_test_cases = [ (1, 32, 4), @@ -1373,10 +1439,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ] -@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestPhiQMoECPU(unittest.TestCase): - @parameterized.expand(phi3_test_cases) - def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(phi3_test_cases)) + def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): # Create unique seed based on test parameters to ensure different inputs for each test base_seed = 2000 # Different base seed from other tests param_hash = hash((batch_size, sequence_length, quant_bits)) @@ -1411,10 +1476,10 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_test_cases) - def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(phi3_test_cases)) + def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): base_seed = 3000 param_hash = hash((batch_size, sequence_length, quant_bits)) unique_seed = base_seed + abs(param_hash) % 1000 @@ -1436,10 +1501,12 @@ def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quan onnx_dtype=TensorProto.FLOAT, use_asymmetric_quant=True, ) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_blockwise_test_cases) - def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(phi3_blockwise_test_cases)) + def test_phi3_qmoe_blockwise_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(42) numpy.random.seed(42) @@ -1468,10 +1535,12 @@ def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_blockwise_test_cases) - def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(phi3_blockwise_test_cases)) + def test_phi3_qmoe_blockwise_asymmetric_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(43) numpy.random.seed(43) @@ -1489,10 +1558,8 @@ def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_le block_size=block_size, use_asymmetric_quant=True, ) - phi3_moe.parity_check() - + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) -disable_cpu_qmoe_tests = False swiglu_test_cases = [ (1, 32, 4), @@ -1510,10 +1577,9 @@ def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_le ] -@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestSwigluQMoECPU(unittest.TestCase): - @parameterized.expand(swiglu_test_cases) - def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(swiglu_test_cases)) + def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): # Create unique seed based on test parameters to ensure different inputs for each test base_seed = 1000 # Different base seed from regular MoE tests param_hash = hash((batch_size, sequence_length, quant_bits)) @@ -1547,10 +1613,10 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_test_cases) - def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(swiglu_test_cases)) + def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): base_seed = 1100 param_hash = hash((batch_size, sequence_length, quant_bits)) unique_seed = base_seed + abs(param_hash) % 1000 @@ -1572,10 +1638,12 @@ def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, qu onnx_dtype=TensorProto.FLOAT, use_asymmetric_quant=True, ) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_blockwise_test_cases) - def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(swiglu_blockwise_test_cases)) + def test_swiglu_qmoe_blockwise_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(42) numpy.random.seed(42) @@ -1603,10 +1671,12 @@ def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, qua self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_blockwise_test_cases) - def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(swiglu_blockwise_test_cases)) + def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(43) numpy.random.seed(43) @@ -1624,7 +1694,7 @@ def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_ block_size=block_size, use_asymmetric_quant=True, ) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) @unittest.skipIf(True, "Skipping QMoE CPU benchmark tests") @@ -1633,9 +1703,6 @@ class TestQMoESwiGLUBenchmark(unittest.TestCase): def test_qmoe_swiglu_throughput_benchmark(self): """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" - if disable_cpu_qmoe_tests: - self.skipTest("QMoE CPU tests disabled") - print("\n=== QMoE SwiGLU Throughput Benchmark ===") # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits)