Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/llmcompressor/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# trigger registration
from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
from .glm4_moe import CalibrationGlm4MoeMoE # noqa: F401
from .glm4_moe_lite import CalibrationGlm4MoeLiteMoE # noqa: F401
from .glm_moe_dsa import CalibrationGlmMoeDsaMoE # noqa: F401
from .llama4 import SequentialLlama4TextMoe # noqa: F401
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
Expand Down
169 changes: 169 additions & 0 deletions src/llmcompressor/modeling/glm4_moe_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from llmcompressor.modeling.moe_context import MoECalibrationModule
from llmcompressor.utils.dev import skip_weights_initialize

if TYPE_CHECKING:
from transformers.models.glm4_moe_lite.configuration_glm4_moe_lite import (
Glm4MoeLiteConfig,
)
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
Glm4MoeLiteMoE,
Glm4MoeLiteNaiveMoe,
)


@MoECalibrationModule.register("Glm4MoeLiteMoE")
class CalibrationGlm4MoeLiteMoE(MoECalibrationModule):
"""
Calibration version of Glm4MoeLiteMoE that unfuses 3D expert parameters into
individual MLP modules (nn.Linear) so they can be quantized.

GLM-4.7-Flash Lite stores routed experts in a `Glm4MoeLiteNaiveMoe` module
using 3D parameters (`gate_up_proj`, `down_proj`) instead of `nn.Linear`
submodules. Since llm-compressor targets `Linear` modules, the original routed
experts are invisible to quantization and remain BF16 unless they are unpacked.

is_permanent = True so the unpacked `nn.Linear` expert structure persists for
quantization and checkpoint save.
"""

is_permanent = True

def __init__(
self,
original: Glm4MoeLiteMoE,
config: Glm4MoeLiteConfig,
calibrate_all_experts: bool = True,
num_calibrate_experts: int | None = None,
):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.n_routed_experts
self.n_routed_experts = config.n_routed_experts
self.n_group = config.n_group
self.topk_group = config.topk_group
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling_factor = config.routed_scaling_factor

self.experts = SequentialGlm4MoeLiteExperts(config, original.experts)
self.gate = original.gate
self.shared_experts = original.shared_experts
self.calibrate_all_experts = calibrate_all_experts

