-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[data] feat: major refactor RLHFDataset for multi-modal data #4759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a major refactoring of RLHFDataset to better handle multi-modal data, specifically by moving the responsibility of fetching and processing image and video data from the DataLoader to the AgentLoopWorker. This change effectively reduces communication overhead and mitigates a potential single-point bottleneck in the controller, which is a significant architectural improvement. The changes are consistently applied across agent loops, tests, and utility functions.
My review has identified a couple of areas for improvement related to maintainability and dependency management. Specifically, the reliance on an external qwen_vl_utils module and the use of monkey-patching to extend processor functionality could be made more robust. Addressing these points would further enhance the quality and long-term maintainability of the codebase.
| from qwen_vl_utils import process_vision_info | ||
|
|
||
| if "dataframe" in state: | ||
| del state["dataframe"] | ||
| return state | ||
| images, videos = process_vision_info(messages, image_patch_size=image_patch_size, return_video_metadata=True) | ||
| return images, videos |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import from qwen_vl_utils import process_vision_info introduces a dependency on qwen_vl_utils, which does not seem to be part of the verl package or a standard library. This can lead to portability and dependency management issues, as it relies on the module being present in the Python path, which is fragile.
To improve maintainability and ensure the project is self-contained, consider one of the following approaches:
- Vendor the utility: Copy the necessary code from
qwen_vl_utilsinto theverlproject, for example, underverl/utils/vision/. - Formal dependency: If
qwen_vl_utilsis available as a package, add it as a formal dependency insetup.pyorrequirements.txt.
| # Bind vlm model's get_rope_index method to processor | ||
| processor.config = config | ||
| match processor.__class__.__name__: | ||
| case "Qwen2VLProcessor": | ||
| from transformers.models.qwen2_vl import Qwen2VLModel | ||
|
|
||
| processor.get_rope_index = types.MethodType(Qwen2VLModel.get_rope_index, processor) | ||
| case "Qwen2_5_VLProcessor": | ||
| from transformers.models.qwen2_5_vl import Qwen2_5_VLModel | ||
|
|
||
| processor.get_rope_index = types.MethodType(Qwen2_5_VLModel.get_rope_index, processor) | ||
| case "Qwen3VLProcessor": | ||
| from transformers.models.qwen3_vl import Qwen3VLModel | ||
|
|
||
| processor.get_rope_index = types.MethodType(Qwen3VLModel.get_rope_index, processor) | ||
| case "Glm4vImageProcessor": | ||
| from transformers.models.glm4v import Glm4vModel | ||
|
|
||
| processor.get_rope_index = types.MethodType(Glm4vModel.get_rope_index, processor) | ||
| case _: | ||
| raise ValueError(f"Unsupported processor type: {processor.__class__.__name__}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code uses monkey-patching to dynamically add the get_rope_index method to processor instances based on their class name. While this provides flexibility, it can make the code harder to understand, maintain, and debug, as the methods are not part of the original class definition. This creates an implicit contract that is not obvious to developers who are not familiar with this specific piece of code.
For better maintainability and code clarity, consider using a more explicit design pattern, such as:
- Wrapper Classes: Create wrapper classes for each processor type that encapsulate the processor and add the model-specific logic. This would make the relationship between the processor and the added functionality explicit.
- Factory Function: A factory function could return a specialized object or a tuple of
(processor, rope_index_function)based on the processor type.
What does this PR do?
Refactor RLHFDataset for multi-modal data,
__getitem__return text messages and let agent_workers fetch image and video from storage.- single controller cpu/memory bottleneck in large scale dataset
- agent_loop workers run in each node, avoid single-point bottleneck
TODO