Skip to content

[vllm,trainer,algo] feat: Enable On-Policy Distillation for VLM#5592

Draft
JacobHelwig wants to merge 180 commits intoverl-project:mainfrom
JacobHelwig:jhelwig/opdServer
Draft

[vllm,trainer,algo] feat: Enable On-Policy Distillation for VLM#5592
JacobHelwig wants to merge 180 commits intoverl-project:mainfrom
JacobHelwig:jhelwig/opdServer

Conversation

@JacobHelwig
Copy link
Collaborator

What does this PR do?

Adds support for OPD with VLM student and teacher.

Test

Tested with examples/on_policy_distillation_trainer/run_qwen3_vl_geo3k.sh.

  • Data: Geometry3K
  • Student: Qwen3-VL-2B-Instruct
  • Teacher: Qwen3-VL-4B-Instruct
  • OPD algo: k1 KL estimator as reward with policy gradient loss

Geo3K eval acc

image

Geo3K train acc

image

Distillation loss

image

Design & Code Changes

This PR is stacked on #4897 . Here's the diff between the two branches: JacobHelwig/verl@jhelwig/onPolicyDistillation...JacobHelwig:verl:jhelwig/opdServer.

#4897 submits requests to the vLLMHttpServer via the v1/completions endpoint, which does not support multi-modal data. While v1/chat/completions does support multi-modal inputs, text must be passed as raw text instead of token IDs, preventing exact scoring of student generations since student gen IDs -> student gen text -> teacher input IDs via v1/chat/completions tokenization will not always give student gen IDs == teacher input IDs (https://vllm.ai/blog/agent-lightning).

This PR instead follows a path similar to how rollout replicas directly call the generate method on the vLLMHttpServer. This enables multi-modal inputs while representing text as token IDs. Requests to the teacher server now call the newly-added compute_logprobs method of vLLMHttpServer.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature: on-policy distillation (OPD) for vLLM, with a focus on Vision Language Models. The core of the change is a new compute_logprobs method on the vLLMHttpServer that allows for scoring student-generated tokens with a teacher model, even for multi-modal inputs. The implementation includes a new distillation module with configurable loss functions, a TeacherModelManager for handling teacher model replicas, and updates across the training pipeline to integrate this new capability. The changes are extensive and well-thought-out. My review includes a couple of suggestions to improve the robustness and usability of the new functionality.

Comment on lines +119 to +127
assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}"

response_list = []
# Skip padding dimensions after sequence dimensions, if any.
skip_padding = (0, 0) * (values.ndim - 1)
for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True):
pad_size = max_response_len - resp_len
# left-shift model output by one token for log_probs/values
response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (0, pad_size)))
response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (*skip_padding, 0, pad_size)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion prevents the function from handling empty prompts, which might be a valid use case. The indexing seq_offset - resp_len - 1 on line 127 is incorrect when prompt_len is 0, as there is no token preceding the response to compute the first log-probability.

To make this function more robust, consider removing the assertion and handling the empty prompt case explicitly. When the prompt is empty, you could pad for the missing first log-probability.

Suggested change
assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}"
response_list = []
# Skip padding dimensions after sequence dimensions, if any.
skip_padding = (0, 0) * (values.ndim - 1)
for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True):
pad_size = max_response_len - resp_len
# left-shift model output by one token for log_probs/values
response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (0, pad_size)))
response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (*skip_padding, 0, pad_size)))
response_list = []
# Skip padding dimensions after sequence dimensions, if any.
skip_padding = (0, 0) * (values.ndim - 1)
for i, (resp_len, seq_offset) in enumerate(zip(response_lens, sequence_offsets, strict=True)):
pad_size = max_response_len - resp_len
if resp_len == 0:
shape = (0,) + values.shape[1:]
response_list.append(torch.zeros(shape, dtype=values.dtype, device=values.device))
continue
if prompt_lens[i] > 0:
# left-shift model output by one token for log_probs/values
item = values[seq_offset - resp_len - 1 : seq_offset - 1]
response_list.append(F.pad(item, (*skip_padding, 0, pad_size)))
else: # empty prompt
# Handle empty prompt: logprob for the first token is not available.
item = values[seq_offset - resp_len : seq_offset - 1]
# Pad for the missing first logprob and for sequence length alignment.
response_list.append(F.pad(item, (*skip_padding, 1, pad_size)))

Comment on lines +649 to +650
if temp != 1.0:
raise NotImplementedError("vLLM doesn't support temperature for prompt logprobs")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Raising a NotImplementedError if the temperature is not 1.0 can lead to unexpected runtime failures, especially since the teacher's temperature is often inherited from the student's configuration. A user might not be aware of this vLLM limitation.

To improve usability, consider logging a warning and overriding the temperature to 1.0 instead of raising an error. This would prevent crashes while still informing the user about the limitation.

Suggested change
if temp != 1.0:
raise NotImplementedError("vLLM doesn't support temperature for prompt logprobs")
if temp != 1.0:
logger.warning("vLLM doesn't support temperature for prompt logprobs. Overriding to 1.0.")
temp = 1.0

@JacobHelwig JacobHelwig changed the title [vllm,trainer,algo] feat: Enable On-Policy Distillation for vLLM [vllm,trainer,algo] feat: Enable On-Policy Distillation for VLM Mar 14, 2026
@mergify
Copy link

mergify bot commented Mar 14, 2026

⚠️ The sha of the head commit of this PR conflicts with #5164. Mergify cannot evaluate rules on this PR. Once #5164 is merged or closed, Mergify will resume processing this PR. ⚠️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant