Skip to content

Commit 3021b23

Browse files
authored
fix(pu): fix _map_weight_name_for_sglang bug in text-only model (#55)
1 parent 572565d commit 3021b23

File tree

1 file changed

+32
-24
lines changed

1 file changed

+32
-24
lines changed

lightrft/strategy/utils/broadcast_utils.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,36 +65,44 @@ def _map_weight_name_for_sglang(self, name: str) -> str:
6565
:param name: Original weight name from training model
6666
:return: Mapped weight name for SGLang
6767
"""
68-
# Step 0: Handle PEFT/LoRA and other potential wrapping prefixes
68+
# Step 0: Handle PEFT/LoRA wrapping prefixes
6969
# PEFT models have weights like base_model.model.<original_name>
70-
# We recursively strip "base_model.model." or "model." prefixes until we find
71-
# core components like "visual" or "language_model"
72-
while name.startswith("base_model.model.") or name.startswith("model."):
73-
if name.startswith("base_model.model."):
74-
name = name[len("base_model.model."):]
75-
elif name.startswith("model."):
76-
# We strip "model." and let the following steps handle it.
77-
# If "language_model" follows, it will be added back as "model."
78-
# for SGLang's expectation.
79-
name = name[len("model."):]
70+
# Strip "base_model.model." prefix (possibly nested) to get the original name.
71+
while name.startswith("base_model.model."):
72+
name = name[len("base_model.model."):]
8073

8174
# PEFT models also rename original weights to include ".base_layer."
8275
# we need to strip this to match standard weight names
8376
name = name.replace(".base_layer.", ".")
8477

85-
# Step 2: Handle language_model prefix mapping
86-
if name.startswith("language_model."):
87-
# Remove "language_model." prefix
88-
name = name[15:] # Remove "language_model."
89-
90-
# For lm_head, keep as is (no "model." prefix)
91-
if name.startswith("lm_head"):
92-
return name
93-
94-
# For other components (embed_tokens, layers, norm), add "model." prefix
95-
return f"model.{name}"
96-
97-
# Step 3: Return as is for other cases (e.g., visual.xxx)
78+
# Step 1: Handle VLM models wrapped by ActorVL
79+
# ActorVL wraps the HF model as self.model, so parameter names get an extra "model." prefix:
80+
# Training (ActorVL): model.visual.xxx, model.model.layers.xxx, model.lm_head.xxx
81+
# SGLang expects: visual.xxx, model.layers.xxx, lm_head.xxx
82+
# Also handle the "model.language_model." pattern (some VLM architectures):
83+
# Training: model.language_model.model.layers.xxx
84+
# SGLang expects: model.layers.xxx
85+
if name.startswith("model.language_model."):
86+
inner = name[len("model.language_model."):]
87+
if inner.startswith("lm_head"):
88+
return inner
89+
return f"model.{inner}"
90+
91+
if name.startswith("model.visual."):
92+
return name[len("model."):]
93+
94+
if name.startswith("model.lm_head"):
95+
return name[len("model."):]
96+
97+
# Handle VLM's double "model.model." prefix (ActorVL.model -> HF model.layers)
98+
# model.model.layers.xxx -> model.layers.xxx
99+
# model.model.embed_tokens.xxx -> model.embed_tokens.xxx
100+
if name.startswith("model.model."):
101+
return name[len("model."):]
102+
103+
# Step 2: For text-only models (e.g., Qwen2.5-0.5B-Instruct), parameter names
104+
# are already in SGLang's expected format: model.layers.xxx, model.embed_tokens.xxx,
105+
# model.norm.xxx, lm_head.xxx. Return as-is without stripping "model." prefix.
98106
return name
99107

100108
def _deepspeed_broadcast(self):

0 commit comments

Comments
 (0)