diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f0d94bfbcaba..c42966121e61 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1362,6 +1362,10 @@ def weight_loader( def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]] ) -> Iterable[str]: + from vllm.model_executor.model_loader.weight_utils import ( + remap_expert_weight_name, + ) + if (expert_mapping := self.expert_mapping) is None: raise ValueError( "`self.expert_mapping` must be provided to " @@ -1372,7 +1376,10 @@ def load_weights( for param_name, weight_name, expert_id, shard_id in expert_mapping: if weight_name not in qual_name: continue - weight_name = qual_name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + weight_name = remap_expert_weight_name( + qual_name, weight_name, param_name + ) param_name = weight_name.removeprefix(f"{self.layer_name}.") param = getattr(self, param_name) success = self.weight_loader( diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0c5961561a7d..97decc3104a2 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1178,3 +1178,38 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: # If there were no matches, return the untouched param name return name + + +def remap_expert_weight_name( + name: str, + weight_name: str, + param_name: str, +) -> str: + """Remap expert weight names, handling base_layer prefix for LoRA. + + When loading expert weights, this function maps from checkpoint weight + names to model parameter names. It handles the special case where + LoRA wraps the original layer with a `base_layer` prefix. + + For example: + - Input: name="model.layers.0.mlp.experts.0.up_proj.base_layer.weight" + weight_name="experts.0.up_proj." + param_name="experts.w13_" + - Output: "model.layers.0.mlp.experts.base_layer.w13_weight" + + Args: + name: The full checkpoint weight name. + weight_name: The weight name pattern to match (e.g., "experts.0.up_proj."). + param_name: The parameter name to substitute (e.g., "experts.w13_"). + + Returns: + The remapped weight name with proper base_layer handling. + """ + prefix, _, suffix = name.partition(weight_name) + middle = param_name + base = "base_layer" + if suffix.startswith(f"{base}."): + param_list = param_name.split(".", 1) + param_list.insert(1, base) + middle = ".".join(param_list) + return prefix + middle + suffix.removeprefix(f"{base}.") diff --git a/vllm/model_executor/models/afmoe.py b/vllm/model_executor/models/afmoe.py index f5dfe4306741..8f9359876590 100644 --- a/vllm/model_executor/models/afmoe.py +++ b/vllm/model_executor/models/afmoe.py @@ -36,6 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP @@ -533,7 +534,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Do not modify `name` since the loop may continue here # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + name, weight_name, param_name + ) if is_pp_missing_parameter(name_mapped, self): continue diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 0200984c0ec8..9f8e19f322e9 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -38,7 +38,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -609,7 +612,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: for param_name, weight_name, shard_id in expert_params_mapping: if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name, self): continue param = params_dict[name] diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 4bccee752174..313bab6d99bf 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -55,7 +55,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -524,7 +527,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index db4fe61b0d85..d4fc97bc53a9 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -31,6 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors @@ -411,7 +412,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: for param_name, weight_name in expert_params_mapping: if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name, self): continue param = params_dict[name] diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 8f6b4a4b021f..af84f1577368 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -18,6 +18,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.model_executor.models.deepseek_v2 import ( DeepseekV2DecoderLayer, @@ -155,7 +156,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index c25e8422da15..dcfcdea03c62 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -359,7 +360,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Do not modify `name` since the loop may continue here # Instead, create a new variable - name_mapped = chunk_name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + chunk_name, weight_name, param_name + ) param = params_dict[name_mapped] # We should ask the weight loader to return success or diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b22cdb6d6c80..19b617c33aa8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -72,6 +72,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform @@ -1643,7 +1644,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Do not modify `name` since the loop may continue here # Instead, create a new variable - name_mapped = chunk_name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + chunk_name, weight_name, param_name + ) if is_pp_missing_parameter(name_mapped, self): continue diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 870a37039f15..f9938dae4a93 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -59,6 +59,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors @@ -464,7 +465,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index fbbd31a48538..365307a06aed 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -60,6 +60,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import set_default_rope_theta @@ -563,7 +564,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Do not modify `name` since the loop may continue here # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + name, weight_name, param_name + ) + # Skip layers on other devices. if is_pp_missing_parameter(name_mapped, self): continue diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 72f9957fc882..0ca92f10877b 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -56,6 +56,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import set_default_rope_theta @@ -736,7 +737,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: moe_offset = int(name.split(".")[-3]) is_text_expert = moe_offset <= self.config.moe_num_experts[0] - 1 - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_text_expert: name = name.replace(".experts.", ".text_experts.") else: diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 6fb09be7c67f..b1639755496b 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -59,6 +59,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors @@ -554,7 +555,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Do not modify `name` since the loop may continue here # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + name, weight_name, param_name + ) if is_pp_missing_parameter(name_mapped, self): continue diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index e34ae6c85a4f..0f3b494474b5 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -38,7 +38,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.sequence import IntermediateTensors from .glm4_moe import ( @@ -293,7 +296,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 0b1064b6343e..5624719a0191 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -56,6 +56,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors @@ -401,7 +402,8 @@ def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str] param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 3434716b8378..c037ff501e8c 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -28,7 +28,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.sequence import IntermediateTensors from .granitemoe import GraniteMoeMoE @@ -465,7 +468,8 @@ def _load_quant_expert(name, loaded_weight): if weight_name not in name: continue - name_mapped = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name(name, weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name_mapped, self): diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 0a2e5cf39ffd..6b9da1a2a2e8 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -52,6 +52,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors @@ -426,7 +427,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 0e82e84c4edb..b1764d7da3b4 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -63,6 +63,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors @@ -848,7 +849,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # Do not modify `name` since the loop may continue here # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + name, weight_name, param_name + ) + if is_pp_missing_parameter(name_mapped, self): continue param = params_dict[name_mapped] diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b2ad12be1e35..03fceb14ae52 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -33,7 +33,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.model_executor.models.llama import LlamaMLP as JambaMLP from vllm.sequence import IntermediateTensors @@ -427,7 +430,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if is_pp_missing_parameter(name, self): continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader( diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py index 4562b2202c5e..c454e802c839 100644 --- a/vllm/model_executor/models/kimi_linear.py +++ b/vllm/model_executor/models/kimi_linear.py @@ -38,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig @@ -609,7 +610,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ): if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name, self): continue param = params_dict[name] diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 85267ccda8a9..3112f088df99 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -65,6 +65,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP @@ -528,7 +529,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ) in enumerate(expert_params_mapping): if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index 70804e0a843e..6bc35ee722fc 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -35,7 +35,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Lfm2MoeConfig @@ -536,7 +539,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 7b3da3e10ab8..c7d2d61a2f21 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -46,6 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.model_executor.models.interfaces import MixtureOfExperts from vllm.model_executor.models.utils import sequence_parallel_chunk @@ -465,7 +466,7 @@ def load_moe_expert_weights( continue # Replace the weight name with the parameter name. - full_param_name = name.replace(weight_name, param_name) + full_param_name = remap_expert_weight_name(name, weight_name, param_name) # Skip if the current weight corresponds to a parameter that # does not exist on the current PP (pipeline parallel) rank. diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index 774737387639..872118555d38 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -60,7 +60,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention from vllm.sequence import IntermediateTensors @@ -676,7 +679,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if weight_name not in name: continue is_expert_weight = True - name_mapped = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + name, weight_name, param_name + ) # Skip mtp if ".mtp." in name_mapped: continue diff --git a/vllm/model_executor/models/mimo_v2_flash.py b/vllm/model_executor/models/mimo_v2_flash.py index 12b486f001e0..4651921aec02 100644 --- a/vllm/model_executor/models/mimo_v2_flash.py +++ b/vllm/model_executor/models/mimo_v2_flash.py @@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors @@ -555,7 +556,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if weight_name not in name: continue - name_rewritten = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_rewritten = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name_rewritten, self): continue diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index f104018d3aa6..674a517a4de7 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -58,7 +58,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -525,7 +528,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: for param_name, weight_name, expert_id in expert_params_mapping: if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name, self): continue param = params_dict[name] diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 9f3587a6d2fa..bded5c1bb7d5 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -40,7 +40,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle, SupportsLoRA, SupportsPP @@ -262,7 +265,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: for param_name, weight_name, expert_id in expert_params_mapping: if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name, self): continue param = params_dict[name] diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index 822bf9b5c93a..141d8cdb9ae7 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -56,6 +56,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors @@ -449,7 +450,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 4bfe3c391c26..3d3e507bba94 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -45,7 +45,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.model_executor.models.utils import maybe_prefix from vllm.sequence import IntermediateTensors @@ -806,7 +809,8 @@ def load_sparse_moe_weight( continue if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) if is_pp_missing_parameter(name, self): return param = params_dict[name] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e170c530ca29..c9daa3d0481f 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -57,6 +57,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors @@ -428,7 +429,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue is_expert_weight = True - name_mapped = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + name, weight_name, param_name + ) # Skip layers on other devices. if is_pp_missing_parameter(name_mapped, self): diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 8bc9ce6154d9..f27cd4b539cd 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -56,6 +56,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.model_executor.models.interfaces import ( HasInnerState, @@ -700,7 +701,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Do not modify `name` since the loop may continue here # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + name, weight_name, param_name + ) if is_pp_missing_parameter(name_mapped, self): continue diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index a5a926151c5c..0e08fef05686 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -46,7 +46,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -383,7 +386,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 662ecef3ac8f..3c2063289873 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -62,6 +62,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.model_executor.models.interfaces import ( MixtureOfExperts, @@ -1129,7 +1130,10 @@ def load_expert_weight( if origin_name not in weight_name: continue flag_dict["is_expert_weight"] = True - weight_name_mapped = weight_name.replace(origin_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + weight_name_mapped = remap_expert_weight_name( + weight_name, origin_name, param_name + ) if is_pp_missing_parameter(weight_name_mapped, self): continue param = params_dict[weight_name_mapped] diff --git a/vllm/model_executor/models/openpangu_mtp.py b/vllm/model_executor/models/openpangu_mtp.py index 436b7f981b1f..249867c4d6f7 100644 --- a/vllm/model_executor/models/openpangu_mtp.py +++ b/vllm/model_executor/models/openpangu_mtp.py @@ -34,7 +34,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.model_executor.models.deepseek_mtp import ( DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, @@ -201,7 +204,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 14f73d0c6458..d67ec5292dea 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -51,6 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.sequence import IntermediateTensors @@ -565,7 +566,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 82837b77e537..e527daaf4da8 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -55,7 +55,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -476,7 +479,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 0be81ecc7dd3..f698a4574add 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -61,6 +61,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors @@ -567,7 +568,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Do not modify `name` since the loop may continue here # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + name, weight_name, param_name + ) if is_pp_missing_parameter(name_mapped, self): continue diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ccf6cc6e5894..6a626f5a5fd4 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -65,6 +65,7 @@ ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, + remap_expert_weight_name, sharded_weight_loader, ) from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP @@ -1083,7 +1084,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 83694caa5248..8abd82e34665 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -18,7 +18,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.model_executor.models.qwen3_next import ( Qwen3NextDecoderLayer, Qwen3NextRMSNorm, @@ -184,7 +187,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index 3186804488e5..705fcb878d0f 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -42,6 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + remap_expert_weight_name, ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors @@ -238,7 +239,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True - name_mapped = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name_mapped = remap_expert_weight_name( + name, weight_name, param_name + ) if is_pp_missing_parameter(name_mapped, self): continue if is_fused_expert: diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 7077f1a22e8d..9d3782ed6fc6 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -34,7 +34,10 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + remap_expert_weight_name, +) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.step3_vl import Step3TextConfig @@ -498,7 +501,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param_name, weight_name, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Remap expert weight name (handles base_layer suffix correctly) + name = remap_expert_weight_name(name, weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue