Describe the Issue
The "Teacher Distillation" feature was found to be functionally incomplete. While the TeacherDistillationEnv correctly collected top_k token IDs and logprobs from the teacher model, this data was dropped at the Trainer level:
- Data Loader Neglect:
pad_data_to_good_offset in example_trainer/data.py did not extract or pad the distillation fields from the rollout payload.
- Missing Loss Logic:
compute_grpo_loss in example_trainer/training.py lacked the KL-divergence term required to align the student's distribution with the teacher's signals.
This made distillation a "no-op" feature: it consumed GPU resources for teacher inference but provided zero learning signal to the student.
Environment/API Details
- Environment Class/Name:
gsm8k_server_teacher_distill.py / TeacherDistillationEnv
- Environment Configuration:
teacher_enabled: true
teacher_top_k: 4
distill_alpha: 0.1
- API Endpoint/Method Involved:
example_trainer/data.py and example_trainer/training.py
Steps to Reproduce
- Launch a training run with
TeacherDistillationEnv.
- Observe that the teacher is queried (inference logs).
- Observe
distill_loss is missing from metrics, or the student model shows no improvement toward the teacher's reasoning patterns despite valid teacher rollouts.
Interaction Details (if applicable)
- Expected Behavior:
- The data loader should extract, pad, and causally shift
distill_token_ids and distill_logprobs.
- The trainer should compute an on-policy KL-divergence loss between the student's logits and the teacher's top-K distribution.
Setup Details
- OS: Linux
- Python Version: 3.10+
- Atropos Version: commit c20c852
- Relevant Libraries/Versions:
torch, vllm
Additional Context & Logs
The fix implements a re-normalized KL divergence over the top-K tokens, allowing the student to learn from the teacher's distribution without requiring full-vocab logit computation.
Describe the Issue
The "Teacher Distillation" feature was found to be functionally incomplete. While the
TeacherDistillationEnvcorrectly collectedtop_ktoken IDs andlogprobsfrom the teacher model, this data was dropped at the Trainer level:pad_data_to_good_offsetinexample_trainer/data.pydid not extract or pad the distillation fields from the rollout payload.compute_grpo_lossinexample_trainer/training.pylacked the KL-divergence term required to align the student's distribution with the teacher's signals.This made distillation a "no-op" feature: it consumed GPU resources for teacher inference but provided zero learning signal to the student.
Environment/API Details
gsm8k_server_teacher_distill.py/TeacherDistillationEnvexample_trainer/data.pyandexample_trainer/training.pySteps to Reproduce
TeacherDistillationEnv.distill_lossis missing from metrics, or the student model shows no improvement toward the teacher's reasoning patterns despite valid teacher rollouts.Interaction Details (if applicable)
distill_token_idsanddistill_logprobs.Setup Details
torch,vllmAdditional Context & Logs
The fix implements a re-normalized KL divergence over the top-K tokens, allowing the student to learn from the teacher's distribution without requiring full-vocab logit computation.