Skip to content

Commit e460938

Browse files
committed
Add GLM-4.7-Flash Lite MoE calibration wrapper
GLM-4.7-Flash Lite stores routed experts as packed 3D tensors in Glm4MoeLiteNaiveMoe, so the existing calibration path not only skipped MoE-aware calibration but also kept routed experts invisible to Linear-targeted quantization. Unpack the routed experts into per-expert Glm4MoeLiteMLP modules, preserve the unpacked structure for quantization and checkpoint save, and add focused modeling tests for expert activation, output parity, and Linear visibility. Signed-off-by: Jason Lu <Nottlespike@users.noreply.github.com>
1 parent e48353f commit e460938

File tree

3 files changed

+285
-0
lines changed

3 files changed

+285
-0
lines changed

src/llmcompressor/modeling/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# trigger registration
1313
from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
1414
from .glm4_moe import CalibrationGlm4MoeMoE # noqa: F401
15+
from .glm4_moe_lite import CalibrationGlm4MoeLiteMoE # noqa: F401
1516
from .glm_moe_dsa import CalibrationGlmMoeDsaMoE # noqa: F401
1617
from .llama4 import SequentialLlama4TextMoe # noqa: F401
1718
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
7+
from llmcompressor.modeling.moe_context import MoECalibrationModule
8+
from llmcompressor.utils.dev import skip_weights_initialize
9+
10+
if TYPE_CHECKING:
11+
from transformers.models.glm4_moe_lite.configuration_glm4_moe_lite import (
12+
Glm4MoeLiteConfig,
13+
)
14+
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
15+
Glm4MoeLiteMoE,
16+
Glm4MoeLiteNaiveMoe,
17+
)
18+
19+
20+
@MoECalibrationModule.register("Glm4MoeLiteMoE")
21+
class CalibrationGlm4MoeLiteMoE(MoECalibrationModule):
22+
"""
23+
Calibration version of Glm4MoeLiteMoE that unfuses 3D expert parameters into
24+
individual MLP modules (nn.Linear) so they can be quantized.
25+
26+
GLM-4.7-Flash Lite stores routed experts in a `Glm4MoeLiteNaiveMoe` module
27+
using 3D parameters (`gate_up_proj`, `down_proj`) instead of `nn.Linear`
28+
submodules. Since llm-compressor targets `Linear` modules, the original routed
29+
experts are invisible to quantization and remain BF16 unless they are unpacked.
30+
31+
is_permanent = True so the unpacked `nn.Linear` expert structure persists for
32+
quantization and checkpoint save.
33+
"""
34+
35+
is_permanent = True
36+
37+
def __init__(
38+
self,
39+
original: Glm4MoeLiteMoE,
40+
config: Glm4MoeLiteConfig,
41+
calibrate_all_experts: bool = True,
42+
num_calibrate_experts: int | None = None,
43+
):
44+
super().__init__()
45+
self.top_k = config.num_experts_per_tok
46+
self.num_experts = config.n_routed_experts
47+
self.n_routed_experts = config.n_routed_experts
48+
self.n_group = config.n_group
49+
self.topk_group = config.topk_group
50+
self.norm_topk_prob = config.norm_topk_prob
51+
self.routed_scaling_factor = config.routed_scaling_factor
52+
53+
self.experts = SequentialGlm4MoeLiteExperts(config, original.experts)
54+
self.gate = original.gate
55+
self.shared_experts = original.shared_experts
56+
self.calibrate_all_experts = calibrate_all_experts
57+
58+
def route_tokens_to_experts(
59+
self, router_logits: torch.Tensor
60+
) -> tuple[torch.Tensor, torch.Tensor]:
61+
"""Match the Hugging Face GLM-4.7-Flash Lite group router."""
62+
router_logits = router_logits.sigmoid()
63+
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
64+
group_scores = (
65+
router_logits_for_choice.view(
66+
-1, self.n_group, self.n_routed_experts // self.n_group
67+
)
68+
.topk(2, dim=-1)[0]
69+
.sum(dim=-1)
70+
)
71+
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
72+
group_mask = torch.zeros_like(group_scores)
73+
group_mask.scatter_(1, group_idx, 1)
74+
score_mask = (
75+
group_mask.unsqueeze(-1)
76+
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
77+
.reshape(-1, self.n_routed_experts)
78+
)
79+
scores_for_choice = router_logits_for_choice.masked_fill(
80+
~score_mask.bool(), 0.0
81+
)
82+
topk_indices = torch.topk(
83+
scores_for_choice, k=self.top_k, dim=-1, sorted=False
84+
)[1]
85+
topk_weights = router_logits.gather(1, topk_indices)
86+
if self.norm_topk_prob:
87+
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
88+
topk_weights /= denominator
89+
topk_weights = topk_weights * self.routed_scaling_factor
90+
return topk_indices, topk_weights
91+
92+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
93+
residuals = hidden_states
94+
orig_shape = hidden_states.shape
95+
router_logits = self.gate(hidden_states)
96+
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
97+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
98+
99+
# Run unpacked experts sequentially so routed MLPs stay visible
100+
# as Linear modules.
101+
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
102+
with torch.no_grad():
103+
expert_mask = torch.nn.functional.one_hot(
104+
topk_indices, num_classes=self.num_experts
105+
)
106+
expert_mask = expert_mask.permute(2, 1, 0)
107+
108+
for i in range(self.num_experts):
109+
top_k_pos, token_idx = torch.where(expert_mask[i])
110+
has_tokens = token_idx.numel() > 0
111+
112+
if self.calibrate_all_experts:
113+
expert_out_all = self.experts[i](hidden_states)
114+
if not has_tokens:
115+
continue
116+
expert_out = expert_out_all[token_idx]
117+
else:
118+
if not has_tokens:
119+
continue
120+
expert_out = self.experts[i](hidden_states[token_idx])
121+
122+
weighted_output = expert_out * topk_weights[token_idx, top_k_pos, None]
123+
final_hidden_states.index_add_(
124+
0, token_idx, weighted_output.to(final_hidden_states.dtype)
125+
)
126+
127+
hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
128+
hidden_states = hidden_states + self.shared_experts(residuals)
129+
return hidden_states
130+
131+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
132+
"""Keep the unpacked expert structure for quantization and checkpoint save."""
133+
return self
134+
135+
136+
class SequentialGlm4MoeLiteExperts(torch.nn.ModuleList):
137+
"""
138+
Unpacks 3D expert parameter tensors into individual Glm4MoeLiteMLP modules so
139+
each routed expert has standard `nn.Linear` projections visible to
140+
`targets="Linear"`.
141+
"""
142+
143+
def __init__(self, config: Glm4MoeLiteConfig, original: Glm4MoeLiteNaiveMoe):
144+
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
145+
Glm4MoeLiteMLP,
146+
)
147+
148+
self.num_experts = config.n_routed_experts
149+
intermediate_size = config.moe_intermediate_size
150+
151+
with skip_weights_initialize():
152+
super().__init__(
153+
[
154+
Glm4MoeLiteMLP(config, intermediate_size=intermediate_size)
155+
for _ in range(self.num_experts)
156+
]
157+
)
158+
159+
gate_up_data = original.gate_up_proj.data
160+
down_data = original.down_proj.data
161+
162+
for i in range(self.num_experts):
163+
gate_up = gate_up_data[i]
164+
down = down_data[i]
165+
gate_proj, up_proj = gate_up.chunk(2, dim=0)
166+
167+
self[i].gate_proj.weight.data = gate_proj.contiguous()
168+
self[i].up_proj.weight.data = up_proj.contiguous()
169+
self[i].down_proj.weight.data = down.contiguous()
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from functools import partial
2+
3+
import pytest
4+
import torch
5+
6+
_glm_cfg = pytest.importorskip(
7+
"transformers.models.glm4_moe_lite.configuration_glm4_moe_lite",
8+
reason="glm4_moe_lite requires transformers >= 5.x",
9+
)
10+
_glm_mod = pytest.importorskip(
11+
"transformers.models.glm4_moe_lite.modeling_glm4_moe_lite",
12+
reason="glm4_moe_lite requires transformers >= 5.x",
13+
)
14+
Glm4MoeLiteConfig = _glm_cfg.Glm4MoeLiteConfig
15+
Glm4MoeLiteMoE = _glm_mod.Glm4MoeLiteMoE
16+
17+
from llmcompressor.modeling.glm4_moe_lite import CalibrationGlm4MoeLiteMoE # noqa: E402
18+
from llmcompressor.utils.helpers import calibration_forward_context # noqa: E402
19+
from tests.testing_utils import requires_gpu # noqa: E402
20+
21+
22+
def _tiny_config():
23+
"""Small config for fast unit tests (8 experts instead of 64)."""
24+
return Glm4MoeLiteConfig(
25+
hidden_size=64,
26+
intermediate_size=128,
27+
moe_intermediate_size=32,
28+
n_routed_experts=8,
29+
num_experts_per_tok=2,
30+
n_shared_experts=1,
31+
n_group=1,
32+
topk_group=1,
33+
num_hidden_layers=4,
34+
num_attention_heads=4,
35+
num_key_value_heads=4,
36+
vocab_size=256,
37+
pad_token_id=0,
38+
eos_token_id=[0],
39+
)
40+
41+
42+
@requires_gpu
43+
def test_calib_glm4_moe_lite_all_experts_triggered():
44+
config = _tiny_config()
45+
with torch.device("cuda"):
46+
original = Glm4MoeLiteMoE(config)
47+
for param in original.parameters():
48+
param.data.normal_(mean=0.0, std=0.02)
49+
50+
module = CalibrationGlm4MoeLiteMoE(original, config, calibrate_all_experts=True)
51+
52+
num_experts = len(module.experts)
53+
expert_triggered = [False for _ in range(num_experts)]
54+
55+
def hook_fn(i, module, input, output):
56+
expert_triggered[i] = True
57+
58+
for i, expert in enumerate(module.experts):
59+
expert.register_forward_hook(partial(hook_fn, i))
60+
61+
hidden_dim = config.hidden_size
62+
batch, seq_len = 4, 32
63+
sample = torch.randn(batch, seq_len, hidden_dim, device="cuda")
64+
65+
with calibration_forward_context(module):
66+
with torch.no_grad():
67+
_ = module(sample)
68+
69+
assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}"
70+
71+
72+
@requires_gpu
73+
def test_calib_glm4_moe_lite_output_matches():
74+
config = _tiny_config()
75+
with torch.device("cuda"):
76+
original = Glm4MoeLiteMoE(config)
77+
for param in original.parameters():
78+
param.data.normal_(mean=0.0, std=0.02)
79+
80+
hidden_dim = config.hidden_size
81+
batch, seq_len = 4, 32
82+
sample = torch.randn(batch, seq_len, hidden_dim, device="cuda")
83+
84+
with calibration_forward_context(original):
85+
true_out = original(sample)
86+
87+
module = CalibrationGlm4MoeLiteMoE(original, config, calibrate_all_experts=True)
88+
with calibration_forward_context(module):
89+
out = module(sample)
90+
assert torch.nn.functional.mse_loss(true_out, out) < 0.1
91+
92+
module = CalibrationGlm4MoeLiteMoE(original, config, calibrate_all_experts=False)
93+
with calibration_forward_context(module):
94+
out = module(sample)
95+
assert torch.nn.functional.mse_loss(true_out, out) < 0.1
96+
97+
98+
@requires_gpu
99+
def test_calib_glm4_moe_lite_experts_are_linear():
100+
"""Verify that unpacked routed experts expose Linear modules for quantization."""
101+
config = _tiny_config()
102+
with torch.device("cuda"):
103+
original = Glm4MoeLiteMoE(config)
104+
105+
module = CalibrationGlm4MoeLiteMoE(original, config, calibrate_all_experts=True)
106+
107+
linear_names = [
108+
name for name, mod in module.named_modules() if isinstance(mod, torch.nn.Linear)
109+
]
110+
expected_expert_linears = config.n_routed_experts * 3
111+
expected_shared_linears = 3
112+
assert len(linear_names) == expected_expert_linears + expected_shared_linears, (
113+
f"Expected {expected_expert_linears + expected_shared_linears} Linear modules, "
114+
f"found {len(linear_names)}: {linear_names}"
115+
)

0 commit comments

Comments
 (0)