forked from vllm-project/llm-compressor
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathqwen3_5_moe.py
More file actions
145 lines (116 loc) · 5.3 KB
/
qwen3_5_moe.py
File metadata and controls
145 lines (116 loc) · 5.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
from llmcompressor.modeling.moe_context import MoECalibrationModule
from llmcompressor.utils.dev import skip_weights_initialize
if TYPE_CHECKING:
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeSparseMoeBlock,
)
@MoECalibrationModule.register("Qwen3_5MoeSparseMoeBlock")
class CalibrationQwen3_5MoeSparseMoeBlock(MoECalibrationModule):
"""
Calibration version of Qwen3_5MoeSparseMoeBlock that unfuses 3D expert
parameters into individual MLP modules (nn.Linear) so they can be
individually quantized. Sends all tokens to all experts during calibration.
is_permanent = True because the unfused structure must persist for
quantization to target the individual nn.Linear expert weights.
"""
is_permanent = True
def __init__(
self,
original: Qwen3_5MoeSparseMoeBlock,
config,
calibrate_all_experts: bool = True,
):
super().__init__()
text_config = getattr(config, "text_config", config)
self.num_experts = text_config.num_experts
self.top_k = text_config.num_experts_per_tok
self.hidden_size = text_config.hidden_size
self.calibrate_all_experts = calibrate_all_experts
self.gate = original.gate
self.shared_expert = original.shared_expert
self.shared_expert_gate = original.shared_expert_gate
self.experts = SequentialQwen3_5MoeExperts(text_config, original.experts)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
# router: returns (router_logits, router_scores, router_indices)
_, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
# expert mask: (num_experts, top_k, num_tokens)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(
2, 1, 0
)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
for expert_idx, expert_layer in enumerate(self.experts):
idx, token_idx = torch.where(expert_mask[expert_idx])
if self.calibrate_all_experts:
expert_out = expert_layer(hidden_states_reshaped)[token_idx]
else:
expert_out = expert_layer(hidden_states_reshaped[token_idx])
if len(token_idx) > 0:
current_hidden_states = (
expert_out * routing_weights[token_idx, idx, None]
)
final_hidden_states.index_add_(
0,
token_idx,
current_hidden_states.to(hidden_states.dtype),
)
# shared expert
shared_expert_output = self.shared_expert(hidden_states_reshaped)
shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
* shared_expert_output
)
final_hidden_states = final_hidden_states + shared_expert_output
final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
return self
class SequentialQwen3_5MoeExperts(torch.nn.ModuleList):
"""
Unfuses 3D expert parameter tensors into individual Qwen3_5MoeMLP modules
so that each expert's weights are nn.Linear and can be targeted by
quantization with targets="Linear".
"""
def __init__(self, config, original):
from compressed_tensors.offload import disable_onloading
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeMLP,
)
self.num_experts = config.num_experts
intermediate_size = config.moe_intermediate_size
with skip_weights_initialize():
super().__init__(
[
Qwen3_5MoeMLP(config, intermediate_size=intermediate_size)
for _ in range(self.num_experts)
]
)
# Access expert weights on CPU to avoid GPU OOM.
# disable_onloading() makes OffloadCache return the offloaded (CPU)
# values directly instead of onloading to GPU.
with disable_onloading():
gate_up_data = original.gate_up_proj.data # [num_experts, 2*inter, hidden]
down_data = original.down_proj.data # [num_experts, hidden, inter]
for i in range(self.num_experts):
gate_up = gate_up_data[i] # [2*intermediate, hidden]
down = down_data[i] # [hidden, intermediate]
# gate_up_proj stores [gate; up] stacked along dim 0
# nn.Linear weight is [out_features, in_features]
self[i].gate_proj.weight.data = (
gate_up[:intermediate_size, :].clone().contiguous()
)
self[i].up_proj.weight.data = (
gate_up[intermediate_size:, :].clone().contiguous()
)
self[i].down_proj.weight.data = down.clone().contiguous()