Skip to content

Add MoE calibration wrapper for GLM-4.7-Flash (Glm4MoeLiteMoE)#2547

Draft
Nottlespike wants to merge 1 commit intovllm-project:mainfrom
Nottlespike:add-glm4-moe-lite-calibration
Draft

Add MoE calibration wrapper for GLM-4.7-Flash (Glm4MoeLiteMoE)#2547
Nottlespike wants to merge 1 commit intovllm-project:mainfrom
Nottlespike:add-glm4-moe-lite-calibration

Conversation

@Nottlespike
Copy link
Copy Markdown

Summary

GLM-4.7-Flash uses a separate MoE class (Glm4MoeLiteMoE) that is not covered by the existing Glm4MoeMoE wrapper. Without this fix, MoE calibration is silently skipped for GLM-4.7-Flash models, resulting in quantization that doesn't properly calibrate expert weights.

Problem

The model zai-org/GLM-4.7-Flash (31B MoE) uses Glm4MoeLiteMoE with a different architecture than Glm4MoeMoE:

  • Has shared_experts attribute (not shared_expert)
  • Uses Glm4MoeLiteNaiveMoe experts interface: (hidden_states, topk_indices, topk_weights)
  • Has group-based routing with n_group, topk_group parameters
  • 64 routed experts + shared experts

When quantizing with NVFP4, the MoE calibration context manager checks for registered wrappers but Glm4MoeLiteMoE doesn't match Glm4MoeMoE, so calibration silently falls back to standard forward passes without collecting expert activation statistics.

Solution

Add CalibrationGlm4MoeLiteMoE wrapper class that:

  • Registers for Glm4MoeLiteMoE class specifically
  • Implements route_tokens_to_experts() for proper group-based routing
  • Collects activation statistics for all 64 experts during calibration
  • Provides proper forward() that routes through all experts

Changes

  • src/llmcompressor/modeling/glm4_moe_lite.py - New 119-line wrapper
  • src/llmcompressor/modeling/__init__.py - Import the new wrapper

Testing

Verified that quantization now shows:

Found 46 MoE modules to replace
Replaced 46 MoE modules for calibration

Instead of the previous behavior where no MoE modules were detected for Glm4MoeLiteMoE.

Hardware Tested

  • 2× RTX PRO 6000 Blackwell (96GB each)
  • NVFP4 quantization with sequential offload

Copilot AI review requested due to automatic review settings March 30, 2026 22:06
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a calibration wrapper for GLM-4.7-Flash models to ensure all experts are properly calibrated during quantization, preventing suboptimal expert weight quantization. The feedback identifies a potential shape mismatch in the routing logic that could occur if input logits are not flattened and suggests a refactor to the forward method to eliminate redundant code execution paths.

Comment on lines +92 to +111
if self.calibrate_all_experts:
# Send ALL tokens to ALL experts for calibration
num_tokens = hidden_states.shape[0]
all_expert_indices = torch.arange(
self.n_routed_experts, device=hidden_states.device
).unsqueeze(0).expand(num_tokens, -1)
all_expert_weights = torch.ones(
num_tokens, self.n_routed_experts,
dtype=hidden_states.dtype,
device=hidden_states.device
) / self.n_routed_experts

# Run calibration pass through all experts
_ = self.experts(hidden_states, all_expert_indices, all_expert_weights)

# Use actual routing for output
hidden_states = self.experts(hidden_states, topk_indices, topk_weights)
else:
# Standard routing
hidden_states = self.experts(hidden_states, topk_indices, topk_weights)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to self.experts(hidden_states, topk_indices, topk_weights) is duplicated in both branches of the if self.calibrate_all_experts block. This can be refactored to improve maintainability by moving the common call outside the conditional block.

        if self.calibrate_all_experts:
            # Send ALL tokens to ALL experts for calibration
            num_tokens = hidden_states.shape[0]
            all_expert_indices = torch.arange(
                self.n_routed_experts, device=hidden_states.device
            ).unsqueeze(0).expand(num_tokens, -1)
            all_expert_weights = torch.ones(
                num_tokens, self.n_routed_experts,
                dtype=hidden_states.dtype,
                device=hidden_states.device
            ) / self.n_routed_experts

            # Run calibration pass through all experts
            _ = self.experts(hidden_states, all_expert_indices, all_expert_weights)

        # Standard routing for output
        hidden_states = self.experts(hidden_states, topk_indices, topk_weights)

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds MoE calibration support for GLM-4.7-Flash models by introducing a dedicated calibration wrapper for the Glm4MoeLiteMoE architecture, ensuring expert activation statistics are collected instead of silently skipping MoE calibration.

