Skip to content

Commit 7cb6a1f

Browse files
committed
clean
1 parent 24590cb commit 7cb6a1f

4 files changed

Lines changed: 8 additions & 12 deletions

File tree

swift/megatron/trainers/rollout_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,9 @@ def _export_and_load_weights(self):
436436
llm_model = self.engine.inner_model
437437
patch_vllm_moe_model_weight_loader(llm_model)
438438
llm_model.load_weights(weight_iterator)
439-
_model_config = getattr(getattr(self.engine, 'engine', None), 'model_config', None)
439+
_model_config = self.engine.engine.model_config
440440
finish_vllm_weight_reload(
441-
llm_model, model_config=_model_config, target_device=getattr(llm_model, 'device', None))
441+
llm_model, model_config=_model_config, target_device=next(llm_model.parameters()).device)
442442
elif self.vllm_mode == 'server':
443443
self._load_weights_to_server_in_buckets(weight_iterator)
444444
self.vllm_client.process_weights_after_loading()

swift/rlhf_trainers/gkd_trainer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -619,10 +619,7 @@ def _fetch_and_assemble_teacher_logprobs(self, chunks):
619619
all_raw = gather_object(local_raw)
620620

621621
if self.accelerator.is_main_process:
622-
non_thinking_prefix_ids = get_non_thinking_prefix_ids(self.template)
623-
requests = [
624-
build_teacher_infer_request(d, non_thinking_prefix_ids=non_thinking_prefix_ids) for d in all_raw
625-
]
622+
requests = [build_teacher_infer_request(d) for d in all_raw]
626623
request_config = RequestConfig(prompt_logprobs=self.gkd_logits_topk, max_tokens=1, temperature=0.0)
627624
responses = self.teacher_client.infer(requests, request_config=request_config, use_tqdm=False)
628625
parsed_global = [parse_prompt_logprobs(r, topk=self.gkd_logits_topk) for r in responses]

swift/rlhf_trainers/rollout_mixin.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -813,11 +813,10 @@ def _move_full_model_to_vllm(self):
813813

814814
# Re-run process_weights_after_loading once after ALL groups loaded
815815
if self.vllm_mode == 'colocate':
816-
_model_config = getattr(getattr(self.engine, 'engine', None), 'model_config', None)
816+
_model_config = self.engine.engine.model_config
817+
llm_model = self.engine.inner_model
817818
finish_vllm_weight_reload(
818-
self.engine.inner_model,
819-
model_config=_model_config,
820-
target_device=getattr(self.engine.inner_model, 'device', None))
819+
llm_model, model_config=_model_config, target_device=next(llm_model.parameters()).device)
821820
elif self.vllm_mode == 'server' and self.accelerator.is_main_process:
822821
self.vllm_client.process_weights_after_loading()
823822

swift/rlhf_trainers/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,8 +1141,8 @@ def finish_vllm_weight_reload(vllm_model, model_config=None, target_device=None)
11411141
# Prefer vLLM's built-in
11421142
if model_config is not None and target_device is not None:
11431143
try:
1144-
from vllm.model_executor.model_loader.utils import process_weights_after_loading as _vllm_process
1145-
_vllm_process(vllm_model, model_config, target_device)
1144+
from vllm.model_executor.model_loader.utils import process_weights_after_loading
1145+
process_weights_after_loading(vllm_model, model_config, target_device)
11461146
return
11471147
except Exception as e:
11481148
logger.warning(

0 commit comments

Comments
 (0)