From 85cc95fa59f24a9a8324c015e740b76e7d507290 Mon Sep 17 00:00:00 2001 From: Li Wei Date: Wed, 11 Mar 2026 13:02:12 +0800 Subject: [PATCH] [dev] split fused moe experts to ensure quantization Signed-off-by: Li Wei --- examples/model_free_ptq/qwen3.5_int8.py | 24 ++++++++ .../entrypoints/model_free/process.py | 61 +++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 examples/model_free_ptq/qwen3.5_int8.py diff --git a/examples/model_free_ptq/qwen3.5_int8.py b/examples/model_free_ptq/qwen3.5_int8.py new file mode 100644 index 0000000000..29703b7a64 --- /dev/null +++ b/examples/model_free_ptq/qwen3.5_int8.py @@ -0,0 +1,24 @@ +from llmcompressor import model_free_ptq + +MODEL_ID = "Qwen/Qwen3.5-35B-A3B" +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-INT8" + +# Apply W8A8 to the model +# Once quantized, the model is saved +# using compressed-tensors to the SAVE_DIR. +model_free_ptq( + model_stub=MODEL_ID, + save_directory=SAVE_DIR, + scheme="W8A8", + ignore=[ + "lm_head", + "re:.*mlp.gate$", + "re:.*mlp.shared_expert_gate.*", + "re:.*norm.*", + "re:.*embed_tokens.*", + "re:.*visual.*", + "re:.*conv1d.*" + ], + max_workers=15, + device="cuda:0", +) diff --git a/src/llmcompressor/entrypoints/model_free/process.py b/src/llmcompressor/entrypoints/model_free/process.py index 44835c7f81..a6c4e3709b 100644 --- a/src/llmcompressor/entrypoints/model_free/process.py +++ b/src/llmcompressor/entrypoints/model_free/process.py @@ -77,6 +77,7 @@ def process_file( """ assert not is_microscale_scheme(scheme), "Use `_process_file_microscale_scheme`" tensors = load_file(file_path) + tensors = split_fused_moe_experts(tensors) if converter is not None: converter.process(tensors) @@ -194,3 +195,63 @@ def process_file_microscale_scheme( total_size = sum(tensor.nbytes for tensor in tensors.values()) weight_map = {key: os.path.basename(save_path) for key in tensors.keys()} return total_size, weight_map + + +def split_fused_moe_experts(tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Find fused MoE experts (with gate_up_proj/down_proj). + Split them from 3D tensors into individual 2D expert tensors. + + Args: + tensors: Dictionary of loaded tensors from safetensors file + + Returns: + New dictionary with split expert weights + """ + _tensors = {} + + for name, tensor in tensors.items(): + # Check if this is a MoE expert weight (3D tensor for experts) + if tensor.ndim == 3 and ("experts.gate_up_proj" in name or "experts.down_proj" in name): + # Get number of experts + num_experts = tensor.shape[0] + + if "gate_up_proj" in name: + # gate_up_proj is typically [num_experts, 2*intermediate, hidden] + if tensor.shape[1] % 2 != 0: + print(f"Warning: gate_up_proj {name} has odd second dimension: {tensor.shape}") + continue + + hidden_size = tensor.shape[1] // 2 + + # Split into individual experts + for expert_idx in range(num_experts): + expert_tensor = tensor[expert_idx] # [2*hidden, intermediate] + + # Split gate and up projections + gate_proj = expert_tensor[:hidden_size, :] + up_proj = expert_tensor[hidden_size:, :] + + # Create new key names + base_key = name.replace("mlp.experts.gate_up_proj", f"mlp.experts.{expert_idx}") + _tensors[base_key + ".gate_proj.weight"] = gate_proj + _tensors[base_key + ".up_proj.weight"] = up_proj + + print(f"Split {name} into {num_experts} experts") + + elif "down_proj" in name: + # down_proj is typically [num_experts, hidden, intermediate] + # Split into individual experts + for expert_idx in range(num_experts): + down_proj = tensor[expert_idx] # [hidden, intermediate] + + # Create new key name + new_key = name.replace("mlp.experts.down_proj", f"mlp.experts.{expert_idx}") + ".down_proj.weight" + _tensors[new_key] = down_proj + + print(f"Split {name} into {num_experts} experts") + else: + # Non-MoE or non-3D tensors, keep as is + _tensors[name] = tensor + + return _tensors \ No newline at end of file