2121 get_fused_names ,
2222 is_microscale_scheme ,
2323)
24+ from loguru import logger
2425
2526__all__ = [
2627 "validate_file" ,
@@ -82,6 +83,7 @@ def process_file(
8283 assert not is_microscale_scheme (scheme ), "Use `process_file_microscale_scheme`"
8384
8485 tensors = _load_tensors_from_inverse_weights_map (inverse_weights_map , device )
86+ tensors = split_fused_moe_experts (tensors )
8587
8688 if converter is not None :
8789 converter .process (tensors )
@@ -253,3 +255,72 @@ def _load_tensors_from_inverse_weights_map(
253255 )
254256 tensors [tensor_name ] = f .get_tensor (tensor_name )
255257 return tensors
258+
259+
260+ def split_fused_moe_experts (
261+ tensors : dict [str , torch .Tensor ],
262+ ) -> dict [str , torch .Tensor ]:
263+ """
264+ Find fused MoE experts (with gate_up_proj/down_proj).
265+ Split them from 3D tensors into individual 2D expert tensors.
266+
267+ Args:
268+ tensors: Dictionary of loaded tensors from safetensors file
269+
270+ Returns:
271+ split_tensors: New dictionary with split expert weights
272+ """
273+ split_tensors = {}
274+
275+ for name , tensor in tensors .items ():
276+ # Check if this is a MoE expert weight (3D tensor for experts)
277+ if tensor .ndim == 3 and (
278+ "experts.gate_up_proj" in name or "experts.down_proj" in name
279+ ):
280+ # Get number of experts
281+ num_experts = tensor .shape [0 ]
282+
283+ if "gate_up_proj" in name :
284+ # gate_up_proj is typically [num_experts, 2*intermediate, hidden]
285+ if tensor .shape [1 ] % 2 != 0 :
286+ logger .info (
287+ f"Warning: gate_up_proj { name } has odd second dimension: { tensor .shape } "
288+ )
289+ continue
290+
291+ # Split into individual experts
292+ intermediate_size = tensor .shape [1 ] // 2
293+ for expert_idx in range (num_experts ):
294+ expert_tensor = tensor [expert_idx ] # [2*intermediate, hidden]
295+ # Split gate and up projections
296+ gate_proj , up_proj = expert_tensor .split (intermediate_size , dim = 0 )
297+ # Create new key names
298+ base_key = name .replace (
299+ "mlp.experts.gate_up_proj" , f"mlp.experts.{ expert_idx } "
300+ )
301+ split_tensors [base_key + ".gate_proj.weight" ] = gate_proj
302+ split_tensors [base_key + ".up_proj.weight" ] = up_proj
303+
304+ logger .info (f"Split { name } into { num_experts } experts" )
305+
306+ elif "down_proj" in name :
307+ # down_proj is typically [num_experts, hidden, intermediate]
308+ # Split into individual experts
309+ for expert_idx in range (num_experts ):
310+ down_proj = tensor [expert_idx ] # [hidden, intermediate]
311+
312+ # Create new key name
313+ new_key = (
314+ name .replace (
315+ "mlp.experts.down_proj" , f"mlp.experts.{ expert_idx } "
316+ )
317+ + ".down_proj.weight"
318+ )
319+ split_tensors [new_key ] = down_proj
320+
321+ logger .info (f"Split { name } into { num_experts } experts" )
322+ else :
323+ # Non-MoE or non-3D tensors, keep as is
324+ split_tensors [name ] = tensor
325+
326+ return split_tensors
0 commit comments