Skip to content

Commit 36fd650

Browse files
lukehindspre-commit-ci[bot]ericwb
authored
Pytorch Load / Save Plugin (#1114)
* Pytorch Load / Save Plugin This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load` with untrusted data can lead to arbitrary code execution, and improper use of `torch.save` might expose sensitive data or lead to data corruption. Signed-off-by: Luke Hinds <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add missing save check Signed-off-by: Luke Hinds <[email protected]> * Review fixes from 8b92a02 Signed-off-by: Luke Hinds <[email protected]> * Fix tox issues Signed-off-by: Luke Hinds <[email protected]> * Review fixes Signed-off-by: Luke Hinds <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_functional.py * Update bandit/plugins/pytorch_load_save.py Co-authored-by: Eric Brown <[email protected]> * Update bandit/plugins/pytorch_load_save.py Co-authored-by: Eric Brown <[email protected]> * Update doc/source/plugins/b704_pytorch_load_save.rst Co-authored-by: Eric Brown <[email protected]> * Update bandit/plugins/pytorch_load_save.py Co-authored-by: Eric Brown <[email protected]> --------- Signed-off-by: Luke Hinds <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Brown <[email protected]>
1 parent 4ac55df commit 36fd650

File tree

5 files changed

+109
-0
lines changed

5 files changed

+109
-0
lines changed

bandit/plugins/pytorch_load_save.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) 2024 Stacklok, Inc.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
r"""
5+
==========================================
6+
B614: Test for unsafe PyTorch load or save
7+
==========================================
8+
9+
This plugin checks for the use of `torch.load` and `torch.save`. Using
10+
`torch.load` with untrusted data can lead to arbitrary code execution, and
11+
improper use of `torch.save` might expose sensitive data or lead to data
12+
corruption. A safe alternative is to use `torch.load` with the `safetensors`
13+
library from hugingface, which provides a safe deserialization mechanism.
14+
15+
:Example:
16+
17+
.. code-block:: none
18+
19+
>> Issue: Use of unsafe PyTorch load or save
20+
Severity: Medium Confidence: High
21+
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
22+
Location: examples/pytorch_load_save.py:8
23+
7 loaded_model.load_state_dict(torch.load('model_weights.pth'))
24+
8 another_model.load_state_dict(torch.load('model_weights.pth',
25+
map_location='cpu'))
26+
9
27+
10 print("Model loaded successfully!")
28+
29+
.. seealso::
30+
31+
- https://cwe.mitre.org/data/definitions/94.html
32+
- https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
33+
- https://github.com/huggingface/safetensors
34+
35+
.. versionadded:: 1.7.10
36+
37+
"""
38+
import bandit
39+
from bandit.core import issue
40+
from bandit.core import test_properties as test
41+
42+
43+
@test.checks("Call")
44+
@test.test_id("B614")
45+
def pytorch_load_save(context):
46+
"""
47+
This plugin checks for the use of `torch.load` and `torch.save`. Using
48+
`torch.load` with untrusted data can lead to arbitrary code execution,
49+
and improper use of `torch.save` might expose sensitive data or lead
50+
to data corruption.
51+
"""
52+
imported = context.is_module_imported_exact("torch")
53+
qualname = context.call_function_name_qual
54+
if not imported and isinstance(qualname, str):
55+
return
56+
57+
qualname_list = qualname.split(".")
58+
func = qualname_list[-1]
59+
if all(
60+
[
61+
"torch" in qualname_list,
62+
func in ["load", "save"],
63+
not context.check_call_arg_value("map_location", "cpu"),
64+
]
65+
):
66+
return bandit.Issue(
67+
severity=bandit.MEDIUM,
68+
confidence=bandit.HIGH,
69+
text="Use of unsafe PyTorch load or save",
70+
cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
71+
lineno=context.get_lineno_for_call_arg("load"),
72+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
-----------------------
2+
B614: pytorch_load_save
3+
-----------------------
4+
5+
.. automodule:: bandit.plugins.pytorch_load_save

examples/pytorch_load_save.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
import torchvision.models as models
3+
4+
# Example of saving a model
5+
model = models.resnet18(pretrained=True)
6+
torch.save(model.state_dict(), 'model_weights.pth')
7+
8+
# Example of loading the model weights in an insecure way
9+
loaded_model = models.resnet18()
10+
loaded_model.load_state_dict(torch.load('model_weights.pth'))
11+
12+
# Save the model
13+
torch.save(loaded_model.state_dict(), 'model_weights.pth')
14+
15+
# Another example using torch.load with more parameters
16+
another_model = models.resnet18()
17+
another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
18+
19+
# Save the model
20+
torch.save(another_model.state_dict(), 'model_weights.pth')
21+

setup.cfg

+3
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ bandit.plugins =
152152
#bandit/plugins/tarfile_unsafe_members.py
153153
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members
154154

155+
#bandit/plugins/pytorch_load_save.py
156+
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save
157+
155158
# bandit/plugins/trojansource.py
156159
trojansource = bandit.plugins.trojansource:trojansource
157160

tests/functional/test_functional.py

+8
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,14 @@ def test_tarfile_unsafe_members(self):
930930
}
931931
self.check_example("tarfile_extractall.py", expect)
932932

933+
def test_pytorch_load_save(self):
934+
"""Test insecure usage of torch.load and torch.save."""
935+
expect = {
936+
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0},
937+
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4},
938+
}
939+
self.check_example("pytorch_load_save.py", expect)
940+
933941
def test_trojansource(self):
934942
expect = {
935943
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 1},

0 commit comments

Comments
 (0)