Skip to content

Commit aa02800

Browse files
authored
[bugfix] fix process_weights_after_loading & non_thinking_prefix (#9519)
1 parent 8553751 commit aa02800

12 files changed

Lines changed: 158 additions & 108 deletions

File tree

swift/megatron/trainers/gkd_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from swift.megatron.model import get_mcore_model
1818
from swift.rlhf_trainers.gkd_loss import DataSource, TeacherOutput, build_opsd_teacher_data, gkd_loss
1919
from swift.rlhf_trainers.utils import (assemble_teacher_topk_logprobs, build_teacher_infer_request,
20-
parse_prompt_logprobs, replace_assistant_response_with_ids)
20+
get_non_thinking_prefix_ids, parse_prompt_logprobs,
21+
replace_assistant_response_with_ids)
2122
from swift.rlhf_trainers.vllm_client import VLLMInferClient
2223
from swift.template import Template
2324
from swift.utils import get_cu_seqlens_from_position_ids, get_logger, is_last_rank, to_device
@@ -159,9 +160,11 @@ def _encode_batch(self, batch: List[Dict]) -> Dict[str, torch.Tensor]:
159160
template = self.template
160161
args = self.args
161162
max_length = template.max_length + self.max_completion_length
163+
non_thinking_prefix_ids = get_non_thinking_prefix_ids(template)
162164
for data in batch:
163165
if 'response_token_ids' in data:
164-
data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'])
166+
data['messages'] = replace_assistant_response_with_ids(
167+
data['messages'], data['response_token_ids'], non_thinking_prefix_ids=non_thinking_prefix_ids)
165168

166169
with self._template_context(template, max_length=max_length):
167170
encoded_list = [template.encode(data, return_length=True) for data in batch]

swift/megatron/trainers/grpo_trainer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from swift.megatron.arguments import MegatronArguments, MegatronRLHFArguments
2727
from swift.megatron.utils import RouterReplayHelper, get_padding_to, set_router_replay_data
2828
from swift.rlhf_trainers.grpo_trainer import DataType
29-
from swift.rlhf_trainers.utils import (aggressive_empty_cache, detect_async_reward_indices, make_reward_weights, nanstd,
30-
pad_logps_back_to_batch, profiling_context, profiling_decorator,
31-
replace_assistant_response_with_ids, resolve_reward_funcs,
29+
from swift.rlhf_trainers.utils import (aggressive_empty_cache, detect_async_reward_indices, get_non_thinking_prefix_ids,
30+
make_reward_weights, nanstd, pad_logps_back_to_batch, profiling_context,
31+
profiling_decorator, replace_assistant_response_with_ids, resolve_reward_funcs,
3232
set_expandable_segments)
3333
from swift.rollout import MultiTurnScheduler, multi_turns
3434
from swift.template import Template, TemplateInputs
@@ -1125,15 +1125,19 @@ def _disable_maxlength_template_context(self, template: Template):
11251125

11261126
def _maybe_replace_response_token(self, batch):
11271127
# maybe replace the response token with the response token ids to avoid repetitive tokenize
1128+
non_thinking_prefix_ids = get_non_thinking_prefix_ids(self.template)
11281129

11291130
for data in batch:
11301131
if 'response_token_ids' in data and data['response_token_ids']:
11311132
loss_mask = None
11321133
if 'response_loss_mask' in data and data['response_loss_mask']:
11331134
loss_mask = data['response_loss_mask']
11341135
# token in token out
1135-
data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'],
1136-
loss_mask)
1136+
data['messages'] = replace_assistant_response_with_ids(
1137+
data['messages'],
1138+
data['response_token_ids'],
1139+
loss_mask,
1140+
non_thinking_prefix_ids=non_thinking_prefix_ids)
11371141
return batch
11381142

11391143
@property

