Skip to content

Hardening: Restore Teacher Distillation Pipeline and KL Loss #456

@RUFFY-369

Description

@RUFFY-369

Describe the Issue

The "Teacher Distillation" feature was found to be functionally incomplete. While the TeacherDistillationEnv correctly collected top_k token IDs and logprobs from the teacher model, this data was dropped at the Trainer level:

  1. Data Loader Neglect: pad_data_to_good_offset in example_trainer/data.py did not extract or pad the distillation fields from the rollout payload.
  2. Missing Loss Logic: compute_grpo_loss in example_trainer/training.py lacked the KL-divergence term required to align the student's distribution with the teacher's signals.

This made distillation a "no-op" feature: it consumed GPU resources for teacher inference but provided zero learning signal to the student.

Environment/API Details

  • Environment Class/Name: gsm8k_server_teacher_distill.py / TeacherDistillationEnv
  • Environment Configuration:
    teacher_enabled: true
    teacher_top_k: 4
    distill_alpha: 0.1
    
  • API Endpoint/Method Involved: example_trainer/data.py and example_trainer/training.py

Steps to Reproduce

  1. Launch a training run with TeacherDistillationEnv.
  2. Observe that the teacher is queried (inference logs).
  3. Observe distill_loss is missing from metrics, or the student model shows no improvement toward the teacher's reasoning patterns despite valid teacher rollouts.

Interaction Details (if applicable)

  • Expected Behavior:
    1. The data loader should extract, pad, and causally shift distill_token_ids and distill_logprobs.
    2. The trainer should compute an on-policy KL-divergence loss between the student's logits and the teacher's top-K distribution.

Setup Details

  • OS: Linux
  • Python Version: 3.10+
  • Atropos Version: commit c20c852
  • Relevant Libraries/Versions: torch, vllm

Additional Context & Logs

The fix implements a re-normalized KL divergence over the top-K tokens, allowing the student to learn from the teacher's distribution without requiring full-vocab logit computation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions