Skip to content

Commit c59d8f0

Browse files
authored
[on-policy distillation] support and related data handling (THUDM#673)
1 parent 3aabe8d commit c59d8f0

File tree

5 files changed

+29
-3
lines changed

5 files changed

+29
-3
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _train_step(self, packed_batch, world_size, reported_accum, mbs_id, grad_acc
484484
temperature=self.args.rollout_temperature,
485485
)
486486
packed_batch["cur_log_probs"] = log_probs
487-
487+
488488
shifted_logits = logits.squeeze(0)[:-1]
489489
log_probs_full = torch.log_softmax(shifted_logits, dim=-1)
490490
probs = torch.softmax(shifted_logits, dim=-1)
@@ -554,7 +554,7 @@ def _train_step(self, packed_batch, world_size, reported_accum, mbs_id, grad_acc
554554

555555
entropy = torch.cat([batch["entropy"] for batch in unpacked_batches], dim=0)
556556
entropy_loss = sum_of_sample_mean(entropy, response_lengths, loss_masks)
557-
557+
558558
loss = pg_loss - self.args.entropy_coef * entropy_loss
559559

560560
if self.args.use_kl_loss:

slime/backends/megatron_utils/loss.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,21 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
286286
)
287287
returns = advantages
288288

289+
elif args.advantage_estimator == "on_policy_distillation":
290+
student_log_probs = log_probs
291+
teacher_log_probs = rollout_data.get("teacher_log_probs")
292+
response_lengths = rollout_data.get("response_lengths")
293+
device = student_log_probs[0].device
294+
teacher_log_probs = [t_log_prob.to(device=device) for t_log_prob in teacher_log_probs]
295+
teacher_log_probs = [
296+
t_log_prob[-response_length:] for t_log_prob, response_length in zip(teacher_log_probs, response_lengths)
297+
]
298+
advantages = [
299+
teacher_log_prob - student_log_prob
300+
for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs)
301+
]
302+
returns = advantages
303+
289304
else:
290305
raise NotImplementedError(f"advantage_estimator {args.advantage_estimator} is not supported. ")
291306

slime/ray/rollout.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ def _convert_samples_to_train_data(self, samples: Union[list[Sample], list[list[
249249
if samples[0].train_metadata is not None:
250250
train_data["metadata"] = [sample.train_metadata for sample in samples]
251251

252+
if "teacher_log_probs" in samples[0].__dict__:
253+
train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples]
254+
252255
return train_data
253256

254257

slime/utils/arguments.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,14 @@ def add_algo_arguments(parser):
672672
parser.add_argument(
673673
"--advantage-estimator",
674674
type=str,
675-
choices=["grpo", "gspo", "reinforce_plus_plus", "reinforce_plus_plus_baseline", "ppo"],
675+
choices=[
676+
"grpo",
677+
"gspo",
678+
"reinforce_plus_plus",
679+
"reinforce_plus_plus_baseline",
680+
"ppo",
681+
"on_policy_distillation",
682+
],
676683
default="grpo",
677684
)
678685
parser.add_argument(

slime/utils/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def get_partition(val):
211211
"sample_indices",
212212
"rollout_log_probs",
213213
"prompt",
214+
"teacher_log_probs",
214215
]:
215216
if key not in data:
216217
continue

0 commit comments

Comments
 (0)