Skip to content

Commit 5c39293

Browse files
committed
[BugFix] LoRA: Support loading base_layer of experts
Signed-off-by: Hollow Man <[email protected]>
1 parent c016c95 commit 5c39293

40 files changed

+223
-58
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
4343
is_flashinfer_supporting_global_sf,
4444
)
45+
from vllm.model_executor.model_loader.weight_utils import (
46+
remap_expert_weight_name,
47+
)
4548
from vllm.platforms import current_platform
4649
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
4750
from vllm.utils.math_utils import cdiv, round_up
@@ -1376,7 +1379,9 @@ def load_weights(
13761379
for param_name, weight_name, expert_id, shard_id in expert_mapping:
13771380
if weight_name not in qual_name:
13781381
continue
1379-
weight_name = qual_name.replace(weight_name, param_name)
1382+
weight_name = remap_expert_weight_name(
1383+
qual_name, weight_name, param_name
1384+
)
13801385
param_name = weight_name.removeprefix(f"{self.layer_name}.")
13811386
param = getattr(self, param_name)
13821387
success = self.weight_loader(

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,3 +1178,38 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
11781178

11791179
# If there were no matches, return the untouched param name
11801180
return name
1181+
1182+
1183+
def remap_expert_weight_name(
1184+
name: str,
1185+
weight_name: str,
1186+
param_name: str,
1187+
) -> str:
1188+
"""Remap expert weight names, handling base_layer prefix for LoRA.
1189+
1190+
When loading expert weights, this function maps from checkpoint weight
1191+
names to model parameter names. It handles the special case where
1192+
LoRA wraps the original layer with a `base_layer` prefix.
1193+
1194+
For example:
1195+
- Input: name="model.layers.0.mlp.experts.0.up_proj.base_layer.weight"
1196+
weight_name="experts.0.up_proj."
1197+
param_name="experts.w13_"
1198+
- Output: "model.layers.0.mlp.experts.base_layer.w13_weight"
1199+
1200+
Args:
1201+
name: The full checkpoint weight name.
1202+
weight_name: The weight name pattern to match (e.g., "experts.0.up_proj.").
1203+
param_name: The parameter name to substitute (e.g., "experts.w13_").
1204+
1205+
Returns:
1206+
The remapped weight name with proper base_layer handling.
1207+
"""
1208+
prefix, _, suffix = name.partition(weight_name)
1209+
middle = param_name
1210+
base = "base_layer"
1211+
if suffix.startswith(f"{base}."):
1212+
param_list = param_name.split(".", 1)
1213+
param_list.insert(1, base)
1214+
middle = ".".join(param_list)
1215+
return prefix + middle + suffix.removeprefix(f"{base}.")

vllm/model_executor/models/afmoe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from vllm.model_executor.model_loader.weight_utils import (
3737
default_weight_loader,
3838
maybe_remap_kv_scale_name,
39+
remap_expert_weight_name,
3940
)
4041
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
4142
from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP
@@ -533,7 +534,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
533534

534535
# Do not modify `name` since the loop may continue here
535536
# Instead, create a new variable
536-
name_mapped = name.replace(weight_name, param_name)
537+
# Remap expert weight name (handles base_layer prefix)
538+
name_mapped = remap_expert_weight_name(
539+
name, weight_name, param_name
540+
)
537541

538542
if is_pp_missing_parameter(name_mapped, self):
539543
continue

vllm/model_executor/models/arctic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@
3838
ParallelLMHead,
3939
VocabParallelEmbedding,
4040
)
41-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41+
from vllm.model_executor.model_loader.weight_utils import (
42+
default_weight_loader,
43+
remap_expert_weight_name,
44+
)
4245
from vllm.model_executor.utils import set_weight_attrs
4346
from vllm.platforms import current_platform
4447
from vllm.sequence import IntermediateTensors
@@ -609,7 +612,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
609612
for param_name, weight_name, shard_id in expert_params_mapping:
610613
if weight_name not in name:
611614
continue
612-
name = name.replace(weight_name, param_name)
615+
# Remap expert weight name (handles base_layer prefix)
616+
name = remap_expert_weight_name(name, weight_name, param_name)
613617
if is_pp_missing_parameter(name, self):
614618
continue
615619
param = params_dict[name]

vllm/model_executor/models/bailing_moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@
5555
ParallelLMHead,
5656
VocabParallelEmbedding,
5757
)
58-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
58+
from vllm.model_executor.model_loader.weight_utils import (
59+
default_weight_loader,
60+
remap_expert_weight_name,
61+
)
5962
from vllm.sequence import IntermediateTensors
6063

