Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,10 @@ def weight_loader(
def load_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[str]:
from vllm.model_executor.model_loader.weight_utils import (
remap_expert_weight_name,
)

if (expert_mapping := self.expert_mapping) is None:
raise ValueError(
"`self.expert_mapping` must be provided to "
Expand All @@ -1372,7 +1376,10 @@ def load_weights(
for param_name, weight_name, expert_id, shard_id in expert_mapping:
if weight_name not in qual_name:
continue
weight_name = qual_name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
weight_name = remap_expert_weight_name(
qual_name, weight_name, param_name
)
param_name = weight_name.removeprefix(f"{self.layer_name}.")
param = getattr(self, param_name)
success = self.weight_loader(
Expand Down
35 changes: 35 additions & 0 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,3 +1178,38 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:

# If there were no matches, return the untouched param name
return name


def remap_expert_weight_name(
name: str,
weight_name: str,
param_name: str,
) -> str:
"""Remap expert weight names, handling base_layer prefix for LoRA.

When loading expert weights, this function maps from checkpoint weight
names to model parameter names. It handles the special case where
LoRA wraps the original layer with a `base_layer` prefix.

For example:
- Input: name="model.layers.0.mlp.experts.0.up_proj.base_layer.weight"
weight_name="experts.0.up_proj."
param_name="experts.w13_"
- Output: "model.layers.0.mlp.experts.base_layer.w13_weight"

Args:
name: The full checkpoint weight name.
weight_name: The weight name pattern to match (e.g., "experts.0.up_proj.").
param_name: The parameter name to substitute (e.g., "experts.w13_").

Returns:
The remapped weight name with proper base_layer handling.
"""
prefix, _, suffix = name.partition(weight_name)
middle = param_name
base = "base_layer"
if suffix.startswith(f"{base}."):
param_list = param_name.split(".", 1)
param_list.insert(1, base)
middle = ".".join(param_list)
return prefix + middle + suffix.removeprefix(f"{base}.")
6 changes: 5 additions & 1 deletion vllm/model_executor/models/afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP
Expand Down Expand Up @@ -533,7 +534,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name_mapped = remap_expert_weight_name(
name, weight_name, param_name
)

if is_pp_missing_parameter(name_mapped, self):
continue
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
remap_expert_weight_name,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -609,7 +612,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
for param_name, weight_name, shard_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name = remap_expert_weight_name(name, weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/bailing_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
remap_expert_weight_name,
)
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
Expand Down Expand Up @@ -524,7 +527,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name = remap_expert_weight_name(name, weight_name, param_name)

if is_pp_missing_parameter(name, self):
continue
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -411,7 +412,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name = remap_expert_weight_name(name, weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/deepseek_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2DecoderLayer,
Expand Down Expand Up @@ -155,7 +156,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name = remap_expert_weight_name(name, weight_name, param_name)

param = params_dict[name]
weight_loader = param.weight_loader
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -359,7 +360,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name_mapped = remap_expert_weight_name(
chunk_name, weight_name, param_name
)

param = params_dict[name_mapped]
# We should ask the weight loader to return success or
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
Expand Down Expand Up @@ -1643,7 +1644,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name_mapped = remap_expert_weight_name(
chunk_name, weight_name, param_name
)

if is_pp_missing_parameter(name_mapped, self):
continue
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -464,7 +465,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name = remap_expert_weight_name(name, weight_name, param_name)

if is_pp_missing_parameter(name, self):
continue
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/ernie45_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import set_default_rope_theta
Expand Down Expand Up @@ -563,7 +564,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name_mapped = remap_expert_weight_name(
name, weight_name, param_name
)

# Skip layers on other devices.
if is_pp_missing_parameter(name_mapped, self):
continue
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/ernie45_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import set_default_rope_theta
Expand Down Expand Up @@ -736,7 +737,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
moe_offset = int(name.split(".")[-3])
is_text_expert = moe_offset <= self.config.moe_num_experts[0] - 1

name = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name = remap_expert_weight_name(name, weight_name, param_name)
if is_text_expert:
name = name.replace(".experts.", ".text_experts.")
else:
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -554,7 +555,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name_mapped = remap_expert_weight_name(
name, weight_name, param_name
)

if is_pp_missing_parameter(name_mapped, self):
continue
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/glm4_moe_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
remap_expert_weight_name,
)
from vllm.sequence import IntermediateTensors

from .glm4_moe import (
Expand Down Expand Up @@ -293,7 +296,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name = remap_expert_weight_name(name, weight_name, param_name)

param = params_dict[name]
weight_loader = param.weight_loader
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -401,7 +402,8 @@ def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name = remap_expert_weight_name(name, weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/granitemoehybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
remap_expert_weight_name,
)
from vllm.sequence import IntermediateTensors

from .granitemoe import GraniteMoeMoE
Expand Down Expand Up @@ -465,7 +468,8 @@ def _load_quant_expert(name, loaded_weight):
if weight_name not in name:
continue

name_mapped = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name_mapped = remap_expert_weight_name(name, weight_name, param_name)

# Skip layers on other devices.
if is_pp_missing_parameter(name_mapped, self):
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -426,7 +427,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name = remap_expert_weight_name(name, weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/hunyuan_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
remap_expert_weight_name,
)
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -848,7 +849,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
# Remap expert weight name (handles base_layer suffix correctly)
name_mapped = remap_expert_weight_name(
name, weight_name, param_name
)

if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
Expand Down
Loading