Skip to content

Commit f6d355d

Browse files
committed
[dev] split fused moe experts to ensure quantization
Signed-off-by: Li Wei <liwei.109@outlook.com>
1 parent 370c04c commit f6d355d

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from llmcompressor import model_free_ptq
2+
3+
MODEL_ID = "Qwen/Qwen3.5-35B-A3B"
4+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-INT8"
5+
6+
# Apply W8A8 to the model
7+
# Once quantized, the model is saved
8+
# using compressed-tensors to the SAVE_DIR.
9+
model_free_ptq(
10+
model_stub=MODEL_ID,
11+
save_directory=SAVE_DIR,
12+
scheme="W8A8",
13+
ignore=[
14+
"lm_head",
15+
"re:.*mlp.gate$",
16+
"re:.*mlp.shared_expert_gate.*",
17+
"re:.*norm.*",
18+
"re:.*embed_tokens.*",
19+
"re:.*visual.*",
20+
"re:.*conv1d.*"
21+
],
22+
max_workers=15,
23+
device="cuda:0",
24+
)

src/llmcompressor/entrypoints/model_free/process.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
get_fused_names,
2121
is_microscale_scheme,
2222
)
23+
from loguru import logger
2324

2425
__all__ = [
2526
"validate_file",
@@ -77,6 +78,7 @@ def process_file(
7778
"""
7879
assert not is_microscale_scheme(scheme), "Use `_process_file_microscale_scheme`"
7980
tensors = load_file(file_path)
81+
tensors = split_fused_moe_experts(tensors)
8082

8183
if converter is not None:
8284
converter.process(tensors)
@@ -194,3 +196,59 @@ def process_file_microscale_scheme(
194196
total_size = sum(tensor.nbytes for tensor in tensors.values())
195197
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
196198
return total_size, weight_map
199+
200+
201+
def split_fused_moe_experts(tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
202+
"""
203+
Find fused MoE experts (with gate_up_proj/down_proj).
204+
Split them from 3D tensors into individual 2D expert tensors.
205+
206+
Args:
207+
tensors: Dictionary of loaded tensors from safetensors file
208+
209+
Returns:
210+
splited_tensors: New dictionary with split expert weights
211+
"""
212+
splited_tensors = {}
213+
214+
for name, tensor in tensors.items():
215+
# Check if this is a MoE expert weight (3D tensor for experts)
216+
if tensor.ndim == 3 and ("experts.gate_up_proj" in name or "experts.down_proj" in name):
217+
# Get number of experts
218+
num_experts = tensor.shape[0]
219+
220+
if "gate_up_proj" in name:
221+
# gate_up_proj is typically [num_experts, 2*intermediate, hidden]
222+
if tensor.shape[1] % 2 != 0:
223+
logger.info(f"Warning: gate_up_proj {name} has odd second dimension: {tensor.shape}")
224+
continue
225+
226+
# Split into individual experts
227+
intermediate_size = tensor.shape[1] // 2
228+
for expert_idx in range(num_experts):
229+
expert_tensor = tensor[expert_idx] # [2*intermediate, hidden]
230+
# Split gate and up projections
231+
gate_proj, up_proj = expert_tensor.split(intermediate_size, dim=0)
232+
# Create new key names
233+
base_key = name.replace("mlp.experts.gate_up_proj", f"mlp.experts.{expert_idx}")
234+
splited_tensors[base_key + ".gate_proj.weight"] = gate_proj
235+
splited_tensors[base_key + ".up_proj.weight"] = up_proj
236+
237+
logger.info(f"Split {name} into {num_experts} experts")
238+
239+
elif "down_proj" in name:
240+
# down_proj is typically [num_experts, hidden, intermediate]
241+
# Split into individual experts
242+
for expert_idx in range(num_experts):
243+
down_proj = tensor[expert_idx] # [hidden, intermediate]
244+
245+
# Create new key name
246+
new_key = name.replace("mlp.experts.down_proj", f"mlp.experts.{expert_idx}") + ".down_proj.weight"
247+
splited_tensors[new_key] = down_proj
248+
249+
logger.info(f"Split {name} into {num_experts} experts")
250+
else:
251+
# Non-MoE or non-3D tensors, keep as is
252+
splited_tensors[name] = tensor
253+
254+
return splited_tensors

0 commit comments

Comments
 (0)