swift/megatron/trainers/rollout_mixin.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -435,15 +435,12 @@ def _export_and_load_weights(self):
435435
if self.vllm_mode == 'colocate':
436436
llm_model = self.engine.inner_model
437437
patch_vllm_moe_model_weight_loader(llm_model)
438-
# Re-run process_weights_after_loading on FusedMoE layers so
439-
# the kernel-format layout is rebuilt after the in-place reload
440-
# (workaround for vLLM issue #42821).
441-
try:
442-
llm_model.load_weights(weight_iterator)
443-
finally:
444-
finish_vllm_weight_reload(llm_model)
438+
llm_model.load_weights(weight_iterator)
439+
_model_config = self.engine.engine.model_config
440+
finish_vllm_weight_reload(llm_model, model_config=_model_config, target_device=self.device)
445441
elif self.vllm_mode == 'server':
446442
self._load_weights_to_server_in_buckets(weight_iterator)
443+
self.vllm_client.process_weights_after_loading()
447444

448445
def _get_vllm_param_names_for_mapping(self):
449446
"""Get vLLM runtime parameter names for base_layer mapping.

swift/pipelines/infer/rollout.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -284,13 +284,19 @@ def update_flattened_params(self, metadatas: list[Dict]) -> None:
284284
named_params = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor).reconstruct_tensors()
285285

286286
patch_vllm_moe_model_weight_loader(self.model_runner.model)
287-
# Re-run process_weights_after_loading on FusedMoE layers so the
288-
# kernel-format layout is rebuilt after the in-place reload
289-
# (workaround for vLLM issue #42821).
290-
try:
291-
self.model_runner.model.load_weights(weights=list(named_params.items()))
292-
finally:
293-
finish_vllm_weight_reload(self.model_runner.model)
287+
self.model_runner.model.load_weights(weights=list(named_params.items()))
288+
289+
def process_weights_after_loading(self) -> None:
290+
"""Re-run process_weights_after_loading once after ALL weight
291+
buckets have been loaded, so the kernel-format layout is rebuilt
292+
on complete weights rather than partial ones.
293+
294+
Uses vLLM's built-in ``process_weights_after_loading`` when
295+
*model_config* and *target_device* are available (same as verl);
296+
falls back to FusedMoE-only path otherwise.
297+
"""
298+
model_config = self.model_runner.model_config
299+
finish_vllm_weight_reload(self.model_runner.model, model_config=model_config, target_device=self.device)
294300

295301
def close_communicator(self) -> None:
296302
"""
@@ -512,12 +518,13 @@ def _broadcast_obj(obj):
512518
if metadata.get('is_last'):
513519
break
514520

515-
# Re-run process_weights_after_loading on FusedMoE layers so the
516-
# kernel-format layout is rebuilt after the in-place reload
517-
# (workaround for vLLM issue #42821). Skipped for LoRA sync
518-
# because the adapter path doesn't call ``load_weights``.
521+
# Re-run process_weights_after_loading so the kernel-format
522+
# layout is rebuilt after the in-place reload (vLLM issue
523+
# #42821). Skipped for LoRA sync because the adapter path
524+
# doesn't call ``load_weights``.
519525
if not is_lora_sync:
520-
finish_vllm_weight_reload(self.model_runner.model)
526+
model_config = self.model_runner.model_config
527+
finish_vllm_weight_reload(self.model_runner.model, model_config=model_config, target_device=self.device)
521528

