-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[trainer] fix: model engine vlm multi_modal_inputs to NonTensorStack #4492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an issue with handling multi_modal_inputs for VLM model engines by correctly transposing them into a NonTensorStack. It also includes a good refactoring in verl/workers/engine_workers.py to centralize device placement logic by moving inference results to the CPU and removing redundant .cpu() calls. My review includes one suggestion in verl/protocol.py to improve the memory efficiency of key collection, which is an important consideration for large-scale models.
8358145 to
954033c
Compare
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several important fixes and refactorings to improve support for Vision Language Models (VLMs), particularly around handling multi_modal_inputs and variable-length sequences. The core change is to group multi-modal data under a single multi_modal_inputs key, which is a cleaner data structure. The PR correctly disables pin_memory in dataloaders to prevent crashes with NestedTensors, and adds a necessary workaround for chunking TensorDicts containing 3D jagged tensors. My review identifies one critical issue in the updated data collation logic that could lead to a KeyError when processing mixed batches of VLM and text-only data. A code suggestion is provided to fix this.
| if isinstance(batch[0][key], torch.Tensor): | ||
| tensors = [item[key] for item in batch] | ||
| final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) | ||
| else: | ||
| tensors = [NonTensorData(item.get(key)) for item in batch] | ||
| final_batch[key] = torch.stack(tensors, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic isinstance(batch[0][key], torch.Tensor) is not robust and can lead to a KeyError. The tensor_keys set is a union of keys from all samples in the batch. If a key (e.g., multi_modal_inputs) is present in some samples but not in batch[0], accessing batch[0][key] will cause a crash. This is likely to happen when a batch mixes vision-language and text-only data.
Checking for the key's existence in batch[0] before checking its type will prevent this crash.
| if isinstance(batch[0][key], torch.Tensor): | |
| tensors = [item[key] for item in batch] | |
| final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) | |
| else: | |
| tensors = [NonTensorData(item.get(key)) for item in batch] | |
| final_batch[key] = torch.stack(tensors, dim=0) | |
| if key in batch[0] and isinstance(batch[0][key], torch.Tensor): | |
| tensors = [item[key] for item in batch] | |
| final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) | |
| else: | |
| tensors = [NonTensorData(item.get(key)) for item in batch] | |
| final_batch[key] = torch.stack(tensors, dim=0) |
What does this PR do?
Fix RL model engine for VLM.
Qwen/Qwen3-VL-30B-A3B-Instruct fsdp vs megatron on geo3k:
