@@ -367,36 +367,44 @@ <h1>Source code for lightrft.strategy.utils.broadcast_utils</h1><div class="high
367367< span class ="sd "> :param name: Original weight name from training model</ span >
368368< span class ="sd "> :return: Mapped weight name for SGLang</ span >
369369< span class ="sd "> """</ span >
370- < span class ="c1 "> # Step 0: Handle PEFT/LoRA and other potential wrapping prefixes</ span >
370+ < span class ="c1 "> # Step 0: Handle PEFT/LoRA wrapping prefixes</ span >
371371 < span class ="c1 "> # PEFT models have weights like base_model.model.<original_name></ span >
372- < span class ="c1 "> # We recursively strip "base_model.model." or "model." prefixes until we find</ span >
373- < span class ="c1 "> # core components like "visual" or "language_model"</ span >
374- < span class ="k "> while</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "base_model.model."</ span > < span class ="p "> )</ span > < span class ="ow "> or</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "model."</ span > < span class ="p "> ):</ span >
375- < span class ="k "> if</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "base_model.model."</ span > < span class ="p "> ):</ span >
376- < span class ="n "> name</ span > < span class ="o "> =</ span > < span class ="n "> name</ span > < span class ="p "> [</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="s2 "> "base_model.model."</ span > < span class ="p "> ):]</ span >
377- < span class ="k "> elif</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "model."</ span > < span class ="p "> ):</ span >
378- < span class ="c1 "> # We strip "model." and let the following steps handle it.</ span >
379- < span class ="c1 "> # If "language_model" follows, it will be added back as "model."</ span >
380- < span class ="c1 "> # for SGLang's expectation.</ span >
381- < span class ="n "> name</ span > < span class ="o "> =</ span > < span class ="n "> name</ span > < span class ="p "> [</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="s2 "> "model."</ span > < span class ="p "> ):]</ span >
372+ < span class ="c1 "> # Strip "base_model.model." prefix (possibly nested) to get the original name.</ span >
373+ < span class ="k "> while</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "base_model.model."</ span > < span class ="p "> ):</ span >
374+ < span class ="n "> name</ span > < span class ="o "> =</ span > < span class ="n "> name</ span > < span class ="p "> [</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="s2 "> "base_model.model."</ span > < span class ="p "> ):]</ span >
382375
383376 < span class ="c1 "> # PEFT models also rename original weights to include ".base_layer."</ span >
384377 < span class ="c1 "> # we need to strip this to match standard weight names</ span >
385378 < span class ="n "> name</ span > < span class ="o "> =</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> replace</ span > < span class ="p "> (</ span > < span class ="s2 "> ".base_layer."</ span > < span class ="p "> ,</ span > < span class ="s2 "> "."</ span > < span class ="p "> )</ span >
386379
387- < span class ="c1 "> # Step 2: Handle language_model prefix mapping</ span >
388- < span class ="k "> if</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "language_model."</ span > < span class ="p "> ):</ span >
389- < span class ="c1 "> # Remove "language_model." prefix</ span >
390- < span class ="n "> name</ span > < span class ="o "> =</ span > < span class ="n "> name</ span > < span class ="p "> [</ span > < span class ="mi "> 15</ span > < span class ="p "> :]</ span > < span class ="c1 "> # Remove "language_model."</ span >
391-
392- < span class ="c1 "> # For lm_head, keep as is (no "model." prefix)</ span >
393- < span class ="k "> if</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "lm_head"</ span > < span class ="p "> ):</ span >
394- < span class ="k "> return</ span > < span class ="n "> name</ span >
395-
396- < span class ="c1 "> # For other components (embed_tokens, layers, norm), add "model." prefix</ span >
397- < span class ="k "> return</ span > < span class ="sa "> f</ span > < span class ="s2 "> "model.</ span > < span class ="si "> {</ span > < span class ="n "> name</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span >
398-
399- < span class ="c1 "> # Step 3: Return as is for other cases (e.g., visual.xxx)</ span >
380+ < span class ="c1 "> # Step 1: Handle VLM models wrapped by ActorVL</ span >
381+ < span class ="c1 "> # ActorVL wraps the HF model as self.model, so parameter names get an extra "model." prefix:</ span >
382+ < span class ="c1 "> # Training (ActorVL): model.visual.xxx, model.model.layers.xxx, model.lm_head.xxx</ span >
383+ < span class ="c1 "> # SGLang expects: visual.xxx, model.layers.xxx, lm_head.xxx</ span >
384+ < span class ="c1 "> # Also handle the "model.language_model." pattern (some VLM architectures):</ span >
385+ < span class ="c1 "> # Training: model.language_model.model.layers.xxx</ span >
386+ < span class ="c1 "> # SGLang expects: model.layers.xxx</ span >
387+ < span class ="k "> if</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "model.language_model."</ span > < span class ="p "> ):</ span >
388+ < span class ="n "> inner</ span > < span class ="o "> =</ span > < span class ="n "> name</ span > < span class ="p "> [</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="s2 "> "model.language_model."</ span > < span class ="p "> ):]</ span >
389+ < span class ="k "> if</ span > < span class ="n "> inner</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "lm_head"</ span > < span class ="p "> ):</ span >
390+ < span class ="k "> return</ span > < span class ="n "> inner</ span >
391+ < span class ="k "> return</ span > < span class ="sa "> f</ span > < span class ="s2 "> "model.</ span > < span class ="si "> {</ span > < span class ="n "> inner</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span >
392+
393+ < span class ="k "> if</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "model.visual."</ span > < span class ="p "> ):</ span >
394+ < span class ="k "> return</ span > < span class ="n "> name</ span > < span class ="p "> [</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="s2 "> "model."</ span > < span class ="p "> ):]</ span >
395+
396+ < span class ="k "> if</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "model.lm_head"</ span > < span class ="p "> ):</ span >
397+ < span class ="k "> return</ span > < span class ="n "> name</ span > < span class ="p "> [</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="s2 "> "model."</ span > < span class ="p "> ):]</ span >
398+
399+ < span class ="c1 "> # Handle VLM's double "model.model." prefix (ActorVL.model -> HF model.layers)</ span >
400+ < span class ="c1 "> # model.model.layers.xxx -> model.layers.xxx</ span >
401+ < span class ="c1 "> # model.model.embed_tokens.xxx -> model.embed_tokens.xxx</ span >
402+ < span class ="k "> if</ span > < span class ="n "> name</ span > < span class ="o "> .</ span > < span class ="n "> startswith</ span > < span class ="p "> (</ span > < span class ="s2 "> "model.model."</ span > < span class ="p "> ):</ span >
403+ < span class ="k "> return</ span > < span class ="n "> name</ span > < span class ="p "> [</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="s2 "> "model."</ span > < span class ="p "> ):]</ span >
404+
405+ < span class ="c1 "> # Step 2: For text-only models (e.g., Qwen2.5-0.5B-Instruct), parameter names</ span >
406+ < span class ="c1 "> # are already in SGLang's expected format: model.layers.xxx, model.embed_tokens.xxx,</ span >
407+ < span class ="c1 "> # model.norm.xxx, lm_head.xxx. Return as-is without stripping "model." prefix.</ span >
400408 < span class ="k "> return</ span > < span class ="n "> name</ span >
401409
402410 < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> _deepspeed_broadcast</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ):</ span >
0 commit comments