def route_tokens_to_experts(
self, router_logits: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Match the Hugging Face GLM-4.7-Flash Lite group router."""
router_logits = router_logits.sigmoid()
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
group_scores = (
router_logits_for_choice.view(
-1, self.n_group, self.n_routed_experts // self.n_group
)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
)
scores_for_choice = router_logits_for_choice.masked_fill(
~score_mask.bool(), 0.0
)
topk_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1, sorted=False
)[1]
topk_weights = router_logits.gather(1, topk_indices)
if self.norm_topk_prob:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor
return topk_indices, topk_weights

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residuals = hidden_states
orig_shape = hidden_states.shape
router_logits = self.gate(hidden_states)
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

# Run unpacked experts sequentially so routed MLPs stay visible
# as Linear modules.
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(
topk_indices, num_classes=self.num_experts
)
expert_mask = expert_mask.permute(2, 1, 0)

for i in range(self.num_experts):
top_k_pos, token_idx = torch.where(expert_mask[i])
has_tokens = token_idx.numel() > 0

if self.calibrate_all_experts:
expert_out_all = self.experts[i](hidden_states)
if not has_tokens:
continue
expert_out = expert_out_all[token_idx]
else:
if not has_tokens:
continue
expert_out = self.experts[i](hidden_states[token_idx])

weighted_output = expert_out * topk_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(
0, token_idx, weighted_output.to(final_hidden_states.dtype)
)

hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states

def restore(self, original: torch.nn.Module) -> torch.nn.Module:
"""Keep the unpacked expert structure for quantization and checkpoint save."""
return self


class SequentialGlm4MoeLiteExperts(torch.nn.ModuleList):
"""
Unpacks 3D expert parameter tensors into individual Glm4MoeLiteMLP modules so
each routed expert has standard `nn.Linear` projections visible to
`targets="Linear"`.
"""

def __init__(self, config: Glm4MoeLiteConfig, original: Glm4MoeLiteNaiveMoe):
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
Glm4MoeLiteMLP,
)

self.num_experts = config.n_routed_experts
intermediate_size = config.moe_intermediate_size

with skip_weights_initialize():
super().__init__(
[
Glm4MoeLiteMLP(config, intermediate_size=intermediate_size)
for _ in range(self.num_experts)
]
)

gate_up_data = original.gate_up_proj.data
down_data = original.down_proj.data

for i in range(self.num_experts):
gate_up = gate_up_data[i]
down = down_data[i]
gate_proj, up_proj = gate_up.chunk(2, dim=0)

self[i].gate_proj.weight.data = gate_proj.contiguous()
self[i].up_proj.weight.data = up_proj.contiguous()
self[i].down_proj.weight.data = down.contiguous()
115 changes: 115 additions & 0 deletions tests/llmcompressor/modeling/test_calib_glm4_moe_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from functools import partial

import pytest
import torch

_glm_cfg = pytest.importorskip(
"transformers.models.glm4_moe_lite.configuration_glm4_moe_lite",
reason="glm4_moe_lite requires transformers >= 5.x",
)
_glm_mod = pytest.importorskip(
"transformers.models.glm4_moe_lite.modeling_glm4_moe_lite",
reason="glm4_moe_lite requires transformers >= 5.x",
)
Glm4MoeLiteConfig = _glm_cfg.Glm4MoeLiteConfig
Glm4MoeLiteMoE = _glm_mod.Glm4MoeLiteMoE

from llmcompressor.modeling.glm4_moe_lite import CalibrationGlm4MoeLiteMoE # noqa: E402
from llmcompressor.utils.helpers import calibration_forward_context # noqa: E402
from tests.testing_utils import requires_gpu # noqa: E402


def _tiny_config():
"""Small config for fast unit tests (8 experts instead of 64)."""
return Glm4MoeLiteConfig(
hidden_size=64,
intermediate_size=128,
moe_intermediate_size=32,
n_routed_experts=8,
num_experts_per_tok=2,
n_shared_experts=1,
n_group=1,
topk_group=1,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
vocab_size=256,
pad_token_id=0,
eos_token_id=[0],
)


@requires_gpu
def test_calib_glm4_moe_lite_all_experts_triggered():
config = _tiny_config()
with torch.device("cuda"):
original = Glm4MoeLiteMoE(config)
for param in original.parameters():
param.data.normal_(mean=0.0, std=0.02)

module = CalibrationGlm4MoeLiteMoE(original, config, calibrate_all_experts=True)

num_experts = len(module.experts)
expert_triggered = [False for _ in range(num_experts)]

def hook_fn(i, module, input, output):
expert_triggered[i] = True

for i, expert in enumerate(module.experts):
expert.register_forward_hook(partial(hook_fn, i))

hidden_dim = config.hidden_size
batch, seq_len = 4, 32
sample = torch.randn(batch, seq_len, hidden_dim, device="cuda")

with calibration_forward_context(module):
with torch.no_grad():
_ = module(sample)

assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}"


@requires_gpu
def test_calib_glm4_moe_lite_output_matches():
config = _tiny_config()
with torch.device("cuda"):
original = Glm4MoeLiteMoE(config)
for param in original.parameters():
param.data.normal_(mean=0.0, std=0.02)

hidden_dim = config.hidden_size
batch, seq_len = 4, 32
sample = torch.randn(batch, seq_len, hidden_dim, device="cuda")

with calibration_forward_context(original):
true_out = original(sample)

module = CalibrationGlm4MoeLiteMoE(original, config, calibrate_all_experts=True)
with calibration_forward_context(module):
out = module(sample)
assert torch.nn.functional.mse_loss(true_out, out) < 0.1

module = CalibrationGlm4MoeLiteMoE(original, config, calibrate_all_experts=False)
with calibration_forward_context(module):
out = module(sample)
assert torch.nn.functional.mse_loss(true_out, out) < 0.1


@requires_gpu
def test_calib_glm4_moe_lite_experts_are_linear():
"""Verify that unpacked routed experts expose Linear modules for quantization."""
config = _tiny_config()
with torch.device("cuda"):
original = Glm4MoeLiteMoE(config)

module = CalibrationGlm4MoeLiteMoE(original, config, calibrate_all_experts=True)

linear_names = [
name for name, mod in module.named_modules() if isinstance(mod, torch.nn.Linear)
]
expected_expert_linears = config.n_routed_experts * 3
expected_shared_linears = 3
assert len(linear_names) == expected_expert_linears + expected_shared_linears, (
f"Expected {expected_expert_linears + expected_shared_linears} Linear modules, "
f"found {len(linear_names)}: {linear_names}"
)
Loading