Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
dcaacfe
Fix partial load problem, Add vlm support for trtllm rollout
SchumiDing Jan 31, 2026
0394ab5
Precommit check
SchumiDing Jan 31, 2026
0664ab1
Add check for if the model is vlm in trtllmhttpserver
SchumiDing Jan 31, 2026
bf71c9b
Support latest trtllm
SchumiDing Feb 2, 2026
f6e58b8
Support for qwen2.5 vl
SchumiDing Feb 2, 2026
7af6917
Add trtllm rollout test script
SchumiDing Feb 2, 2026
94c4eb0
Add test_trtllm_rollout workflow to test trtllm_rollout
SchumiDing Feb 2, 2026
25518fe
Add back mistakenly deleted file
SchumiDing Feb 2, 2026
fd007fb
Precommit check
SchumiDing Feb 2, 2026
659ec01
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 4, 2026
55b55dc
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 5, 2026
e2cc50b
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 5, 2026
ca17f8a
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 6, 2026
62af0f2
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 11, 2026
24a6620
Modified to inherit the worker extension class of tensorrt llm
SchumiDing Feb 11, 2026
6f055a2
Modified to inherit the worker extension class of tensorrt llm
SchumiDing Feb 11, 2026
d0b1d1d
fix readability problem of multimodal config
SchumiDing Feb 11, 2026
6b021f4
Remove need for multimodal server config
SchumiDing Feb 11, 2026
a7faa7b
Add vlm unit test into exisiting trtllm unit test
SchumiDing Feb 11, 2026
8519d36
add e2e script to train qwen2.5-vl with trtllm rollout
SchumiDing Feb 11, 2026
9acdcd6
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 12, 2026
5a145a5
Change import statement
SchumiDing Feb 12, 2026
3776338
remove reward config in e2e script
SchumiDing Feb 12, 2026
1706e71
When multi modal input for trtllm, decode with special token first
SchumiDing Feb 12, 2026
90837f3
rever typo
SchumiDing Feb 12, 2026
57506e2
revert typo
SchumiDing Feb 12, 2026
e193d0d
pre commit check
SchumiDing Feb 12, 2026
81050ce
Fix bugs
SchumiDing Feb 27, 2026
91d8c59
Update
SchumiDing Feb 27, 2026
60dd50b
Update
SchumiDing Feb 27, 2026
44ede00
Add
SchumiDing Mar 1, 2026
b816089
Pre commit check
SchumiDing Mar 2, 2026
fba748a
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Mar 2, 2026
f6977db
Add back CI paths
SchumiDing Mar 3, 2026
49fd2a2
Add fixed configuration in example script
SchumiDing Mar 3, 2026
29ba171
limit worker extension to vlm condition
SchumiDing Mar 3, 2026
dc6b4f4
Update
SchumiDing Mar 3, 2026
9b30172
Merge branch 'main' into vlm_trtllm_support
SchumiDing Mar 9, 2026
7c89354
Pre-commit-checl
SchumiDing Mar 9, 2026
fcd8275
Update
SchumiDing Mar 9, 2026
d8c0c35
Update examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh
SchumiDing Mar 9, 2026
194e00a
Update examples/grpo_trainer/run_qwen2_5_vl_3b_trtllm.sh
SchumiDing Mar 9, 2026
4eaa89d
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Mar 12, 2026
cb053d9
Fix some issue with qwenvl compatability with tensorrt llm
SchumiDing Mar 13, 2026
b93307d
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Mar 13, 2026
d18d672
Merge branch 'vlm_trtllm_support' of https://github.com/SchumiDing/ve…
SchumiDing Mar 13, 2026
e6fa3db
fix wrong geo3k path
SchumiDing Mar 13, 2026
346a759
Pre-commit check
SchumiDing Mar 13, 2026
44be305
Merge commonly used function
SchumiDing Mar 13, 2026
07f1a65
Add install for Qwen_vl_utils
SchumiDing Mar 13, 2026
8446885
Update
SchumiDing Mar 13, 2026
9f391f3
fix
SchumiDing Mar 13, 2026
9844842
Add pip install mathruler for geo3k reward function
SchumiDing Mar 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/e2e_ppo_grpo_trainer_trtllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,11 @@ jobs:
run: |
pip3 install -r requirements-test.txt
pip3 install --no-deps -e .
pip3 install qwen_vl_utils
pip3 install mathruler
- name: Prepare GEO3K dataset
run: |
python3 examples/data_preprocess/geo3k.py --local_dataset_path ${HOME}/models/hf_data/geo3k --local_save_dir ${PWD}/data/geo3k
python3 examples/data_preprocess/geo3k.py --local_dataset_path ${HOME}/models/hf_data/hiyouga/geometry3k --local_save_dir ${PWD}/data/geo3k
- name: Running GEO3K E2E training tests with FSDP on 8 L20 GPUs (VLM)
run: |
ray stop --force
Expand Down
38 changes: 23 additions & 15 deletions verl/workers/rollout/trtllm_rollout/trtllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
from verl.workers.rollout.trtllm_rollout.trtllm_rollout import ServerAdapter
from verl.workers.rollout.utils import get_max_position_embeddings, run_uvicorn
from verl.workers.rollout.utils import get_max_position_embeddings, qwen2_5_vl_dedup_image_tokens, run_uvicorn

logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -204,13 +204,25 @@ async def launch_server(self):
)

self.llm = await AsyncLLM(**llm_kwargs)
trtllm_server = OpenAIServer(
generator=self.llm,
model=self.model_config.local_path,
tool_parser=None,
server_role=None,
metadata_server_cfg=None,
)
import inspect

init_params = inspect.signature(OpenAIServer.__init__).parameters
if "generator" in init_params:
trtllm_server = OpenAIServer(
generator=self.llm,
model=self.model_config.local_path,
tool_parser=None,
server_role=None,
metadata_server_cfg=None,
)
else:
trtllm_server = OpenAIServer(
llm=self.llm,
model=self.model_config.local_path,
tool_parser=None,
server_role=None,
metadata_server_cfg=None,
)

app = trtllm_server.app
self._server_port, self._server_task = await run_uvicorn(app, None, self._server_address)
Expand All @@ -234,7 +246,8 @@ async def generate(

trt_llm_sampling_params = SamplingParams(**sampling_params)
if self.is_vlm_model and (image_data or video_data):
org_prompt = self.llm.tokenizer.decode(prompt_ids)
deduped_ids = qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)
org_prompt = self.llm.tokenizer.decode(deduped_ids)
input_dict = {
"prompt": org_prompt,
"multi_modal_data": {},
Expand Down Expand Up @@ -395,12 +408,7 @@ async def launch_servers(self):
node_id=node_id,
soft=False,
),
runtime_env={
"env_vars": {
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1",
"NCCL_CUMEM_ENABLE": "0",
}
},
runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", "NCCL_CUMEM_ENABLE": "0"}},
name=name,
max_concurrency=self.max_concurrency,
).remote(
Expand Down
21 changes: 21 additions & 0 deletions verl/workers/rollout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import asyncio
import logging

import numpy as np
import uvicorn
from fastapi import FastAPI

Expand Down Expand Up @@ -80,3 +81,23 @@ async def ensure_async_iterator(iterable):
else:
for item in iterable:
yield item


def qwen2_5_vl_dedup_image_tokens(prompt_ids: list[int], processor):
"""Deduplicate consecutive image tokens in prompt_ids for Qwen2.5-VL, since vLLM will replicate the
<|image_pad|> and <|video_pad|> token by image_data.
For example,
```
<|vision_start|><|image_pad|><|image_pad|>...<|image_pad|><|vision_end|>
=>
<|vision_start|><|image_pad|><|vision_end|>
```
"""
if processor is not None and "Qwen2VLImageProcessor" in processor.image_processor.__class__.__name__:
prompt_ids = np.array(prompt_ids)
mask = np.ones(len(prompt_ids), dtype=bool)
is_value = (prompt_ids == processor.image_token_id) | (prompt_ids == processor.video_token_id)
mask[1:] &= ~(is_value[1:] & is_value[:-1])
return prompt_ids[mask].tolist()
else:
return prompt_ids
33 changes: 2 additions & 31 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from pprint import pprint
from typing import Any, Callable, Optional

import numpy as np
import ray
import vllm.entrypoints.cli.serve
from packaging import version
Expand All @@ -43,7 +42,7 @@
from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches
from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
from verl.workers.rollout.utils import get_max_position_embeddings, run_uvicorn
from verl.workers.rollout.utils import get_max_position_embeddings, qwen2_5_vl_dedup_image_tokens, run_uvicorn
from verl.workers.rollout.vllm_rollout.utils import (
VLLM_LORA_INT_ID,
VLLM_LORA_NAME,
Expand Down Expand Up @@ -548,7 +547,7 @@ async def generate(
sampling_params["logprobs"] = 0 if sampling_params.pop("logprobs", False) else None
sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0))
sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)
prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)
prompt_ids = qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)
multi_modal_data = {}
if image_data is not None:
multi_modal_data["image"] = image_data
Expand Down Expand Up @@ -940,31 +939,3 @@ async def abort_request(self, request_id: str) -> dict[str, Any]:
return r

return {"aborted": False, "request_id": request_id, "error": "Request not found on any server"}


def _qwen2_5_vl_dedup_image_tokens(prompt_ids: list[int], processor):
"""Deduplicate consecutive image tokens in prompt_ids for Qwen2.5-VL, since vLLM will replicate the
<|image_pad|> and <|video_pad|> token by image_data.

For example,
```
<|vision_start|><|image_pad|><|image_pad|>...<|image_pad|><|vision_end|>
=>
<|vision_start|><|image_pad|><|vision_end|>
```
"""
if processor is not None and "Qwen2VLImageProcessor" in processor.image_processor.__class__.__name__:
prompt_ids = np.array(prompt_ids)

# Create a mask where True indicates elements to keep
mask = np.ones(len(prompt_ids), dtype=bool)

# Find where the array equals the value
is_value = (prompt_ids == processor.image_token_id) | (prompt_ids == processor.video_token_id)

# Find consecutive duplicates by checking if previous element is also the value
mask[1:] &= ~(is_value[1:] & is_value[:-1])

return prompt_ids[mask].tolist()
else:
return prompt_ids
Loading