Skip to content

Commit ffa5d74

Browse files
authored
Enable loading of fused expert weights in the Transformers modelling backend (vllm-project#36997)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 74fe80e commit ffa5d74

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,22 +1342,41 @@ def load_weights(
13421342
weight_name = qual_name.replace(weight_name, param_name)
13431343
param_name = weight_name.removeprefix(f"{self.layer_name}.")
13441344
param = getattr(self, param_name)
1345-
success = self.weight_loader(
1346-
param=param,
1347-
loaded_weight=loaded_weight,
1348-
weight_name=weight_name,
1349-
shard_id=shard_id,
1350-
expert_id=expert_id,
1351-
return_success=True,
1352-
)
1353-
if success:
1354-
logger.debug(
1355-
"Loaded %s for expert %d into %s",
1356-
param_name,
1357-
expert_id,
1358-
self.layer_name,
1345+
# Fused expert weights can be identified by their 3D tensors
1346+
if loaded_weight.dim() == 3:
1347+
# Repurpose expert_id as shard_idx for deconcatenating w1 and w3
1348+
if shard_id in {"w1", "w3"}:
1349+
shard_idx = expert_id
1350+
experts_shard = loaded_weight.chunk(2, dim=1)[shard_idx]
1351+
else:
1352+
experts_shard = loaded_weight
1353+
start = 0
1354+
else:
1355+
# loaded_weight is a single expert weight, so we add a dummy expert
1356+
# dimension to unify the loading logic with the fused case
1357+
experts_shard = loaded_weight.unsqueeze(0)
1358+
start = expert_id
1359+
1360+
# Unified loading logic for fused and non-fused experts
1361+
loaded_experts = experts_shard.unbind()
1362+
for expert_id, loaded_expert in enumerate(loaded_experts, start=start):
1363+
success = self.weight_loader(
1364+
param=param,
1365+
loaded_weight=loaded_expert,
1366+
weight_name=weight_name,
1367+
shard_id=shard_id,
1368+
expert_id=expert_id,
1369+
return_success=True,
13591370
)
1360-
yield param_name
1371+
if success:
1372+
logger.debug(
1373+
"Loaded expert %d of shard %s into %s for layer %s",
1374+
expert_id,
1375+
shard_id,
1376+
param_name,
1377+
self.layer_name,
1378+
)
1379+
yield param_name
13611380

13621381
def get_expert_weights(self) -> Iterable[torch.Tensor]:
13631382
def _maybe_make_contiguous(

vllm/model_executor/models/transformers/moe.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,17 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
156156
Params for weights, fp8 weight scales, fp8 activation scales
157157
(param_name, weight_name, expert_id, shard_id)
158158
"""
159+
# Models saved with fused experts. These are checkpoints released:
160+
# - After Transformers v5
161+
# - Before Transformers v5, but re-saved with save_original_format=False
162+
# In the fused experts case, we repurpose the expert_id as shard_idx for
163+
# deconcatenating w1 and w3 in FusedMoE.load_weights.
164+
expert_mapping = [
165+
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
166+
("experts.w13_weight", "experts.gate_up_proj", 1, "w3"),
167+
("experts.w2_weight", "experts.down_proj", 0, "w2"),
168+
]
169+
# Models saved with ModuleList experts
159170
ckpt_names = [
160171
# (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name)
161172
("gate_proj", "down_proj", "up_proj"), # Most common MoE style
@@ -164,7 +175,6 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
164175
]
165176
num_experts = self.model_config.get_num_experts()
166177
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
167-
expert_mapping = []
168178
for gate_proj, down_proj, up_proj in ckpt_names:
169179
expert_mapping.extend(
170180
FusedMoE.make_expert_params_mapping(

0 commit comments

Comments
 (0)