Skip to content

Commit 12dd6b2

Browse files
authored
[on-policy distillation] support and related data handling (#673)
1 parent 1b2fa31 commit 12dd6b2

File tree

7 files changed

+260
-3
lines changed

7 files changed

+260
-3
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import aiohttp
2+
import torch
3+
4+
from slime.utils.types import Sample
5+
6+
7+
async def reward_func(args, sample, **kwargs):
8+
payload = {
9+
"text": sample.prompt + sample.response,
10+
"sampling_params": {
11+
"temperature": 0,
12+
"max_new_tokens": 0,
13+
"skip_special_tokens": False,
14+
},
15+
"return_logprob": True,
16+
"logprob_start_len": 0,
17+
}
18+
session_kwargs = {}
19+
async with aiohttp.ClientSession(**session_kwargs) as session:
20+
async with session.post(args.rm_url, json=payload) as resp:
21+
resp.raise_for_status()
22+
return await resp.json()
23+
24+
25+
def post_process_rewards(args, samples: list[Sample], **kwargs):
26+
rewards = [sample.get_reward_value(args) for sample in samples]
27+
response_lengths = [sample.response_length for sample in samples]
28+
teacher_log_probs = [
29+
torch.tensor([item[0] for item in reward["meta_info"]["input_token_logprobs"][1:]], dtype=torch.float32)
30+
for reward in rewards
31+
]
32+
teacher_log_probs = [
33+
t_log_prob[-response_length:] for t_log_prob, response_length in zip(teacher_log_probs, response_lengths)
34+
]
35+
36+
for sample, t_log_probs in zip(samples, teacher_log_probs):
37+
sample.teacher_log_probs = t_log_probs
38+
39+
return teacher_log_probs, teacher_log_probs
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
#!/bin/bash
2+
3+
# usage: bash examples/on_policy_distillation/run-qwen3-8B-opd.sh
4+
5+
set -ex
6+
7+
8+
# Start the teacher model server
9+
TEACHER_IP="127.0.0.1" # Use localhost here, you can change it to your IP
10+
TEACHER_PORT=13141
11+
LOG_FILE="/tmp/sglang_$(head /dev/urandom | tr -dc A-Za-z0-9 | head -c 6).log"
12+
13+
## Launch the teacher model server in the background
14+
CUDA_VISIBLE_DEVICES=7 python3 -m sglang.launch_server \
15+
--model-path /root/Qwen3-32B \
16+
--host 0.0.0.0 \
17+
--port $TEACHER_PORT \
18+
--tp 1 \
19+
--chunked-prefill-size 4096 \
20+
--mem-fraction-static 0.6 \
21+
> "$LOG_FILE" 2>&1 &
22+
23+
echo "Starting teacher model server..."
24+
25+
## Wait for the teacher model server to be ready
26+
until curl -sf http://$TEACHER_IP:$TEACHER_PORT/health_generate > /dev/null; do
27+
echo "Waiting for the teacher model server to start..."
28+
tail -n 10 "$LOG_FILE"
29+
sleep 5
30+
done
31+
32+
echo "Teacher model server is up and running at $TEACHER_IP:$TEACHER_PORT."
33+
sleep 10
34+
35+
36+
export PYTHONBUFFERED=16
37+
38+
NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
39+
if [ "$NVLINK_COUNT" -gt 0 ]; then
40+
HAS_NVLINK=1
41+
else
42+
HAS_NVLINK=0
43+
fi
44+
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
45+
46+
source "/root/slime/scripts/models/qwen3-8B.sh"
47+
48+
49+
CKPT_ARGS=(
50+
--hf-checkpoint /root/Qwen3-8B
51+
--ref-load /root/Qwen3-8B_torch_dist
52+
--load /root/Qwen3-8B_slime/
53+
--save /root/Qwen3-8B_slime/
54+
--save-interval 20
55+
)
56+
57+
ROLLOUT_ARGS=(
58+
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
59+
--input-key prompt
60+
--apply-chat-template
61+
--rollout-shuffle
62+
--num-rollout 300
63+
--rollout-batch-size 16
64+
--n-samples-per-prompt 4
65+
--rollout-max-response-len 16384
66+
--rollout-temperature 0.8
67+
68+
--global-batch-size 64
69+
--balance-data
70+
)
71+
72+
RM_ARGS=(
73+
--custom-rm-path examples.on_policy_distillation.on_policy_distillation.reward_func
74+
--custom-reward-post-process-path examples.on_policy_distillation.on_policy_distillation.post_process_rewards
75+
--rm-url http://$TEACHER_IP:$TEACHER_PORT/generate
76+
)
77+
78+
EVAL_ARGS=(
79+
# --eval-interval 20
80+
# --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl
81+
# --n-samples-per-eval-prompt 16
82+
# --eval-max-response-len 16384
83+
# --eval-top-p 0.7
84+
)
85+
86+
PERF_ARGS=(
87+
--tensor-model-parallel-size 2
88+
--sequence-parallel
89+
--pipeline-model-parallel-size 1
90+
--context-parallel-size 1
91+
--expert-model-parallel-size 1
92+
--expert-tensor-parallel-size 1
93+
94+
--recompute-granularity full
95+
--recompute-method uniform
96+
--recompute-num-layers 1
97+
98+
# --micro-batch-size 1
99+
--use-dynamic-batch-size
100+
--max-tokens-per-gpu 16384
101+
)
102+
103+
GRPO_ARGS=(
104+
--advantage-estimator on_policy_distillation
105+
--use-kl-loss
106+
--kl-loss-coef 0.00
107+
--kl-loss-type low_var_kl
108+
--entropy-coef 0.00
109+
)
110+
111+
OPTIMIZER_ARGS=(
112+
--optimizer adam
113+
--lr 1e-6
114+
--lr-decay-style constant
115+
--weight-decay 0.1
116+
--adam-beta1 0.9
117+
--adam-beta2 0.98
118+
)
119+
120+
WANDB_ARGS=(
121+
#--use-wandb
122+
# --wandb-project slime-dev
123+
# --wandb-group qwen3-8B-test
124+
# --wandb-key ${WANDB_KEY}
125+
)
126+
127+
SGLANG_ARGS=(
128+
--rollout-num-gpus-per-engine 1
129+
--sglang-mem-fraction-static 0.4
130+
)
131+
132+
133+
MISC_ARGS=(
134+
--attention-dropout 0.0
135+
--hidden-dropout 0.0
136+
--accumulate-allreduce-grads-in-fp32
137+
--attention-softmax-in-fp32
138+
--attention-backend flash
139+
)
140+
141+
142+
143+
144+
# launch the master node of ray in container
145+
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
146+
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
147+
148+
149+
ray job submit --address="http://127.0.0.1:8265" \
150+
--runtime-env-json='{
151+
"env_vars": {
152+
"PYTHONPATH": "/root/Megatron-LM/",
153+
"CUDA_DEVICE_MAX_CONNECTIONS": "1"
154+
}
155+
}' \
156+
-- python3 train.py \
157+
--actor-num-nodes 1 \
158+
--actor-num-gpus-per-node 2 \
159+
--rollout-num-gpus 4 \
160+
${MODEL_ARGS[@]} \
161+
${CKPT_ARGS[@]} \
162+
${ROLLOUT_ARGS[@]} \
163+
${OPTIMIZER_ARGS[@]} \
164+
${GRPO_ARGS[@]} \
165+
${WANDB_ARGS[@]} \
166+
${PERF_ARGS[@]} \
167+
${EVAL_ARGS[@]} \
168+
${SGLANG_ARGS[@]} \
169+
${MISC_ARGS[@]} \
170+
${RM_ARGS[@]}
171+
172+
173+
174+
####clear after training
175+
pkill -9 sglang
176+
sleep 3
177+
ray stop --force
178+
pkill -9 ray
179+
pkill -9 python
180+
sleep 3
181+
pkill -9 ray
182+
pkill -9 python
183+
184+
185+
186+
187+
188+
189+
190+
191+
192+

