-
Notifications
You must be signed in to change notification settings - Fork 473
Expand file tree
/
Copy pathglm4_moe_lite.py
More file actions
169 lines (142 loc) · 6.48 KB
/
glm4_moe_lite.py
File metadata and controls
169 lines (142 loc) · 6.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from llmcompressor.modeling.moe_context import MoECalibrationModule
from llmcompressor.utils.dev import skip_weights_initialize
if TYPE_CHECKING:
from transformers.models.glm4_moe_lite.configuration_glm4_moe_lite import (
Glm4MoeLiteConfig,
)
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
Glm4MoeLiteMoE,
Glm4MoeLiteNaiveMoe,
)
@MoECalibrationModule.register("Glm4MoeLiteMoE")
class CalibrationGlm4MoeLiteMoE(MoECalibrationModule):
"""
Calibration version of Glm4MoeLiteMoE that unfuses 3D expert parameters into
individual MLP modules (nn.Linear) so they can be quantized.
GLM-4.7-Flash Lite stores routed experts in a `Glm4MoeLiteNaiveMoe` module
using 3D parameters (`gate_up_proj`, `down_proj`) instead of `nn.Linear`
submodules. Since llm-compressor targets `Linear` modules, the original routed
experts are invisible to quantization and remain BF16 unless they are unpacked.
is_permanent = True so the unpacked `nn.Linear` expert structure persists for
quantization and checkpoint save.
"""
is_permanent = True
def __init__(
self,
original: Glm4MoeLiteMoE,
config: Glm4MoeLiteConfig,
calibrate_all_experts: bool = True,
num_calibrate_experts: int | None = None,
):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.n_routed_experts
self.n_routed_experts = config.n_routed_experts
self.n_group = config.n_group
self.topk_group = config.topk_group
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling_factor = config.routed_scaling_factor
self.experts = SequentialGlm4MoeLiteExperts(config, original.experts)
self.gate = original.gate
self.shared_experts = original.shared_experts
self.calibrate_all_experts = calibrate_all_experts
def route_tokens_to_experts(
self, router_logits: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Match the Hugging Face GLM-4.7-Flash Lite group router."""
router_logits = router_logits.sigmoid()
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
group_scores = (
router_logits_for_choice.view(
-1, self.n_group, self.n_routed_experts // self.n_group
)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
)
scores_for_choice = router_logits_for_choice.masked_fill(
~score_mask.bool(), 0.0
)
topk_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1, sorted=False
)[1]
topk_weights = router_logits.gather(1, topk_indices)
if self.norm_topk_prob:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor
return topk_indices, topk_weights
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residuals = hidden_states
orig_shape = hidden_states.shape
router_logits = self.gate(hidden_states)
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# Run unpacked experts sequentially so routed MLPs stay visible
# as Linear modules.
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(
topk_indices, num_classes=self.num_experts
)
expert_mask = expert_mask.permute(2, 1, 0)
for i in range(self.num_experts):
top_k_pos, token_idx = torch.where(expert_mask[i])
has_tokens = token_idx.numel() > 0
if self.calibrate_all_experts:
expert_out_all = self.experts[i](hidden_states)
if not has_tokens:
continue
expert_out = expert_out_all[token_idx]
else:
if not has_tokens:
continue
expert_out = self.experts[i](hidden_states[token_idx])
weighted_output = expert_out * topk_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(
0, token_idx, weighted_output.to(final_hidden_states.dtype)
)
hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
"""Keep the unpacked expert structure for quantization and checkpoint save."""
return self
class SequentialGlm4MoeLiteExperts(torch.nn.ModuleList):
"""
Unpacks 3D expert parameter tensors into individual Glm4MoeLiteMLP modules so
each routed expert has standard `nn.Linear` projections visible to
`targets="Linear"`.
"""
def __init__(self, config: Glm4MoeLiteConfig, original: Glm4MoeLiteNaiveMoe):
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
Glm4MoeLiteMLP,
)
self.num_experts = config.n_routed_experts
intermediate_size = config.moe_intermediate_size
with skip_weights_initialize():
super().__init__(
[
Glm4MoeLiteMLP(config, intermediate_size=intermediate_size)
for _ in range(self.num_experts)
]
)
gate_up_data = original.gate_up_proj.data
down_data = original.down_proj.data
for i in range(self.num_experts):
gate_up = gate_up_data[i]
down = down_data[i]
gate_proj, up_proj = gate_up.chunk(2, dim=0)
self[i].gate_proj.weight.data = gate_proj.contiguous()
self[i].up_proj.weight.data = up_proj.contiguous()
self[i].down_proj.weight.data = down.contiguous()