Skip to content

Conversation

@wuxibin89
Copy link
Collaborator

@wuxibin89 wuxibin89 commented Dec 11, 2025

What does this PR do?

Fix RL model engine for VLM.

Qwen/Qwen3-VL-30B-A3B-Instruct fsdp vs megatron on geo3k:
image

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue with handling multi_modal_inputs for VLM model engines by correctly transposing them into a NonTensorStack. It also includes a good refactoring in verl/workers/engine_workers.py to centralize device placement logic by moving inference results to the CPU and removing redundant .cpu() calls. My review includes one suggestion in verl/protocol.py to improve the memory efficiency of key collection, which is an important consideration for large-scale models.

@wuxibin89
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several important fixes and refactorings to improve support for Vision Language Models (VLMs), particularly around handling multi_modal_inputs and variable-length sequences. The core change is to group multi-modal data under a single multi_modal_inputs key, which is a cleaner data structure. The PR correctly disables pin_memory in dataloaders to prevent crashes with NestedTensors, and adds a necessary workaround for chunking TensorDicts containing 3D jagged tensors. My review identifies one critical issue in the updated data collation logic that could lead to a KeyError when processing mixed batches of VLM and text-only data. A code suggestion is provided to fix this.

Comment on lines +68 to +73
if isinstance(batch[0][key], torch.Tensor):
tensors = [item[key] for item in batch]
final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
else:
tensors = [NonTensorData(item.get(key)) for item in batch]
final_batch[key] = torch.stack(tensors, dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic isinstance(batch[0][key], torch.Tensor) is not robust and can lead to a KeyError. The tensor_keys set is a union of keys from all samples in the batch. If a key (e.g., multi_modal_inputs) is present in some samples but not in batch[0], accessing batch[0][key] will cause a crash. This is likely to happen when a batch mixes vision-language and text-only data.

Checking for the key's existence in batch[0] before checking its type will prevent this crash.

Suggested change
if isinstance(batch[0][key], torch.Tensor):
tensors = [item[key] for item in batch]
final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
else:
tensors = [NonTensorData(item.get(key)) for item in batch]
final_batch[key] = torch.stack(tensors, dim=0)
if key in batch[0] and isinstance(batch[0][key], torch.Tensor):
tensors = [item[key] for item in batch]
final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
else:
tensors = [NonTensorData(item.get(key)) for item in batch]
final_batch[key] = torch.stack(tensors, dim=0)

@wuxibin89 wuxibin89 merged commit fdf0046 into volcengine:main Dec 16, 2025
77 of 81 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants