Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def generate(
prompt_ids: list[int],
sampling_params: dict[str, Any],
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
) -> TokenOutput:
"""Generate tokens from prompt ids.

Expand All @@ -113,6 +114,7 @@ async def generate(
prompt_ids=prompt_ids,
sampling_params=sampling_params,
image_data=image_data,
video_data=video_data,
)
return output

Expand Down Expand Up @@ -505,16 +507,28 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
multi_modal_inputs = None
if self.processor is not None:
images = getattr(output, "multi_modal_data", {}).get("image", None)
videos = getattr(output, "multi_modal_data", {}).get("video", None)
if videos is not None:
videos, video_metadatas = zip(*videos, strict=False)
videos, video_metadatas = list(videos), list(video_metadatas)
videos_kwargs = {"video_metadata": video_metadatas, "do_sample_frames": False}
else:
videos_kwargs = {}
Comment on lines +510 to +516
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for processing video data is nearly identical to the one in verl/experimental/agent_loop/tool_agent_loop.py at lines 223-228. Duplicating this logic increases maintenance overhead and the risk of introducing inconsistencies if one is updated and the other is not. Consider refactoring this into a shared helper function to promote code reuse and simplify future modifications. For example, a function like _prepare_video_kwargs(videos) could encapsulate this logic.

current_text = self.tokenizer.decode(input_ids.squeeze(0), skip_special_tokens=True)
multi_modal_inputs = self.processor(text=[current_text], images=images, return_tensors="pt")
multi_modal_inputs = self.processor(
text=[current_text], images=images, videos=videos, return_tensors="pt", do_resize=False, **videos_kwargs
)
multi_modal_inputs.pop("input_ids", None)
multi_modal_inputs.pop("attention_mask", None)

# We must use dict(multi_modal_inputs) to convert BatchFeature values to a new dict
# because np.array() only keeps the keys for BatchFeature.
multi_modal_inputs = dict(multi_modal_inputs.convert_to_tensors("pt"))
if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
from verl.models.transformers.qwen2_vl import get_rope_index
if "Qwen3VLProcessor" in self.processor.__class__.__name__:
from verl.models.transformers.qwen3_vl import get_rope_index
else:
from verl.models.transformers.qwen2_vl import get_rope_index

image_grid_thw = multi_modal_inputs.get("image_grid_thw")
video_grid_thw = multi_modal_inputs.get("video_grid_thw")
Expand Down
22 changes: 20 additions & 2 deletions verl/experimental/agent_loop/tool_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
self,
messages: list[dict[str, Any]],
image_data: Any,
video_data: Any,
metrics: dict[str, Any],
request_id: str,
tools_kwargs: dict[str, Any],
Expand All @@ -67,6 +68,7 @@ def __init__(
):
self.messages = messages
self.image_data = image_data
self.video_data = video_data
self.metrics = metrics
self.request_id = request_id
self.tools_kwargs = tools_kwargs
Expand Down Expand Up @@ -134,6 +136,7 @@ def __init__(
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
messages = list(kwargs["raw_prompt"])
image_data = copy.deepcopy(kwargs.get("multi_modal_data", {}).get("image", None))
video_data = copy.deepcopy(kwargs.get("multi_modal_data", {}).get("video", None))
metrics = {}
request_id = uuid4().hex
tools_kwargs = kwargs.get("tools_kwargs", {})
Expand All @@ -157,6 +160,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
agent_data = AgentData(
messages=messages,
image_data=image_data,
video_data=video_data,
metrics=metrics,
request_id=request_id,
tools_kwargs=tools_kwargs,
Expand All @@ -182,7 +186,11 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
# Finalize output
response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :]
prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)]
multi_modal_data = {"image": agent_data.image_data} if agent_data.image_data is not None else {}
multi_modal_data = {}
if agent_data.image_data is not None:
multi_modal_data["image"] = agent_data.image_data
if agent_data.video_data is not None:
multi_modal_data["video"] = agent_data.video_data
output = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids[: self.response_length],
Expand Down Expand Up @@ -211,7 +219,16 @@ async def _handle_pending_state(self, agent_data: AgentData, sampling_params: di
**self.apply_chat_template_kwargs,
),
)
model_inputs = self.processor(text=[raw_prompt], images=agent_data.image_data, return_tensors="pt")
images, videos = agent_data.image_data, agent_data.video_data
if videos is not None:
videos, video_metadatas = zip(*videos, strict=False)
videos, video_metadatas = list(videos), list(video_metadatas)
videos_kwargs = {"video_metadata": video_metadatas, "do_sample_frames": False}
else:
videos_kwargs = {}
model_inputs = self.processor(
text=[raw_prompt], images=images, videos=videos, return_tensors="pt", do_resize=False, **videos_kwargs
)
agent_data.prompt_ids = model_inputs.pop("input_ids").squeeze(0).tolist()
else:
agent_data.prompt_ids = await self.loop.run_in_executor(
Expand All @@ -238,6 +255,7 @@ async def _handle_generating_state(
prompt_ids=agent_data.prompt_ids,
sampling_params=sampling_params,
image_data=agent_data.image_data,
video_data=agent_data.video_data,
)

agent_data.assistant_turns += 1
Expand Down
71 changes: 63 additions & 8 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ async def generate(
sampling_params: dict[str, Any],
request_id: str,
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
) -> TokenOutput:
"""Generate sequence with token-in-token-out."""
# TODO(@wuxibin): switch to `/generate` http endpoint once multi-modal support ready.
Expand All @@ -476,10 +477,16 @@ async def generate(
sampling_params["logprobs"] = 0 if sampling_params.pop("logprobs", False) else None
sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0))
sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)
prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)
prompt = TokensPrompt(
prompt_token_ids=prompt_ids, multi_modal_data={"image": image_data} if image_data else None
)
if "Qwen3VLProcessor" in self.model_config.processor.__class__.__name__:
prompt_ids = _qwen3_vl_dedup_vision_tokens(prompt_ids, self.model_config.processor, video_data)
else:
prompt_ids = _qwen2_5_vl_dedup_vision_tokens(prompt_ids, self.model_config.processor)
multi_modal_data = {}
if image_data is not None:
multi_modal_data["image"] = image_data
if video_data is not None:
multi_modal_data["video"] = video_data
prompt = TokensPrompt(prompt_token_ids=prompt_ids, multi_modal_data=multi_modal_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The previous implementation passed None for multi_modal_data when no image data was present. The new logic passes an empty dictionary {} when both image_data and video_data are None. This change in behavior (from None to {}) might be unintended and could lead to unexpected issues downstream in the vLLM engine, which expects Optional[Dict]. To maintain the original behavior, consider passing None if the multi_modal_data dictionary is empty.

Suggested change
prompt = TokensPrompt(prompt_token_ids=prompt_ids, multi_modal_data=multi_modal_data)
prompt = TokensPrompt(prompt_token_ids=prompt_ids, multi_modal_data=multi_modal_data if multi_modal_data else None)


# Add lora request
lora_request = None
Expand Down Expand Up @@ -792,15 +799,19 @@ async def abort_request(self, request_id: str) -> dict[str, Any]:
return {"aborted": False, "request_id": request_id, "error": "Request not found on any server"}


def _qwen2_5_vl_dedup_image_tokens(prompt_ids: list[int], processor):
"""Deduplicate consecutive image tokens in prompt_ids for Qwen2.5-VL, since vLLM will replicate the
<|image_pad|> token by image_data.
def _qwen2_5_vl_dedup_vision_tokens(prompt_ids: list[int], processor):
"""Deduplicate consecutive vision tokens (image or video) in prompt_ids for Qwen2.5-VL,
since vLLM will replicate the padding tokens by vision data.

For example,
```
<|vision_start|><|image_pad|><|image_pad|>...<|image_pad|><|vision_end|>
=>
<|vision_start|><|image_pad|><|vision_end|>

<|vision_start|><|video_pad|>...<|vision_end|>
=>
<|vision_start|><|video_pad|><|vision_end|>
```
"""
if processor is not None and "Qwen2VLImageProcessor" in processor.image_processor.__class__.__name__:
Expand All @@ -810,11 +821,55 @@ def _qwen2_5_vl_dedup_image_tokens(prompt_ids: list[int], processor):
mask = np.ones(len(prompt_ids), dtype=bool)

