Skip to content

[on-policy distillation] support and related data handling#673

Merged
yitianlian merged 5 commits intoTHUDM:mainfrom
ahxt:feature/on_policy_distillation
Nov 12, 2025
Merged

[on-policy distillation] support and related data handling#673
yitianlian merged 5 commits intoTHUDM:mainfrom
ahxt:feature/on_policy_distillation

Conversation

@ahxt
Copy link
Contributor

@ahxt ahxt commented Nov 3, 2025

PR: On-Policy Distillation Support

This PR introduces On-Policy Distillation to the slime framework, extending its reinforcement learning (RL) pipeline to support teacher–student distillation directly within on-policy training.

Thanks to the modular design of slime, integrating On-Policy Distillation is straightforward. In this PR, the teacher model acts as a reward model (RM) by providing teacher log probabilities as the supervision signal.

1. add on_policy_distillation example folderexamples/on_policy_distillation/

  • on_policy_distillation.py — implements reward_func and post_process_rewards
  • run-qwen3-8B-opd.sh — example training script for Qwen3-8B student model and Qwen3-32B as teacher model

2. Advantage Estimator Extension (loss.py)

  • Added on_policy_distillation advantage estimator
  • Computes advantages as the difference between teacher and student log probabilities

3. Data Pipeline Integration (rollout.py, data.py)

  • Extended rollout data structure to include:
    • teacher_log_probs
    • teacher_token_ids

4. Teacher Model Server

  • a separate SGLang server to serve the teacher model and return log probabilities
  • only prefill stage for teacher model

Comment on lines +184 to +194

####clear after training
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part could be deleted


for key, val in rollout_data.items():
if key == "tokens" or key == "loss_masks" or key == "sample_indices":
if key == "tokens" or key == "loss_masks" or key == "sample_indices" or key == "teacher_token_ids" or key == "rewards":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we exclude key=="reward"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll add the reward back. However, I’m concerned that the reward here is meaningless—it’s just the average of the teacher log probabilities.

Comment on lines +294 to +296
teacher_log_probs = [t_log_prob.to(device=device) for t_log_prob in teacher_log_probs]
teacher_log_probs = [t_log_prob[-response_length:] for t_log_prob, response_length in zip(teacher_log_probs, response_lengths)]
advantages = [teacher_log_prob - student_log_prob for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should do the slice for teacher logs when creating samples.
Remove teacher token ids.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do. the teacher token ids are for debugging.

Comment on lines +29 to +32
def post_process_rewards(args, samples: list[Sample], **kwargs):
rewards = [sample.get_reward_value(args) for sample in samples]
teacher_log_probs = [torch.tensor([item[0] for item in reward["meta_info"]["input_token_logprobs"][1:]], dtype=torch.float32) for reward in rewards]
teacher_token_ids = [torch.tensor([item[1] for item in reward["meta_info"]["input_token_logprobs"][1:]], dtype=torch.int32) for reward in rewards]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do [-response_length:] here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. we can and will do.


for sample, t_log_probs, t_token_ids in zip(samples, teacher_log_probs, teacher_token_ids):
sample.teacher_log_probs = t_log_probs
sample.teacher_token_ids = t_token_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we don’t need teacher token IDs. We should maintain only one token ID list per sample.

echo "Starting teacher model server..."

## Wait for the server to be ready
until curl -sf http://127.0.0.1:$TEACHER_PORT/health_generate > /dev/null; do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IP should use the master address or 0.0.0.0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example only uses one node, so it should be 127.0.0.1 or localhost.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be better to use 0.0.0.0, as users might directly copy this script and run it on multiple nodes setting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 0.0.0.0 won’t work — it should be either the local host (127.0.0.1) or a specific remote host (xxx.xxx.xxx.xx). In this case, the teacher_ip is the local host (127.0.0.1).

I’ll set it like this:

teacher_IP="127.0.0.1" # set to your teacher server's IP.
teacher_host="13141"

Comment on lines +254 to +255
if "teacher_token_ids" in samples[0].__dict__:
train_data["teacher_token_ids"] = [sample.teacher_token_ids for sample in samples]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this

@ahxt
Copy link
Contributor Author

ahxt commented Nov 12, 2025

  1. I added the first-step loss of this example for reference:

(MegatronTrainRayActor pid=46249) step 0: {'train/loss': 0.1438787877559662, 'train/pg_loss': 0.1438787877559662, 'train/entropy_loss': 0.24523311853408813, 'train/pg_clipfrac': 0.0, 'train/ppo_kl': 0.0, 'train/train_rollout_logprob_abs_diff': 0.012880997732281685, 'train/kl_loss': 0.0, 'train/grad_norm': 1.7898532502527087, 'train/lr-pg_0': 1e-06, 'train/lr-pg_1': 1e-06}

  1. I also ran experiments with this code on the OpenThoughts3 dataset. Results (Math500, pass@1, 8 samples):
  • Qwen3-8B-Base + SFT: 76%
  • Qwen3-8B-Base + SFT + On-Policy-Distillation: 94%

RM_ARGS=(
--custom-rm-path examples.on_policy_distillation.on_policy_distillation.reward_func
--custom-reward-post-process-path examples.on_policy_distillation.on_policy_distillation.post_process_rewards
--rm-url http://127.0.0.1:$TEACHER_PORT/generate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also change here? I think a better script can be:

teach_IP="0.0.0.0"
teacher_host="13141"
....

Copy link
Collaborator

@yitianlian yitianlian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@yitianlian yitianlian added the ci label Nov 12, 2025
@ahxt
Copy link
Contributor Author

ahxt commented Nov 12, 2025

Formatted the code to satisfy CI—let me know if anything else is required.”

@yitianlian yitianlian merged commit 12dd6b2 into THUDM:main Nov 12, 2025
3 of 4 checks passed
llltttwww pushed a commit to llltttwww/slime that referenced this pull request Nov 30, 2025

async def reward_func(args, sample, **kwargs):
payload = {
"text": sample.prompt + sample.response,
Copy link

@liujiahua123123 liujiahua123123 Dec 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahxt we should probably use input_ids here, or there might be discrepancy.

Yangruipis pushed a commit to rednote-ai/slime that referenced this pull request Feb 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants