diff --git a/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml b/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml index 7f7314d6a1d..61a19d43419 100644 --- a/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml +++ b/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml @@ -228,9 +228,11 @@ jobs: run: | pip3 install -r requirements-test.txt pip3 install --no-deps -e . + pip3 install qwen_vl_utils + pip3 install mathruler - name: Prepare GEO3K dataset run: | - python3 examples/data_preprocess/geo3k.py --local_dataset_path ${HOME}/models/hf_data/geo3k --local_save_dir ${PWD}/data/geo3k + python3 examples/data_preprocess/geo3k.py --local_dataset_path ${HOME}/models/hf_data/hiyouga/geometry3k --local_save_dir ${PWD}/data/geo3k - name: Running GEO3K E2E training tests with FSDP on 8 L20 GPUs (VLM) run: | ray stop --force diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 32f0a5be786..7b2c5ff2594 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -29,7 +29,7 @@ from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput from verl.workers.rollout.trtllm_rollout.trtllm_rollout import ServerAdapter -from verl.workers.rollout.utils import get_max_position_embeddings, run_uvicorn +from verl.workers.rollout.utils import get_max_position_embeddings, qwen2_5_vl_dedup_image_tokens, run_uvicorn logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) @@ -204,13 +204,25 @@ async def launch_server(self): ) self.llm = await AsyncLLM(**llm_kwargs) - trtllm_server = OpenAIServer( - generator=self.llm, - model=self.model_config.local_path, - tool_parser=None, - server_role=None, - metadata_server_cfg=None, - ) + import inspect + + init_params = inspect.signature(OpenAIServer.__init__).parameters + if "generator" in init_params: + trtllm_server = OpenAIServer( + generator=self.llm, + model=self.model_config.local_path, + tool_parser=None, + server_role=None, + metadata_server_cfg=None, + ) + else: + trtllm_server = OpenAIServer( + llm=self.llm, + model=self.model_config.local_path, + tool_parser=None, + server_role=None, + metadata_server_cfg=None, + ) app = trtllm_server.app self._server_port, self._server_task = await run_uvicorn(app, None, self._server_address) @@ -234,7 +246,8 @@ async def generate( trt_llm_sampling_params = SamplingParams(**sampling_params) if self.is_vlm_model and (image_data or video_data): - org_prompt = self.llm.tokenizer.decode(prompt_ids) + deduped_ids = qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor) + org_prompt = self.llm.tokenizer.decode(deduped_ids) input_dict = { "prompt": org_prompt, "multi_modal_data": {}, @@ -395,12 +408,7 @@ async def launch_servers(self): node_id=node_id, soft=False, ), - runtime_env={ - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", - "NCCL_CUMEM_ENABLE": "0", - } - }, + runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", "NCCL_CUMEM_ENABLE": "0"}}, name=name, max_concurrency=self.max_concurrency, ).remote( diff --git a/verl/workers/rollout/utils.py b/verl/workers/rollout/utils.py index 69c688dfa24..efff4af2852 100644 --- a/verl/workers/rollout/utils.py +++ b/verl/workers/rollout/utils.py @@ -14,6 +14,7 @@ import asyncio import logging +import numpy as np import uvicorn from fastapi import FastAPI @@ -80,3 +81,23 @@ async def ensure_async_iterator(iterable): else: for item in iterable: yield item + + +def qwen2_5_vl_dedup_image_tokens(prompt_ids: list[int], processor): + """Deduplicate consecutive image tokens in prompt_ids for Qwen2.5-VL, since vLLM will replicate the + <|image_pad|> and <|video_pad|> token by image_data. + For example, + ``` + <|vision_start|><|image_pad|><|image_pad|>...<|image_pad|><|vision_end|> + => + <|vision_start|><|image_pad|><|vision_end|> + ``` + """ + if processor is not None and "Qwen2VLImageProcessor" in processor.image_processor.__class__.__name__: + prompt_ids = np.array(prompt_ids) + mask = np.ones(len(prompt_ids), dtype=bool) + is_value = (prompt_ids == processor.image_token_id) | (prompt_ids == processor.video_token_id) + mask[1:] &= ~(is_value[1:] & is_value[:-1]) + return prompt_ids[mask].tolist() + else: + return prompt_ids diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 0b776ac5e55..b58d96bcd6b 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -20,7 +20,6 @@ from pprint import pprint from typing import Any, Callable, Optional -import numpy as np import ray import vllm.entrypoints.cli.serve from packaging import version @@ -43,7 +42,7 @@ from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput -from verl.workers.rollout.utils import get_max_position_embeddings, run_uvicorn +from verl.workers.rollout.utils import get_max_position_embeddings, qwen2_5_vl_dedup_image_tokens, run_uvicorn from verl.workers.rollout.vllm_rollout.utils import ( VLLM_LORA_INT_ID, VLLM_LORA_NAME, @@ -548,7 +547,7 @@ async def generate( sampling_params["logprobs"] = 0 if sampling_params.pop("logprobs", False) else None sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0)) sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params) - prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor) + prompt_ids = qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor) multi_modal_data = {} if image_data is not None: multi_modal_data["image"] = image_data @@ -940,31 +939,3 @@ async def abort_request(self, request_id: str) -> dict[str, Any]: return r return {"aborted": False, "request_id": request_id, "error": "Request not found on any server"} - - -def _qwen2_5_vl_dedup_image_tokens(prompt_ids: list[int], processor): - """Deduplicate consecutive image tokens in prompt_ids for Qwen2.5-VL, since vLLM will replicate the - <|image_pad|> and <|video_pad|> token by image_data. - - For example, - ``` - <|vision_start|><|image_pad|><|image_pad|>...<|image_pad|><|vision_end|> - => - <|vision_start|><|image_pad|><|vision_end|> - ``` - """ - if processor is not None and "Qwen2VLImageProcessor" in processor.image_processor.__class__.__name__: - prompt_ids = np.array(prompt_ids) - - # Create a mask where True indicates elements to keep - mask = np.ones(len(prompt_ids), dtype=bool) - - # Find where the array equals the value - is_value = (prompt_ids == processor.image_token_id) | (prompt_ids == processor.video_token_id) - - # Find consecutive duplicates by checking if previous element is also the value - mask[1:] &= ~(is_value[1:] & is_value[:-1]) - - return prompt_ids[mask].tolist() - else: - return prompt_ids