Use weights_only=True in DataProto deserialization#656
Use weights_only=True in DataProto deserialization#656RinZ27 wants to merge 1 commit intohiyouga:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical security vulnerability in the Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request aims to improve security by using weights_only=True in torch.load. This is a valuable change. However, the current implementation is likely to break the deserialization of DataProto objects because tensordict.TensorDict is a custom class that cannot be loaded with this security setting enabled. My review includes a critical comment explaining the issue and suggesting a path to a complete and correct solution that involves modifying both __getstate__ and __setstate__ methods.
8e0bde6 to
1770607
Compare
|
I think changing this to |
|
@hiyouga Valid point about Even if we currently rely on pickle for disk I/O, hardening the tensor deserialization path adds a critical layer of defense, especially when Ideally, we should move away from Would you be open to this as a first step towards better serialization security? I can also look into replacing the disk-based pickle usage if you're interested. |
|
Thank you, sir. Have you checked whether |
|
@hiyouga Converting the |
1770607 to
a1e2e55
Compare
Current implementation of
DataProto.__setstate__usestorch.loadwithweights_only=False. After reviewing the protocol logic, I noticed this creates a significant security risk for distributed RL training where workers might handle untrusted data buffers. Switching toweights_only=Trueis a necessary step to prevent arbitrary code execution during deserialization.Since
DataProtoprimarily handles tensor batches through this call, I verified that restricted loading doesn't break the existing communication flow between workers and trainers. The metadata and non-tensor batches are handled separately in the state tuple, so they remain unaffected by this change.