slime/backends/fsdp_utils/actor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _train_step(self, packed_batch, world_size, reported_accum, mbs_id, grad_acc
484484
temperature=self.args.rollout_temperature,
485485
)
486486
packed_batch["cur_log_probs"] = log_probs
487-
487+
488488
shifted_logits = logits.squeeze(0)[:-1]
489489
log_probs_full = torch.log_softmax(shifted_logits, dim=-1)
490490
probs = torch.softmax(shifted_logits, dim=-1)
@@ -554,7 +554,7 @@ def _train_step(self, packed_batch, world_size, reported_accum, mbs_id, grad_acc
554554

555555
entropy = torch.cat([batch["entropy"] for batch in unpacked_batches], dim=0)
556556
entropy_loss = sum_of_sample_mean(entropy, response_lengths, loss_masks)
557-
557+
558558
loss = pg_loss - self.args.entropy_coef * entropy_loss
559559

560560
if self.args.use_kl_loss:

slime/backends/megatron_utils/loss.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,21 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
286286
)
287287
returns = advantages
288288

289+
elif args.advantage_estimator == "on_policy_distillation":
290+
student_log_probs = log_probs
291+
teacher_log_probs = rollout_data.get("teacher_log_probs")
292+
response_lengths = rollout_data.get("response_lengths")
293+
device = student_log_probs[0].device
294+
teacher_log_probs = [t_log_prob.to(device=device) for t_log_prob in teacher_log_probs]
295+
teacher_log_probs = [
296+
t_log_prob[-response_length:] for t_log_prob, response_length in zip(teacher_log_probs, response_lengths)
297+
]
298+
advantages = [
299+
teacher_log_prob - student_log_prob
300+
for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs)
301+
]
302+
returns = advantages
303+
289304
else:
290305
raise NotImplementedError(f"advantage_estimator {args.advantage_estimator} is not supported. ")
291306

slime/ray/rollout.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ def _convert_samples_to_train_data(self, samples: Union[list[Sample], list[list[
249249
if samples[0].train_metadata is not None:
250250
train_data["metadata"] = [sample.train_metadata for sample in samples]
251251

252+
if "teacher_log_probs" in samples[0].__dict__:
253+
train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples]
254+
252255
return train_data
253256

254257

slime/utils/arguments.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,14 @@ def add_algo_arguments(parser):
672672
parser.add_argument(
673673
"--advantage-estimator",
674674
type=str,
675-
choices=["grpo", "gspo", "reinforce_plus_plus", "reinforce_plus_plus_baseline", "ppo"],
675+
choices=[
676+
"grpo",
677+
"gspo",
678+
"reinforce_plus_plus",
679+
"reinforce_plus_plus_baseline",
680+
"ppo",
681+
"on_policy_distillation",
682+
],
676683
default="grpo",
677684
)
678685
parser.add_argument(

slime/utils/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def get_partition(val):
211211
"sample_indices",
212212
"rollout_log_probs",
213213
"prompt",
214+
"teacher_log_probs",
214215
]:
215216
if key not in data:
216217
continue

0 commit comments

Comments
 (0)