@@ -1178,3 +1178,38 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
11781178
11791179 # If there were no matches, return the untouched param name
11801180 return name
1181+
1182+
1183+ def remap_expert_weight_name (
1184+ name : str ,
1185+ weight_name : str ,
1186+ param_name : str ,
1187+ ) -> str :
1188+ """Remap expert weight names, handling base_layer prefix for LoRA.
1189+
1190+ When loading expert weights, this function maps from checkpoint weight
1191+ names to model parameter names. It handles the special case where
1192+ LoRA wraps the original layer with a `base_layer` prefix.
1193+
1194+ For example:
1195+ - Input: name="model.layers.0.mlp.experts.0.up_proj.base_layer.weight"
1196+ weight_name="experts.0.up_proj."
1197+ param_name="experts.w13_"
1198+ - Output: "model.layers.0.mlp.experts.base_layer.w13_weight"
1199+
1200+ Args:
1201+ name: The full checkpoint weight name.
1202+ weight_name: The weight name pattern to match (e.g., "experts.0.up_proj.").
1203+ param_name: The parameter name to substitute (e.g., "experts.w13_").
1204+
1205+ Returns:
1206+ The remapped weight name with proper base_layer handling.
1207+ """
1208+ prefix , _ , suffix = name .partition (weight_name )
1209+ middle = param_name
1210+ base = "base_layer"
1211+ if suffix .startswith (f"{ base } ." ):
1212+ param_list = param_name .split ("." , 1 )
1213+ param_list .insert (1 , base )
1214+ middle = "." .join (param_list )
1215+ return prefix + middle + suffix .removeprefix (f"{ base } ." )
0 commit comments