# Find where the array equals the value
is_value = prompt_ids == processor.image_token_id
is_value = (prompt_ids == processor.image_token_id) | (prompt_ids == processor.video_token_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing processor.video_token_id directly could raise an AttributeError if the processor does not support video and thus lacks this attribute. To make this code more robust, you should check for the existence of video_token_id before using it.

A safer implementation would be:

is_value = prompt_ids == processor.image_token_id
if hasattr(processor, "video_token_id"):
    is_value |= (prompt_ids == processor.video_token_id)


# Find consecutive duplicates by checking if previous element is also the value
mask[1:] &= ~(is_value[1:] & is_value[:-1])

return prompt_ids[mask].tolist()
else:
return prompt_ids


def _qwen3_vl_dedup_vision_tokens(prompt_ids: list[int], processor, video_data: Optional[list[Any]] = None):
"""Deduplicate consecutive vision tokens (image or video) in prompt_ids for Qwen3-VL,
since vLLM will replicate the padding tokens by vision data.

For example,
```
<|vision_start|><|image_pad|><|image_pad|>...<|image_pad|><|vision_end|>
=>
<|vision_start|><|image_pad|><|vision_end|>

<0.1 seconds><|vision_start|><|video_pad|>...<|vision_end|>
...<11.3 seconds><|vision_start|><|video_pad|>...<|vision_end|>
=>
<|vision_start|><|video_pad|><|vision_end|>
```
"""

# dedup video placeholder
video_frames = []
if video_data is not None:
for video in video_data:
frame = video[0].shape[0] // 2
video_frames.append(frame)

import re

single_frame_pattern = r"<[\d.]+ seconds><\|vision_start\|>(?:<\|video_pad\|>)+<\|vision_end\|>"
prompt = processor.tokenizer.decode(prompt_ids)
current_prompt = prompt
for num_frames in video_frames:
# Match exactly num_frames repetitions of the single frame pattern
video_sequence_pattern = f"(?:{single_frame_pattern}){{{num_frames}}}"

current_prompt, count = re.subn(
video_sequence_pattern, "<|vision_start|><|video_pad|><|vision_end|>", current_prompt, count=1
)
if count != 1:
logger.warning(f"Expected to deduplicate {num_frames} frames, but found {count} matches.")

prompt_ids = processor.tokenizer.encode(current_prompt)

return _qwen2_5_vl_dedup_vision_tokens(prompt_ids, processor)