[vllm,trainer,algo] feat: Enable On-Policy Distillation for VLM#5592
[vllm,trainer,algo] feat: Enable On-Policy Distillation for VLM#5592JacobHelwig wants to merge 180 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a significant feature: on-policy distillation (OPD) for vLLM, with a focus on Vision Language Models. The core of the change is a new compute_logprobs method on the vLLMHttpServer that allows for scoring student-generated tokens with a teacher model, even for multi-modal inputs. The implementation includes a new distillation module with configurable loss functions, a TeacherModelManager for handling teacher model replicas, and updates across the training pipeline to integrate this new capability. The changes are extensive and well-thought-out. My review includes a couple of suggestions to improve the robustness and usability of the new functionality.
| assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}" | ||
|
|
||
| response_list = [] | ||
| # Skip padding dimensions after sequence dimensions, if any. | ||
| skip_padding = (0, 0) * (values.ndim - 1) | ||
| for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True): | ||
| pad_size = max_response_len - resp_len | ||
| # left-shift model output by one token for log_probs/values | ||
| response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (0, pad_size))) | ||
| response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (*skip_padding, 0, pad_size))) |
There was a problem hiding this comment.
This assertion prevents the function from handling empty prompts, which might be a valid use case. The indexing seq_offset - resp_len - 1 on line 127 is incorrect when prompt_len is 0, as there is no token preceding the response to compute the first log-probability.
To make this function more robust, consider removing the assertion and handling the empty prompt case explicitly. When the prompt is empty, you could pad for the missing first log-probability.
| assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}" | |
| response_list = [] | |
| # Skip padding dimensions after sequence dimensions, if any. | |
| skip_padding = (0, 0) * (values.ndim - 1) | |
| for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True): | |
| pad_size = max_response_len - resp_len | |
| # left-shift model output by one token for log_probs/values | |
| response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (0, pad_size))) | |
| response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (*skip_padding, 0, pad_size))) | |
| response_list = [] | |
| # Skip padding dimensions after sequence dimensions, if any. | |
| skip_padding = (0, 0) * (values.ndim - 1) | |
| for i, (resp_len, seq_offset) in enumerate(zip(response_lens, sequence_offsets, strict=True)): | |
| pad_size = max_response_len - resp_len | |
| if resp_len == 0: | |
| shape = (0,) + values.shape[1:] | |
| response_list.append(torch.zeros(shape, dtype=values.dtype, device=values.device)) | |
| continue | |
| if prompt_lens[i] > 0: | |
| # left-shift model output by one token for log_probs/values | |
| item = values[seq_offset - resp_len - 1 : seq_offset - 1] | |
| response_list.append(F.pad(item, (*skip_padding, 0, pad_size))) | |
| else: # empty prompt | |
| # Handle empty prompt: logprob for the first token is not available. | |
| item = values[seq_offset - resp_len : seq_offset - 1] | |
| # Pad for the missing first logprob and for sequence length alignment. | |
| response_list.append(F.pad(item, (*skip_padding, 1, pad_size))) |
| if temp != 1.0: | ||
| raise NotImplementedError("vLLM doesn't support temperature for prompt logprobs") |
There was a problem hiding this comment.
Raising a NotImplementedError if the temperature is not 1.0 can lead to unexpected runtime failures, especially since the teacher's temperature is often inherited from the student's configuration. A user might not be aware of this vLLM limitation.
To improve usability, consider logging a warning and overriding the temperature to 1.0 instead of raising an error. This would prevent crashes while still informing the user about the limitation.
| if temp != 1.0: | |
| raise NotImplementedError("vLLM doesn't support temperature for prompt logprobs") | |
| if temp != 1.0: | |
| logger.warning("vLLM doesn't support temperature for prompt logprobs. Overriding to 1.0.") | |
| temp = 1.0 |
ce38142 to
04c3c69
Compare
04c3c69 to
8dc7fd5
Compare
What does this PR do?
Adds support for OPD with VLM student and teacher.
Test
Tested with
examples/on_policy_distillation_trainer/run_qwen3_vl_geo3k.sh.Geo3K eval acc
Geo3K train acc
Distillation loss
Design & Code Changes
This PR is stacked on #4897 . Here's the diff between the two branches: JacobHelwig/verl@jhelwig/onPolicyDistillation...JacobHelwig:verl:jhelwig/opdServer.
#4897 submits requests to the
vLLMHttpServervia thev1/completionsendpoint, which does not support multi-modal data. Whilev1/chat/completionsdoes support multi-modal inputs, text must be passed as raw text instead of token IDs, preventing exact scoring of student generations sincestudent gen IDs -> student gen text -> teacher input IDs via v1/chat/completions tokenizationwill not always givestudent gen IDs == teacher input IDs(https://vllm.ai/blog/agent-lightning).This PR instead follows a path similar to how rollout replicas directly call the
generatemethod on thevLLMHttpServer. This enables multi-modal inputs while representing text as token IDs. Requests to the teacher server now call the newly-addedcompute_logprobsmethod ofvLLMHttpServer.