522529
if is_lora_sync and all_lora_weights:
523530
req_kw = dict(
@@ -698,6 +705,7 @@ def _register_rl_rollout_app(self):
698705
self.app.post('/update_adapter_flattened_param/')(self.update_adapter_flattened_param)
699706
self.app.post('/update_adapter_param/')(self.update_adapter_param)
700707
self.app.post('/update_flattened_params/')(self.update_flattened_params)
708+
self.app.post('/process_weights_after_loading/')(self.process_weights_after_loading)
701709
self.app.post('/reset_prefix_cache/')(self.reset_prefix_cache)
702710
self.app.post('/reset_encoder_cache/')(self.reset_encoder_cache)
703711
self.app.post('/reset_mm_cache/')(self.reset_mm_cache)
@@ -926,6 +934,18 @@ async def update_flattened_params(self, request: UpdateFlattenedParamsRequest):
926934

927935
return {'message': 'Request received, updating flattened parameters'}
928936

937+
async def process_weights_after_loading(self):
938+
"""
939+
Triggers process_weights_after_loading on all workers.
940+
"""
941+
kwargs = {'method': 'process_weights_after_loading', 'args': ()}
942+
for connection in self.connections:
943+
connection.send({'type': 'call', 'method': 'collective_rpc', 'kwargs': kwargs})
944+
# Wait for all workers to complete before returning
945+
loop = asyncio.get_running_loop()
946+
await asyncio.gather(*(loop.run_in_executor(None, connection.recv) for connection in self.connections))
947+
return {'message': 'Weights processed after loading'}
948+
929949
async def reset_prefix_cache(self):
930950
"""
931951
Resets the prefix cache for the model.

swift/ray/megatron/gkd_trainer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from swift.infer_engine.protocol import RequestConfig, RolloutOutput
1313
from swift.rlhf_trainers.gkd_loss import DataSource, TeacherOutput, build_opsd_teacher_data
14-
from swift.rlhf_trainers.utils import (build_teacher_infer_request, parse_prompt_logprobs,
14+
from swift.rlhf_trainers.utils import (build_teacher_infer_request, get_non_thinking_prefix_ids, parse_prompt_logprobs,
1515
replace_assistant_response_with_ids)
1616
from swift.utils import get_logger
1717
from .base_trainer import BaseRayTrainer
@@ -228,13 +228,16 @@ def _encode_rollout_batch(self, rollout_batch):
228228
"""
229229
template = self.template
230230
samples = []
231+
non_thinking_prefix_ids = get_non_thinking_prefix_ids(template)
231232
with self._extended_max_length():
232233
for orig_item in rollout_batch:
233234
item = orig_item
234235
if item.get('response_token_ids'):
235236
item = dict(item)
236237
item['messages'] = replace_assistant_response_with_ids(
237-
copy.deepcopy(item['messages']), item['response_token_ids'])
238+
copy.deepcopy(item['messages']),
239+
item['response_token_ids'],
240+
non_thinking_prefix_ids=non_thinking_prefix_ids)
238241
encoded = template.encode(item, return_length=True)
239242
sample = {'encoded': encoded}
240243
# OPSD: if the dataset row carries a `teacher_prompt`, also encode the
@@ -259,7 +262,9 @@ def _encode_opsd_teacher(item, template):
259262
opsd_item = opsd_list[0]
260263
if opsd_item.get('response_token_ids'):
261264
opsd_item['messages'] = replace_assistant_response_with_ids(
262-
copy.deepcopy(opsd_item['messages']), opsd_item['response_token_ids'])
265+
copy.deepcopy(opsd_item['messages']),
266+
opsd_item['response_token_ids'],
267+
non_thinking_prefix_ids=get_non_thinking_prefix_ids(template))
263268
return template.encode(opsd_item, return_length=True)
264269

265270
def _fetch_teacher_from_replicas(self, rollout_with_outputs, samples):
@@ -283,7 +288,9 @@ def _fetch_teacher_from_replicas(self, rollout_with_outputs, samples):
283288
opsd_item = build_opsd_teacher_data([item])[0]
284289
if opsd_item.get('response_token_ids'):
285290
opsd_item['messages'] = replace_assistant_response_with_ids(
286-
copy.deepcopy(opsd_item['messages']), opsd_item['response_token_ids'])
291+
copy.deepcopy(opsd_item['messages']),
292+
opsd_item['response_token_ids'],
293+
non_thinking_prefix_ids=non_thinking_prefix_ids)
287294
requests.append(build_teacher_infer_request(opsd_item))
288295
teacher_encodeds.append(opsd_encoded)
289296
else:

