Skip to content

QMoE CPU Performance Update (Up to 4x on 4-bit)#27364

Open
tianleiwu wants to merge 5 commits intomainfrom
tlwu/20260216/qmoe_cpu_q4_perf
Open

QMoE CPU Performance Update (Up to 4x on 4-bit)#27364
tianleiwu wants to merge 5 commits intomainfrom
tlwu/20260216/qmoe_cpu_q4_perf

Conversation

@tianleiwu
Copy link
Contributor

Summary

This change improves QMoE CPU performance by moving more work to prepack time and enabling the DirectQ4 GEMM fast path where appropriate, while preserving an env-var switch for performance/accuracy A/B testing.

This PR introduces:

  • Prepack and cache infrastructure for QMoE expert weights.
  • DirectQ4 packed-B cache built during prepack (instead of mutable runtime cache in Compute()).
  • Fast-path support for block-wise cases (including block size 32 where supported by MLAS Q4 type).
  • Runtime toggle via ORT_USE_MLAS_Q4_GEMM_MOE.
  • Default fast-path policy refined to avoid known accuracy-loss scenarios unless explicitly overridden by env var.
  • Test and benchmark refinements for QMoE CPU validation.

Key Implementation Changes

1. Prepack-time cache build

  • Moves DirectQ4 packed-B cache construction to prepack stage.
  • Removes mutable runtime cache maintenance from Compute().
  • Reduces per-inference overhead and avoids mutable shared cache complexity.

2. Fast path vs fallback

  • Keeps two execution modes:
    • DirectQ4 GEMM fast path (MlasQ4GemmPackB + DirectQ4Gemm cache usage).
    • Fallback path (DequantizePrePacked + MlasGemm).
  • Allows controlled fallback for accuracy-sensitive configurations.

3. Environment variable behavior

  • ORT_USE_MLAS_Q4_GEMM_MOE=1: force fast path when supported.
  • ORT_USE_MLAS_Q4_GEMM_MOE=0: force fallback path.
  • Unset: use default policy that enables fast path unless a known accuracy-loss pattern is detected.

4. Test updates

  • QMoE CPU tests were refined to validate env-var on/off behavior and no-env behavior.
  • Coverage includes parity checks for symmetric/asymmetric, row-wise/block-wise settings.

Benchmark Results (1000 inferences, benchmark_qmoe.py)

Note: PyTorch latency fluctuates across runs and is excluded from conclusions below.

ORT results comparison

Config Baseline ORT Time (ms) Baseline ORT tok/s New ORT Time (env=0) (ms) New ORT tok/s (env=0) New ORT Time (env=1) (ms) New ORT tok/s (env=1)
Medium-4bit 748.594 1.3 237.219 4.2 178.943 5.6
Medium-8bit 209.277 4.8 212.074 4.7 203.882 4.9

ORT speedup vs baseline

Config env=0 speedup vs baseline (time) env=1 speedup vs baseline (time)
Medium-4bit 3.16x faster 4.18x faster
Medium-8bit 0.99x (about flat) 1.03x faster

Accuracy Notes

  • env=1 (forced fast path) provides the best 4-bit performance but may show non-zero max diff in known cases.
  • env=0 (fallback) maintains parity behavior with zero observed max diff in the reported benchmark table.
  • Default no-env policy is designed to avoid known accuracy-loss cases while still enabling fast path where safe.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces prepack-time optimization for QMoE (Quantized Mixture of Experts) CPU operations, achieving up to 4x performance improvement for 4-bit quantization by moving weight preprocessing and DirectQ4 GEMM cache building from runtime to initialization time. The implementation adds environment variable controls for A/B testing between fast-path (DirectQ4) and fallback (dequantize + MlasGemm) execution modes.

Changes:

  • Implements PrePack and UseSharedPrePackedBuffers for weight unpacking and cache building at initialization
  • Adds ORT_USE_MLAS_Q4_GEMM_MOE environment variable for runtime path selection with smart defaults
  • Extends test coverage to validate all execution modes (env=0, env=1, no env) for 4-bit configurations
  • Updates attribute naming from swiglu_interleaved to swiglu_fusion for consistency with operator conventions

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
test_qmoe_cpu.py Adds test expansion for env var modes, updates swiglu_fusion attribute, adds bias support, removes debug code
benchmark_qmoe.py New benchmark file for measuring QMoE throughput across configurations
mlas_q4.h Clarifies documentation for MlasQ4GemmPackB expected data layout ([K, N] format)
debug_node_inputs_outputs_utils.cc Updates debug message to indicate pre-packed tensors may have missing type info
moe_quantization_cpu.h Adds PrePack/UseSharedPrePackedBuffers methods, cache storage, and env var control flags
moe_quantization_cpu.cc Implements weight prepacking, DirectQ4 cache building, environment variable handling, and dual execution paths
moe_helper.h Adds nullable weight tensor handling for prepacked weights with fallback logic

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +76 to +78
// Fallback for non-quantized MoE without weights (should not happen in current code paths)
// or if only bias is provided?
local_num_experts = num_experts;
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback at lines 76-78 sets local_num_experts = num_experts when both fc1_experts_weights and fc1_experts_scales are null. This represents an invalid configuration (no weights or scales provided), and should return an error status rather than silently falling back to a potentially incorrect value. Consider adding a validation check that returns an error if both are null.

Suggested change
// Fallback for non-quantized MoE without weights (should not happen in current code paths)
// or if only bias is provided?
local_num_experts = num_experts;
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid MoE configuration: both fc1_experts_weights and fc1_experts_scales are null. "
"At least one must be provided.");

Copilot uses AI. Check for mistakes.
Comment on lines +91 to +92
// Should not happen for valid QMoE calls
inter_size = 0;
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback at lines 91-92 sets inter_size = 0 when fc2_experts_weights, fc3_experts_scales, and fc1_experts_scales are all null. As noted in the comment, "Should not happen for valid QMoE calls". Instead of silently setting inter_size = 0, which could lead to undefined behavior or silent failures downstream, this should return an error status indicating an invalid configuration.

Suggested change
// Should not happen for valid QMoE calls
inter_size = 0;
ORT_THROW("Invalid MoE configuration: unable to infer inter_size because "
"fc2_experts_weights, fc3_experts_scales, and fc1_experts_scales are all null.");

Copilot uses AI. Check for mistakes.
#include "core/framework/op_kernel.h"
#include "core/mlas/inc/mlas_q4.h"
#include "contrib_ops/cpu/moe/moe_base_cpu.h"
#include <mutex>
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mutex header is included but not used in the implementation. Since the PR moved cache building to PrePack (which is called once during initialization) and removed the mutable runtime cache, synchronization primitives are no longer needed. Consider removing this unused include.

Suggested change
#include <mutex>

Copilot uses AI. Check for mistakes.
Comment on lines +45 to +46
bool has_prepacked_fc1_scales_{false};
bool has_prepacked_fc2_scales_{false};
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The member variables has_prepacked_fc1_scales_ and has_prepacked_fc2_scales_ are set in PrePack but never read elsewhere in the code. Consider removing them if they're not needed, or add a comment explaining their intended future use.

Suggested change
bool has_prepacked_fc1_scales_{false};
bool has_prepacked_fc2_scales_{false};

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant