QMoE CPU Performance Update (Up to 4x on 4-bit)#27364
QMoE CPU Performance Update (Up to 4x on 4-bit)#27364
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces prepack-time optimization for QMoE (Quantized Mixture of Experts) CPU operations, achieving up to 4x performance improvement for 4-bit quantization by moving weight preprocessing and DirectQ4 GEMM cache building from runtime to initialization time. The implementation adds environment variable controls for A/B testing between fast-path (DirectQ4) and fallback (dequantize + MlasGemm) execution modes.
Changes:
- Implements PrePack and UseSharedPrePackedBuffers for weight unpacking and cache building at initialization
- Adds ORT_USE_MLAS_Q4_GEMM_MOE environment variable for runtime path selection with smart defaults
- Extends test coverage to validate all execution modes (env=0, env=1, no env) for 4-bit configurations
- Updates attribute naming from swiglu_interleaved to swiglu_fusion for consistency with operator conventions
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| test_qmoe_cpu.py | Adds test expansion for env var modes, updates swiglu_fusion attribute, adds bias support, removes debug code |
| benchmark_qmoe.py | New benchmark file for measuring QMoE throughput across configurations |
| mlas_q4.h | Clarifies documentation for MlasQ4GemmPackB expected data layout ([K, N] format) |
| debug_node_inputs_outputs_utils.cc | Updates debug message to indicate pre-packed tensors may have missing type info |
| moe_quantization_cpu.h | Adds PrePack/UseSharedPrePackedBuffers methods, cache storage, and env var control flags |
| moe_quantization_cpu.cc | Implements weight prepacking, DirectQ4 cache building, environment variable handling, and dual execution paths |
| moe_helper.h | Adds nullable weight tensor handling for prepacked weights with fallback logic |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Fallback for non-quantized MoE without weights (should not happen in current code paths) | ||
| // or if only bias is provided? | ||
| local_num_experts = num_experts; |
There was a problem hiding this comment.
The fallback at lines 76-78 sets local_num_experts = num_experts when both fc1_experts_weights and fc1_experts_scales are null. This represents an invalid configuration (no weights or scales provided), and should return an error status rather than silently falling back to a potentially incorrect value. Consider adding a validation check that returns an error if both are null.
| // Fallback for non-quantized MoE without weights (should not happen in current code paths) | |
| // or if only bias is provided? | |
| local_num_experts = num_experts; | |
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | |
| "Invalid MoE configuration: both fc1_experts_weights and fc1_experts_scales are null. " | |
| "At least one must be provided."); |
| // Should not happen for valid QMoE calls | ||
| inter_size = 0; |
There was a problem hiding this comment.
The fallback at lines 91-92 sets inter_size = 0 when fc2_experts_weights, fc3_experts_scales, and fc1_experts_scales are all null. As noted in the comment, "Should not happen for valid QMoE calls". Instead of silently setting inter_size = 0, which could lead to undefined behavior or silent failures downstream, this should return an error status indicating an invalid configuration.
| // Should not happen for valid QMoE calls | |
| inter_size = 0; | |
| ORT_THROW("Invalid MoE configuration: unable to infer inter_size because " | |
| "fc2_experts_weights, fc3_experts_scales, and fc1_experts_scales are all null."); |
| #include "core/framework/op_kernel.h" | ||
| #include "core/mlas/inc/mlas_q4.h" | ||
| #include "contrib_ops/cpu/moe/moe_base_cpu.h" | ||
| #include <mutex> |
There was a problem hiding this comment.
The mutex header is included but not used in the implementation. Since the PR moved cache building to PrePack (which is called once during initialization) and removed the mutable runtime cache, synchronization primitives are no longer needed. Consider removing this unused include.
| #include <mutex> |
| bool has_prepacked_fc1_scales_{false}; | ||
| bool has_prepacked_fc2_scales_{false}; |
There was a problem hiding this comment.
The member variables has_prepacked_fc1_scales_ and has_prepacked_fc2_scales_ are set in PrePack but never read elsewhere in the code. Consider removing them if they're not needed, or add a comment explaining their intended future use.
| bool has_prepacked_fc1_scales_{false}; | |
| bool has_prepacked_fc2_scales_{false}; |
Summary
This change improves QMoE CPU performance by moving more work to prepack time and enabling the DirectQ4 GEMM fast path where appropriate, while preserving an env-var switch for performance/accuracy A/B testing.
This PR introduces:
Compute()).ORT_USE_MLAS_Q4_GEMM_MOE.Key Implementation Changes
1. Prepack-time cache build
Compute().2. Fast path vs fallback
MlasQ4GemmPackB+DirectQ4Gemmcache usage).DequantizePrePacked+MlasGemm).3. Environment variable behavior
ORT_USE_MLAS_Q4_GEMM_MOE=1: force fast path when supported.ORT_USE_MLAS_Q4_GEMM_MOE=0: force fallback path.4. Test updates
Benchmark Results (1000 inferences,
benchmark_qmoe.py)Note: PyTorch latency fluctuates across runs and is excluded from conclusions below.
ORT results comparison
ORT speedup vs baseline
Accuracy Notes
env=1(forced fast path) provides the best 4-bit performance but may show non-zero max diff in known cases.env=0(fallback) maintains parity behavior with zero observed max diff in the reported benchmark table.