swift/ray/megatron/grpo_trainer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from swift.dataset import RowPreprocessor
1313
from swift.infer_engine.protocol import RolloutInferRequest, RolloutOutput
14-
from swift.rlhf_trainers.utils import compute_grpo_advantages, make_reward_weights, resolve_reward_funcs
14+
from swift.rlhf_trainers.utils import (compute_grpo_advantages, get_non_thinking_prefix_ids, make_reward_weights,
15+
replace_assistant_response_with_ids, resolve_reward_funcs)
1516
from swift.rollout import MultiTurnScheduler, invoke_async_hook, multi_turns, run_multi_turn
1617
from swift.utils import get_logger
1718
from .base_trainer import BaseRayTrainer
@@ -497,8 +498,7 @@ def encode_rollout_batch(
497498
rollout_batch: Sequence[Dict[str, Any]],
498499
) -> List[Dict[str, Any]]:
499500
"""Encode rollout samples and keep them as per-sample payloads."""
500-
from swift.rlhf_trainers.utils import replace_assistant_response_with_ids
501-
501+
non_thinking_prefix_ids = get_non_thinking_prefix_ids(self.template)
502502
rollout_for_encode: List[Dict[str, Any]] = []
503503
for data in rollout_batch:
504504
item = dict(data)
@@ -508,8 +508,11 @@ def encode_rollout_batch(
508508
loss_mask = None
509509
if 'response_loss_mask' in item and item['response_loss_mask']:
510510
loss_mask = item['response_loss_mask']
511-
item['messages'] = replace_assistant_response_with_ids(item['messages'], item['response_token_ids'],
512-
loss_mask)
511+
item['messages'] = replace_assistant_response_with_ids(
512+
item['messages'],
513+
item['response_token_ids'],
514+
loss_mask,
515+
non_thinking_prefix_ids=non_thinking_prefix_ids)
513516
rollout_for_encode.append(item)
514517

515518
encoded_list, error_list = self._batch_encode_parallel(rollout_for_encode, strict=True)

swift/ray/megatron/megatron_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,6 @@ def _build_routed_experts_batch(
644644
raise AssertionError(
645645
f'The seq_len of routed_experts({experts_seq_len}) does not match encoded length '
646646
f'({expected_len}); expected same length or one less.')
647-
648647
target_len = int(cur_seq_len.item()) if template.padding_free else max_seq_len
649648
routed = self._pad_or_trim_routed_experts(routed, target_len, padding_right=padding_right)
650649
routed_tensors.append(routed)

swift/rlhf_trainers/gkd_trainer.py

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from swift.infer_engine.protocol import RequestConfig
2020
from swift.rlhf_trainers.gkd_loss import DataSource, TeacherOutput, build_opsd_teacher_data, gkd_loss
2121
from swift.rlhf_trainers.utils import (assemble_teacher_topk_logprobs, build_teacher_infer_request,
22-
parse_prompt_logprobs, prepare_fsdp, replace_assistant_response_with_ids)
22+
get_non_thinking_prefix_ids, parse_prompt_logprobs, prepare_fsdp,
23+
replace_assistant_response_with_ids)
2324
from swift.rlhf_trainers.vllm_client import VLLMInferClient
2425
from swift.template import TemplateInputs
2526
from swift.trainers import SwiftMixin, disable_gradient_checkpointing
@@ -369,11 +370,13 @@ def _prepare_batch_inputs(self, inputs: list, encode_prompt_only: bool = False)
369370
mode = 'transformers' if encode_prompt_only else 'train'
370371
original_mode = template.mode
371372
template.set_mode(mode)
373+
non_thinking_prefix_ids = get_non_thinking_prefix_ids(template)
372374
try:
373375
for data in inputs:
374376
if 'response_token_ids' in data and data['response_token_ids']:
375377
data = {**data}
376-
data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'])
378+
data['messages'] = replace_assistant_response_with_ids(
379+
data['messages'], data['response_token_ids'], non_thinking_prefix_ids=non_thinking_prefix_ids)
377380

378381
if encode_prompt_only:
379382
# Remove response content for prompt-only encoding
@@ -641,36 +644,6 @@ def _fetch_and_assemble_teacher_logprobs(self, chunks):
641644
c['_teacher_topk_logprobs'] = topk_lp
642645
c['_teacher_topk_indices'] = topk_ix
643646

644-
def _inline_fetch_teacher_logprobs(self, encoded_inputs: Dict[str, torch.Tensor], raw_data) -> None:
645-
"""Fetch teacher logprobs with gather+broadcast (used in eval/prediction_step).
646-
647-
Same synchronization pattern as _fetch_and_assemble_teacher_logprobs:
648-
only main_process has teacher_client, so we gather raw → fetch on rank0 → broadcast.
649-
"""
650-
all_raw = gather_object(list(raw_data))
651-
652-
if self.accelerator.is_main_process:
653-
requests = [build_teacher_infer_request(d) for d in all_raw]
654-
request_config = RequestConfig(prompt_logprobs=self.gkd_logits_topk, max_tokens=1, temperature=0.0)
655-
responses = self.teacher_client.infer(requests, request_config=request_config, use_tqdm=False)
656-
parsed_global = [parse_prompt_logprobs(r, topk=self.gkd_logits_topk) for r in responses]
657-
else:
658-
parsed_global = None
659-
660-
container = [parsed_global]
661-
broadcast_object_list(container, from_process=0)
662-
parsed_global = container[0]
663-
664-
# Slice this rank's portion (gather_object returns rank-ordered list)
665-
n_local = len(raw_data)
666-
rank = self.accelerator.process_index
667-
parsed = parsed_global[rank * n_local:(rank + 1) * n_local]
668-
669-
target = encoded_inputs.get('_opsd_teacher_inputs') or encoded_inputs
670-
topk_lp, topk_ix = self._assemble_topk_for_chunk(parsed, target)
671-
encoded_inputs['_teacher_topk_logprobs'] = topk_lp
672-
encoded_inputs['_teacher_topk_indices'] = topk_ix
673-
674647
@profiling_decorator
675648
def training_step(self,
676649
model: nn.Module,

swift/rlhf_trainers/grpo_trainer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,10 @@
6060
start_event_loop_in_daemon, to_device, unwrap_model_for_generation)
6161
from .arguments import GRPOConfig
6262
from .rollout_mixin import DataType, RolloutTrainerMixin, SyncRefModelCallback
63-
from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, identity_data_collator,
64-
load_pil_img, make_chord_sft_dataset, nanstd, pad_logps_back_to_batch, patch_save_last_checkpoint,
65-
profiling_context, profiling_decorator, replace_assistant_response_with_ids)
63+
from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, get_non_thinking_prefix_ids,
64+
identity_data_collator, load_pil_img, make_chord_sft_dataset, nanstd, pad_logps_back_to_batch,
65+
patch_save_last_checkpoint, profiling_context, profiling_decorator,
66+
replace_assistant_response_with_ids)
6667

6768
try:
6869
from trl.trainer.utils import entropy_from_logits
@@ -804,6 +805,7 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]:
804805
template = self.template
805806
gas_chunks = self.split_by_mini_batches(inputs)
806807
ga_batch_encoded_inputs = []
808+
non_thinking_prefix_ids = get_non_thinking_prefix_ids(template)
807809
for batch in gas_chunks:
808810
# Encode and process each batch (size=bs)
809811
with self._template_context(template):
@@ -812,8 +814,11 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]:
812814
loss_mask = None
813815
if 'response_loss_mask' in data and data['response_loss_mask']:
814816
loss_mask = data['response_loss_mask']
815-
data['messages'] = replace_assistant_response_with_ids(data['messages'],
816-
data['response_token_ids'], loss_mask)
817+
data['messages'] = replace_assistant_response_with_ids(
818+
data['messages'],
819+
data['response_token_ids'],
820+
loss_mask,
821+
non_thinking_prefix_ids=non_thinking_prefix_ids)
817822
batch_encoded_inputs = [template.encode(data, return_length=True) for data in batch]
818823
for encoded_inputs in batch_encoded_inputs:
819824
extra_kwargs = encoded_inputs.get('_extra_kwargs') or {}

0 commit comments

Comments
 (0)