[on-policy distillation] support and related data handling#673
[on-policy distillation] support and related data handling#673yitianlian merged 5 commits intoTHUDM:mainfrom
Conversation
|
|
||
| ####clear after training | ||
| pkill -9 sglang | ||
| sleep 3 | ||
| ray stop --force | ||
| pkill -9 ray | ||
| pkill -9 python | ||
| sleep 3 | ||
| pkill -9 ray | ||
| pkill -9 python | ||
|
|
There was a problem hiding this comment.
I think this part could be deleted
|
|
||
| for key, val in rollout_data.items(): | ||
| if key == "tokens" or key == "loss_masks" or key == "sample_indices": | ||
| if key == "tokens" or key == "loss_masks" or key == "sample_indices" or key == "teacher_token_ids" or key == "rewards": |
There was a problem hiding this comment.
Why we exclude key=="reward"
There was a problem hiding this comment.
I’ll add the reward back. However, I’m concerned that the reward here is meaningless—it’s just the average of the teacher log probabilities.
| teacher_log_probs = [t_log_prob.to(device=device) for t_log_prob in teacher_log_probs] | ||
| teacher_log_probs = [t_log_prob[-response_length:] for t_log_prob, response_length in zip(teacher_log_probs, response_lengths)] | ||
| advantages = [teacher_log_prob - student_log_prob for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs)] |
There was a problem hiding this comment.
You should do the slice for teacher logs when creating samples.
Remove teacher token ids.
There was a problem hiding this comment.
will do. the teacher token ids are for debugging.
| def post_process_rewards(args, samples: list[Sample], **kwargs): | ||
| rewards = [sample.get_reward_value(args) for sample in samples] | ||
| teacher_log_probs = [torch.tensor([item[0] for item in reward["meta_info"]["input_token_logprobs"][1:]], dtype=torch.float32) for reward in rewards] | ||
| teacher_token_ids = [torch.tensor([item[1] for item in reward["meta_info"]["input_token_logprobs"][1:]], dtype=torch.int32) for reward in rewards] |
There was a problem hiding this comment.
Can we do [-response_length:] here?
There was a problem hiding this comment.
yes. we can and will do.
|
|
||
| for sample, t_log_probs, t_token_ids in zip(samples, teacher_log_probs, teacher_token_ids): | ||
| sample.teacher_log_probs = t_log_probs | ||
| sample.teacher_token_ids = t_token_ids |
There was a problem hiding this comment.
I believe we don’t need teacher token IDs. We should maintain only one token ID list per sample.
| echo "Starting teacher model server..." | ||
|
|
||
| ## Wait for the server to be ready | ||
| until curl -sf http://127.0.0.1:$TEACHER_PORT/health_generate > /dev/null; do |
There was a problem hiding this comment.
The IP should use the master address or 0.0.0.0.
There was a problem hiding this comment.
This example only uses one node, so it should be 127.0.0.1 or localhost.
There was a problem hiding this comment.
I think it will be better to use 0.0.0.0, as users might directly copy this script and run it on multiple nodes setting.
There was a problem hiding this comment.
I think 0.0.0.0 won’t work — it should be either the local host (127.0.0.1) or a specific remote host (xxx.xxx.xxx.xx). In this case, the teacher_ip is the local host (127.0.0.1).
I’ll set it like this:
teacher_IP="127.0.0.1" # set to your teacher server's IP.
teacher_host="13141"
slime/ray/rollout.py
Outdated
| if "teacher_token_ids" in samples[0].__dict__: | ||
| train_data["teacher_token_ids"] = [sample.teacher_token_ids for sample in samples] |
…plify return values
…obabilities and clean up server startup messages
|
| RM_ARGS=( | ||
| --custom-rm-path examples.on_policy_distillation.on_policy_distillation.reward_func | ||
| --custom-reward-post-process-path examples.on_policy_distillation.on_policy_distillation.post_process_rewards | ||
| --rm-url http://127.0.0.1:$TEACHER_PORT/generate |
There was a problem hiding this comment.
Also change here? I think a better script can be:
teach_IP="0.0.0.0"
teacher_host="13141"
....|
Formatted the code to satisfy CI—let me know if anything else is required.” |
|
|
||
| async def reward_func(args, sample, **kwargs): | ||
| payload = { | ||
| "text": sample.prompt + sample.response, |
There was a problem hiding this comment.
@ahxt we should probably use input_ids here, or there might be discrepancy.
PR: On-Policy Distillation Support
This PR introduces On-Policy Distillation to the slime framework, extending its reinforcement learning (RL) pipeline to support teacher–student distillation directly within on-policy training.
Thanks to the modular design of slime, integrating On-Policy Distillation is straightforward. In this PR, the teacher model acts as a reward model (RM) by providing teacher log probabilities as the supervision signal.
1. add on_policy_distillation example folder
examples/on_policy_distillation/on_policy_distillation.py— implementsreward_funcandpost_process_rewardsrun-qwen3-8B-opd.sh— example training script for Qwen3-8B student model and Qwen3-32B as teacher model2. Advantage Estimator Extension (
loss.py)on_policy_distillationadvantage estimator3. Data Pipeline Integration (
rollout.py,data.py)teacher_log_probsteacher_token_ids4. Teacher Model Server