[fsdp,vllm,trainer,algo] feat: On-Policy Distillation#4897
[fsdp,vllm,trainer,algo] feat: On-Policy Distillation#4897JacobHelwig wants to merge 167 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request appears to be a work-in-progress for adding on-policy distillation support. The only change present is a minor whitespace modification in the README.md file. As this is a stylistic change with no functional impact, and I am configured to only report issues of high or critical severity, I have no specific comments on the current state of the pull request. I look forward to reviewing more substantial changes as they are added.
4245789 to
840aca3
Compare
|
why the reward is so small (almost 0) before step 40 ? |
Training only explicitly optimizes the distillation loss, not rewards:
Any increase in the logged rewards=GSM8k accuracy are an indirect result of minimizing the distillation loss. In this case, the reason that the base model has Pass@1~=0 is because the default GSM8k answer formatting ( ...
reward_model.reward_manager=remote \
custom_reward_function.path=tests/experimental/reward_loop/reward_fn.py \
custom_reward_function.name=compute_score_math_verify \
trainer.val_only=TrueThe results are: (TaskRunner pid=904198) ("Initial validation metrics: {'val-aux/openai/gsm8k/reward/mean@1': "
(TaskRunner pid=904198) "np.float64(0.31766489764973466), 'val-core/openai/gsm8k/acc/mean@1': "
(TaskRunner pid=904198) "np.float64(0.31766489764973466), 'val-aux/num_turns/min': np.int32(2), "
(TaskRunner pid=904198) "'val-aux/num_turns/max': np.int32(2), 'val-aux/num_turns/mean': "
(TaskRunner pid=904198) 'np.float64(2.0)}')The formatting is only a few tokens, so it does not contribute much to the distillation loss. The distillation loss initially focuses on minimizing other discrepancies between the teacher and student distributions before targeting formatting, which is why early steps of training show 0% accuracy under the stricter parser. |
216f0a7 to
eaca4e1
Compare
9f3cb9d to
d0d0d55
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces on-policy distillation (OPD) as a new feature, which is a significant addition to the library. The implementation is comprehensive, adding new configurations, loss functions, and utility functions, and integrating them into the existing PPO trainer pipeline. My review has identified several critical and high-severity issues related to correctness, robustness, and efficiency that should be addressed. These include a bug that could cause a crash when not using use_remove_padding, an invalid default configuration value, and incorrect handling of zero-length prompts. Additionally, there are some type hint inaccuracies and unnecessary tensor cloning that impact code quality and performance.
verl/trainer/distillation/utils.py
Outdated
| def compute_topk_distillation_inputs(logits: torch.Tensor, batch: TensorDict, cu_seqlens: torch.Tensor, config: DistillationConfig): | ||
| """TODO: Docstring""" | ||
| # Gather inputs for top-k distillation losses. | ||
| logits = logits.squeeze(0) |
There was a problem hiding this comment.
The use of logits.squeeze(0) assumes that the logits tensor always has a batch dimension of size 1. This holds true when use_remove_padding is enabled, as the input is reshaped to (1, total_tokens, ...). However, when use_remove_padding is disabled, logits will have a shape of (batch_size, seq_len, vocab_size). If batch_size is greater than 1, squeeze(0) will raise an error. This will prevent distillation from working correctly in this configuration.
A more robust approach would be to reshape the tensor to flatten the batch and sequence dimensions if they exist, rather than squeezing a specific dimension.
| sequence_lens = prompt_lens + response_lens | ||
| sequence_offsets = sequence_lens.cumsum(dim=0) | ||
| assert sequence_offsets[-1].item() == values.shape[0] | ||
| assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}" |
There was a problem hiding this comment.
The assertion assert not prompt_lens.eq(0).any() will cause a crash if any of the prompts in a batch have a length of 0. While this might be an uncommon case, the code should handle it gracefully instead of crashing. The slicing logic values[seq_offset - resp_len - 1 : seq_offset - 1] relies on prompt_len > 0.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces on-policy distillation capabilities, including top-k and KL estimator-based distillation losses. The changes are well-structured, with new configurations, loss functions, and utility modules. The integration into the existing PPO trainer and FSDP engine seems correct. I've identified one area for improvement regarding the use of __post_init__ in a class with multiple inheritance, which could be made more robust.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant new feature: on-policy distillation. The implementation is well-designed, with a clear separation of concerns between configuration, data processing stages, and loss computation. The use of a registry for distillation losses is a good practice for extensibility. I appreciate the addition of unit tests for the new utility functions and the inclusion of an example script. The new validation checks in the configuration are also a great addition to prevent misuse. I have one high-severity finding in the example script where a variable is overwritten, which could lead to incorrect experiment execution.
| DISTILLATION_LOSS_MODE="jsd_topk" | ||
| DISTILLATION_LOSS_MODE="k3" |
There was a problem hiding this comment.
The DISTILLATION_LOSS_MODE variable is defined twice. The second definition on line 26 will always overwrite the first one on line 25. This will cause the script to run with loss-"k3" regardless of the intention to test jsd_topk. To fix this, you should comment out or remove one of the definitions.
| DISTILLATION_LOSS_MODE="jsd_topk" | |
| DISTILLATION_LOSS_MODE="k3" | |
| #DISTILLATION_LOSS_MODE="jsd_topk" | |
| DISTILLATION_LOSS_MODE="k3" |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces on-policy distillation, a significant new feature. The implementation includes new configuration options, distillation loss functions (top-k and KL estimators), and integrates them into the FSDP engine and PPO trainer. The changes are extensive and well-structured, with new modules for distillation logic and tests for utility functions. My review found a couple of issues in the configuration and example script that need to be addressed. Overall, this is a solid contribution that adds valuable new capabilities.
| DISTILLATION_LOSS_MODE="jsd_topk" | ||
| DISTILLATION_LOSS_MODE="k3" | ||
| DISTILLATION_LOSS_MODE="reverse_kl_topk+" |
There was a problem hiding this comment.
The DISTILLATION_LOSS_MODE variable is being redefined on consecutive lines. This means only the last value, "reverse_kl_topk+", will be effective when the script is run. This is likely not the intended behavior and can lead to confusion or incorrect experiments. To make it easier to switch between different loss modes, you should comment out the inactive options.
| DISTILLATION_LOSS_MODE="jsd_topk" | |
| DISTILLATION_LOSS_MODE="k3" | |
| DISTILLATION_LOSS_MODE="reverse_kl_topk+" | |
| # DISTILLATION_LOSS_MODE="jsd_topk" | |
| # DISTILLATION_LOSS_MODE="k3" | |
| DISTILLATION_LOSS_MODE="reverse_kl_topk+" |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces on-policy distillation to the FSDP engine, adding support for top-k and KL estimator-based distillation losses. The changes are extensive, touching configuration, training loop logic, worker implementations, and loss calculations. A critical bug was found in the distillation loss calculation where a variable was used without being initialized in all code paths, which would lead to a runtime error. The rest of the implementation, including the complex data flow for distillation and the new validation checks, appears solid.
| sequence_lens = prompt_lens + response_lens | ||
| sequence_offsets = sequence_lens.cumsum(dim=0) | ||
| assert sequence_offsets[-1].item() == values.shape[0] | ||
| assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}" |
There was a problem hiding this comment.
This assertion correctly identifies that the slicing logic seq_offset - resp_len - 1 assumes prompt_len > 0. However, if a dataset contains empty prompts, this will crash the training. It would be more robust to handle the case of prompt_len == 0 gracefully within the slicing logic instead of asserting. If empty prompts are not expected and should be filtered, this check is fine, but handling it in the code would prevent unexpected failures with new datasets.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces on-policy distillation capabilities, a significant and well-implemented feature. The changes are extensive, touching upon configuration, loss functions, worker implementations, and core training logic. The addition of top-k and KL estimator distillation losses is comprehensive. The code is well-structured, with new functionalities organized into a distillation module and supported by new tests and example scripts. My review identified one high-severity issue concerning a latent bug in a utility function that affects handling of empty prompts, which is currently mitigated by an assertion. Addressing this would improve the robustness of the implementation. Overall, this is a solid contribution.
| sequence_lens = prompt_lens + response_lens | ||
| sequence_offsets = sequence_lens.cumsum(dim=0) | ||
| assert sequence_offsets[-1].item() == values.shape[0] | ||
| assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}" |
There was a problem hiding this comment.
The assertion assert not prompt_lens.eq(0).any() correctly prevents a bug in the slicing logic for prompts of length 0. However, this also disallows what could be a valid use case (empty prompts).
The slicing logic on line 146, values[seq_offset - resp_len - 1 : seq_offset - 1], is incorrect when prompt_len is 0, as the start index becomes -1, which leads to incorrect behavior.
While the assertion prevents a crash, it would be more robust to:
- Change the assertion to a
ValueErrorwith a clear error message for users. - Ideally, fix the underlying slicing logic to correctly handle empty prompts if they are expected to be supported.
A more robust check would be:
if torch.any(prompt_lens == 0):
raise ValueError(
"Prompts with length 0 are not supported by the current slicing logic. "
"Please filter out empty prompts from your dataset or update the slicing logic."
)968df34 to
98ba5ad
Compare
@wuxibin89 This PR does not support VLM due to the reliance on |
What does this PR do?
Adds on-policy distillation to FSDP engine with top-k distillation loss and KL estimator distillation losses with supervised and PG-style updates. Teacher logprobs are computed using a vLLM teacher server.
Losses
Updates
Test
Tested with
examples/on_policy_distillation_trainer/run_qwen_gsmk8k.sh.Main results
These experiments compare 3 training runs with student model Qwen2.5-0.5B:
GSM8K eval acc
GSM8K train acc
Distillation loss
Top-k training stability
Clamping the top-k forward KL loss was needed for training stability. These experiments compare 3 types of clamping:
Distillation loss
GSM8K eval acc
GSM8K train acc
Teacher server results
The initial implementation of OPD in this PR treated the teacher the same as the reference model. Teacher logprobs were calculated using the
ActorRolloutRefengine worker. Outside of this section, all results in the PR description use the old engine worker teacher with the supervised version of the update.The following results are for the current version of the PR, which computes teacher logprobs using a vLLM worker teacher. Each uses Qwen2.5-0.5B as the student and Qwen2.5-3B-Instruct as the teacher, with the max value of the distillation loss clamped to 10.
While purple seems best, it also is generating responses that exceed the maximum response length of 512.
Distillation loss
GSM8K eval acc
GSM8K train acc
Response length
Note on reverse KL
Initially, this PR included top-k reverse KL and top-k Jensen-Shannon divergences (JSD interpolates between forward and reverse KL). For the student distribution$q$ and teacher distribution $p$ , the top-k reverse KL is given by
Unfortunately, this was unstable. The reason is because one way to make this loss small is to make$q_i$ as small as possible for all $q_i \in \text{top}-k$ . This can be seen from the logs tracking the amount of mass captured in the top-$k$ probabilities:
Ablation: performance with more lenient parser
Note that the only loss used is the distillation loss (no rewards for correctness on GSM8K). Any increase in the logged rewards=GSM8k accuracy are an indirect result of minimizing the distillation loss. The reason that the base model has Pass@1~=0 is because the default GSM8k answer formatting (
#### 42) is OOD for the model. The base model is answering the questions correctly, but using incorrect formatting, so none of the answers can be parsed. The base model can be evaluated using a reward function that is more lenient on formatting by adding the following to the script:... reward_model.reward_manager=remote \ custom_reward_function.path=tests/experimental/reward_loop/reward_fn.py \ custom_reward_function.name=compute_score_math_verify \ trainer.val_only=TrueThe results are:
Design & Code Changes
ppo_lossfunction