6164
from .interfaces import SupportsLoRA, SupportsPP
@@ -524,7 +527,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
524527
param_name, weight_name, expert_id, shard_id = mapping
525528
if weight_name not in name:
526529
continue
527-
name = name.replace(weight_name, param_name)
530+
# Remap expert weight name (handles base_layer prefix)
531+
name = remap_expert_weight_name(name, weight_name, param_name)
528532

529533
if is_pp_missing_parameter(name, self):
530534
continue

vllm/model_executor/models/dbrx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from vllm.model_executor.model_loader.weight_utils import (
3232
default_weight_loader,
3333
maybe_remap_kv_scale_name,
34+
remap_expert_weight_name,
3435
)
3536
from vllm.sequence import IntermediateTensors
3637

@@ -411,7 +412,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
411412
for param_name, weight_name in expert_params_mapping:
412413
if weight_name not in name:
413414
continue
414-
name = name.replace(weight_name, param_name)
415+
# Remap expert weight name (handles base_layer prefix)
416+
name = remap_expert_weight_name(name, weight_name, param_name)
415417
if is_pp_missing_parameter(name, self):
416418
continue
417419
param = params_dict[name]

vllm/model_executor/models/deepseek_eagle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.model_executor.model_loader.weight_utils import (
1919
default_weight_loader,
2020
maybe_remap_kv_scale_name,
21+
remap_expert_weight_name,
2122
)
2223
from vllm.model_executor.models.deepseek_v2 import (
2324
DeepseekV2DecoderLayer,
@@ -155,7 +156,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
155156
param_name, weight_name, expert_id, shard_id = mapping
156157
if weight_name not in name:
157158
continue
158-
name = name.replace(weight_name, param_name)
159+
# Remap expert weight name (handles base_layer prefix)
160+
name = remap_expert_weight_name(name, weight_name, param_name)
159161

160162
param = params_dict[name]
161163
weight_loader = param.weight_loader

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from vllm.model_executor.model_loader.weight_utils import (
2323
default_weight_loader,
2424
maybe_remap_kv_scale_name,
25+
remap_expert_weight_name,
2526
)
2627
from vllm.platforms import current_platform
2728
from vllm.sequence import IntermediateTensors
@@ -357,9 +358,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
357358
# attempted to load as other weights later
358359
is_expert_weight = True
359360

360-
# Do not modify `name` since the loop may continue here
361-
# Instead, create a new variable
362-
name_mapped = chunk_name.replace(weight_name, param_name)
361+
# Remap expert weight name (handles base_layer prefix)
362+
name_mapped = remap_expert_weight_name(
363+
chunk_name, weight_name, param_name
364+
)
363365

364366
param = params_dict[name_mapped]
365367
# We should ask the weight loader to return success or

vllm/model_executor/models/deepseek_v2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from vllm.model_executor.model_loader.weight_utils import (
7373
default_weight_loader,
7474
maybe_remap_kv_scale_name,
75+
remap_expert_weight_name,
7576
)
7677
from vllm.model_executor.models.utils import sequence_parallel_chunk
7778
from vllm.platforms import current_platform
@@ -1635,9 +1636,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
16351636
# attempted to load as other weights later
16361637
is_expert_weight = True
16371638

1638-
# Do not modify `name` since the loop may continue here
1639-
# Instead, create a new variable
1640-
name_mapped = chunk_name.replace(weight_name, param_name)
1639+
# Remap expert weight name (handles base_layer prefix)
1640+
name_mapped = remap_expert_weight_name(
1641+
chunk_name, weight_name, param_name
1642+
)
16411643

16421644
if is_pp_missing_parameter(name_mapped, self):
16431645
continue

vllm/model_executor/models/dots1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from vllm.model_executor.model_loader.weight_utils import (
6060
default_weight_loader,
6161
maybe_remap_kv_scale_name,
62+
remap_expert_weight_name,
6263
)
6364
from vllm.sequence import IntermediateTensors
6465

@@ -464,7 +465,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
464465
param_name, weight_name, expert_id, shard_id = mapping
465466
if weight_name not in name:
466467
continue
467-
name = name.replace(weight_name, param_name)
468+
# Remap expert weight name (handles base_layer prefix)
469+
name = remap_expert_weight_name(name, weight_name, param_name)
468470

469471
if is_pp_missing_parameter(name, self):
470472
continue

0 commit comments

Comments
 (0)