Skip to content

Latest commit

 

History

History
117 lines (88 loc) · 4.79 KB

File metadata and controls

117 lines (88 loc) · 4.79 KB

On-Policy Distillation

Overview

On-policy distillation trains the student using teacher guidance on trajectories sampled from its own policy, reducing distribution mismatch and improving stability. Combined with reinforcement learning, it lets the student imitate the teacher while exploring simultaneously.

AReaL previously supported RL for post-training. With this implementation, it now also supports on-policy knowledge distillation and the combined KDRL framework, enabling the student to learn from a teacher while exploring via RL on the same on-policy trajectories, improving both efficiency and stability.

The Core Concept

Knowledge distillation aims to train the student policy $\pi_\theta$ to mimic the behavior of a more powerful teacher $\pi_T$. The choice of divergence measure and sampling distribution significantly impacts the student's final performance and exposure bias.

Supervised Fine-Tuning (Forward KL):

A simple yet effective method is to maximize the log-likelihood on data generated by the teacher, known as supervised fine-tuning (SFT). This is equivalent to minimizing the Forward KL divergence between $\pi_T$ and $\pi_\theta$: $$\arg \min_{\theta} D_{KL}(\pi_T \parallel \pi_\theta) = \arg \max_{\theta} \mathbb{E}{q \sim Q, o \sim \pi_T(\cdot|q)} [\log \pi\theta(o|q)]$$

On-Policy Distillation (Reverse KL):

While SFT is efficient, training on off-policy data induces exposure bias: a mismatch between training on teacher-generated prefixes and inference on self-generated prefixes. This is especially severe for reasoning LLMs with long response chains. To alleviate this, we can train on self-generated trajectories, which is equivalent to minimizing the Reverse KL divergence (RKL) [1]: $$\arg \min_{\theta} D_{KL}(\pi_\theta \parallel \pi_T) = \arg \max_{\theta} \mathbb{E}{q \sim Q, o \sim \pi\theta(\cdot|q)} \left[ \log \frac{\pi_T(o|q)}{\pi_\theta(o|q)} \right]$$

Minimizing RKL is equivalent to REINFORCE where the "reward" is the log-ratio of teacher to student probabilities. By adopting the GRPO framework, we optimize [1]:

$$J_{RKL}(\theta) = \mathbb{E}_{q, {o_i} \sim \pi_{\theta_{old}}} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \frac{\pi_\theta(o_{i,t})}{\pi_{\theta_{old}}(o_{i,t})} R_{i,t} \right]$$

where the reward $R_{i,t} = \log \pi_T(o_{i,t}) - \log \pi_\theta(o_{i,t})$. This encourages the student to increase the probability of tokens the teacher prefers and suppress those it deems unlikely.

  • Implementation Detail: During pure KD, we need to set rl_loss_weight to 0, so the implementation estimates the RKL gradient using importance sampling. The code calculates the reward as teacher_logp - logprobs ($R_{i,t}$) and applies a negative coefficient to the loss to perform minimization (check areal/trainer/ppo/actor.py).

Combination of GRPO and KD

We implemented KD+RL approach using a Joint Loss strategy.

Joint Loss:

This strategy augments the GRPO objective with an auxiliary KL loss term. To maintain consistency with the on-policy nature of GRPO, it utilizes the Reverse KL (RKL) [1]: $$J_{KDRL}(\theta) = J_{GRPO}(\theta) - \beta D_{KL}(\pi_\theta \parallel \pi_T) \tag{8}$$

The gradient $\nabla_\theta J_{KDRL}(\theta)$ provides an unbiased estimate of $\nabla_\theta J_{GRPO}( \theta) + \beta \cdot \nabla_\theta J_{RKL}(\theta)$.

  • Implementation Detail: In the joint loss case (rl_loss_weight > 0), the RKL is treated as a direct penalty. Minimizing the term logprobs - teacher_logp is mathematically equivalent to minimizing the Reverse KL objective $D_{KL}(\pi_\theta \parallel \pi_T)$ when sampling from the student distribution $\pi_\theta$. In the code, this is implemented as: loss = rl_loss_weight * loss + distill_loss_weight * rkl_penalty

Running the example

Need to add teacher configuration to your yaml:

teacher:
  backend: fsdp:d1p1t4
  rl_loss_weight: 1.0
  distill_loss_weight: 0.005
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  path: Qwen/Qwen3-32B
  init_from_scratch: false
  disable_dropout: true
  dtype: ${actor.dtype}
  mb_spec:
    max_tokens_per_mb: 10240
  optimizer: null
  scheduling_spec: ${actor.scheduling_spec}

Example command using local scheduler:

python3 examples/math/gsm8k_rl.py --config examples/distillation/gsm8k_grpo_distill.yaml scheduler.type=local experiment_name=gsm8k-grpo-distillation trial_name=trial0

Result

On-policy knowledge distillation + RL reward plot for Qwen2.5-14B-Instruct (teacher) and Qwen3-0.6B (student), trained using FSDP and vLLM.

alt text

References

[1] Xu H, Zhu Q, Deng H, Li J, Hou L, Wang Y, Shang L, Xu R, Mi F. Kdrl: Post-training reasoning llms via unified knowledge distillation and reinforcement learning. KDRL paper link