2020 get_fused_names ,
2121 is_microscale_scheme ,
2222)
23+ from loguru import logger
2324
2425__all__ = [
2526 "validate_file" ,
@@ -77,6 +78,7 @@ def process_file(
7778 """
7879 assert not is_microscale_scheme (scheme ), "Use `_process_file_microscale_scheme`"
7980 tensors = load_file (file_path )
81+ tensors = split_fused_moe_experts (tensors )
8082
8183 if converter is not None :
8284 converter .process (tensors )
@@ -194,3 +196,59 @@ def process_file_microscale_scheme(
194196 total_size = sum (tensor .nbytes for tensor in tensors .values ())
195197 weight_map = {key : os .path .basename (save_path ) for key in tensors .keys ()}
196198 return total_size , weight_map
199+
200+
201+ def split_fused_moe_experts (tensors : dict [str , torch .Tensor ]) -> dict [str , torch .Tensor ]:
202+ """
203+ Find fused MoE experts (with gate_up_proj/down_proj).
204+ Split them from 3D tensors into individual 2D expert tensors.
205+
206+ Args:
207+ tensors: Dictionary of loaded tensors from safetensors file
208+
209+ Returns:
210+ splited_tensors: New dictionary with split expert weights
211+ """
212+ splited_tensors = {}
213+
214+ for name , tensor in tensors .items ():
215+ # Check if this is a MoE expert weight (3D tensor for experts)
216+ if tensor .ndim == 3 and ("experts.gate_up_proj" in name or "experts.down_proj" in name ):
217+ # Get number of experts
218+ num_experts = tensor .shape [0 ]
219+
220+ if "gate_up_proj" in name :
221+ # gate_up_proj is typically [num_experts, 2*intermediate, hidden]
222+ if tensor .shape [1 ] % 2 != 0 :
223+ logger .info (f"Warning: gate_up_proj { name } has odd second dimension: { tensor .shape } " )
224+ continue
225+
226+ # Split into individual experts
227+ intermediate_size = tensor .shape [1 ] // 2
228+ for expert_idx in range (num_experts ):
229+ expert_tensor = tensor [expert_idx ] # [2*intermediate, hidden]
230+ # Split gate and up projections
231+ gate_proj , up_proj = expert_tensor .split (intermediate_size , dim = 0 )
232+ # Create new key names
233+ base_key = name .replace ("mlp.experts.gate_up_proj" , f"mlp.experts.{ expert_idx } " )
234+ splited_tensors [base_key + ".gate_proj.weight" ] = gate_proj
235+ splited_tensors [base_key + ".up_proj.weight" ] = up_proj
236+
237+ logger .info (f"Split { name } into { num_experts } experts" )
238+
239+ elif "down_proj" in name :
240+ # down_proj is typically [num_experts, hidden, intermediate]
241+ # Split into individual experts
242+ for expert_idx in range (num_experts ):
243+ down_proj = tensor [expert_idx ] # [hidden, intermediate]
244+
245+ # Create new key name
246+ new_key = name .replace ("mlp.experts.down_proj" , f"mlp.experts.{ expert_idx } " ) + ".down_proj.weight"
247+ splited_tensors [new_key ] = down_proj
248+
249+ logger .info (f"Split { name } into { num_experts } experts" )
250+ else :
251+ # Non-MoE or non-3D tensors, keep as is
252+ splited_tensors [name ] = tensor
253+
254+ return splited_tensors
0 commit comments