Changes:

  • Added CalibrationGlm4MoeLiteMoE wrapper with GLM-4.7-Flash group-based routing and “all experts see tokens” calibration behavior.
  • Registered the new wrapper via llmcompressor.modeling package import side effects.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
src/llmcompressor/modeling/glm4_moe_lite.py New MoE calibration wrapper for Glm4MoeLiteMoE, including routing + calibration-only all-expert pass.
src/llmcompressor/modeling/init.py Imports the new wrapper to trigger registry registration.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +94 to +105
num_tokens = hidden_states.shape[0]
all_expert_indices = torch.arange(
self.n_routed_experts, device=hidden_states.device
).unsqueeze(0).expand(num_tokens, -1)
all_expert_weights = torch.ones(
num_tokens, self.n_routed_experts,
dtype=hidden_states.dtype,
device=hidden_states.device
) / self.n_routed_experts

# Run calibration pass through all experts
_ = self.experts(hidden_states, all_expert_indices, all_expert_weights)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra calibration-only expert pass (_ = self.experts(...)) doesn’t use its output, but it will still build an autograd graph when calibration runs with grads enabled, increasing memory/compute significantly. Wrap this call (and the temporary tensor construction if desired) in a torch.no_grad() block (or otherwise explicitly disable grads) since it’s only for collecting activation stats.

Suggested change
num_tokens = hidden_states.shape[0]
all_expert_indices = torch.arange(
self.n_routed_experts, device=hidden_states.device
).unsqueeze(0).expand(num_tokens, -1)
all_expert_weights = torch.ones(
num_tokens, self.n_routed_experts,
dtype=hidden_states.dtype,
device=hidden_states.device
) / self.n_routed_experts
# Run calibration pass through all experts
_ = self.experts(hidden_states, all_expert_indices, all_expert_weights)
with torch.no_grad():
num_tokens = hidden_states.shape[0]
all_expert_indices = torch.arange(
self.n_routed_experts, device=hidden_states.device
).unsqueeze(0).expand(num_tokens, -1)
all_expert_weights = torch.ones(
num_tokens, self.n_routed_experts,
dtype=hidden_states.dtype,
device=hidden_states.device
) / self.n_routed_experts
# Run calibration pass through all experts
_ = self.experts(hidden_states, all_expert_indices, all_expert_weights)

Copilot uses AI. Check for mistakes.
Comment on lines +17 to +27
@MoECalibrationModule.register("Glm4MoeLiteMoE")
class CalibrationGlm4MoeLiteMoE(MoECalibrationModule):
"""
Calibration version of Glm4MoeLiteMoE that sends all tokens to all experts.

GLM-4.7-Flash uses Glm4MoeLiteNaiveMoe which has a batched expert interface:
experts(hidden_states, top_k_index, top_k_weights)

During calibration with calibrate_all_experts=True, we override routing to
send all tokens to all experts, ensuring proper quantization statistics.
"""
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces a new MoE calibration wrapper but there’s no corresponding unit test under tests/llmcompressor/modeling/ (other calibration wrappers like GLM4 MoE have targeted tests). Add a test that (1) verifies moe_calibration_context replaces Glm4MoeLiteMoE modules with CalibrationGlm4MoeLiteMoE, and (2) when calibrate_all_experts=True, all routed experts receive a forward call (e.g., via forward hooks), similar to tests/llmcompressor/modeling/test_calib_glm4_moe.py.

Copilot uses AI. Check for mistakes.
@Nottlespike Nottlespike marked this pull request as draft March 30, 2026 23:34
GLM-4.7-Flash Lite stores routed experts as packed 3D tensors in
Glm4MoeLiteNaiveMoe, so the existing calibration path not only skipped
MoE-aware calibration but also kept routed experts invisible to
Linear-targeted quantization.

Unpack the routed experts into per-expert Glm4MoeLiteMLP modules,
preserve the unpacked structure for quantization and checkpoint save,
and add focused modeling tests for expert activation, output parity,
and Linear visibility.

Signed-off-by: Jason Lu <[email protected]>
@Nottlespike Nottlespike force-pushed the add-glm4-moe-lite-calibration branch from d2c0afa to e460938 Compare March 30, 2026 23:59
Copy link
Copy Markdown
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Nottlespike , thanks for preparing this. It looks like a lot of the code is shared with what is in src/llmcompressor/modeling/glm_moe_dsa.py. Have you explored what it would look like to import and subclass the classes from that file directly? I know transformers sticks to the approach of no shared code across model definitions, but given we want to apply the same operation to the 3D expert tensors in both, maybe we won't have to repeat our code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants