Skip to content

Commit ca93021

Browse files
adrinjalaliE-Aho
andauthored
FEAT Implement .skops file audit functionality
* FEAT Audit before loading a skops file * WIP * numpy loaders * scipy and src issue * sklearn * make tests pass * remove pickle.py * fix a few issues * add missing files * add get_untrusted_types and docs * minor fix * add more tests * add a smoke test, failing though * implement safety for functions * add missing Tree children for audit * add missing SparseMatrixNode children for audit * tests pass * remove safety tree code * minor test * fix ids in test * move type ignore * add more tests and some docs * more comments * fix recursive dump and get_untrusted_set * Ben's comments * address comments: sentinel, contextmanager, sorted * move all children to the children attribute * add complex pipeline test * apply Ben's suggestions * Card object should pass trusted to load * Update skops/io/_dispatch.py Co-authored-by: Erin Aho <[email protected]>
1 parent 757b940 commit ca93021

16 files changed

+1121
-333
lines changed

docs/persistence.rst

+32-15
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,37 @@ The code snippet below illustrates how to use :func:`skops.io.dump` and
4949
clf.fit(X_train, y_train)
5050
dump(clf, "my-logistic-regression.skops")
5151
# ...
52-
loaded = load("my-logistic-regression.skops")
52+
loaded = load("my-logistic-regression.skops", trusted=True)
5353
loaded.predict(X_test)
5454
5555
# in memory
5656
from skops.io import dumps, loads
5757
serialized = dumps(clf)
58-
loaded = loads(serialized)
58+
loaded = loads(serialized, trusted=True)
5959
60-
At the moment, we support the vast majority of sklearn estimators. This includes
61-
complex use cases such as :class:`sklearn.pipeline.Pipeline`,
60+
Note that you should only load files with ``trusted=True`` if you trust the
61+
source. Otherwise you can get a list of untrusted types present in the dump
62+
using :func:`skops.io.get_untrusted_types`:
63+
64+
.. code:: python
65+
66+
from skops.io import get_untrusted_types
67+
unknown_types = get_untrusted_types(file="my-logistic-regression.skops")
68+
print(unknown_types)
69+
70+
Once you check the list and you validate that everything in the list is safe,
71+
you can load the file with ``trusted=unknown_types``:
72+
73+
.. code:: python
74+
75+
loaded = load("my-logistic-regression.skops", trusted=unknown_types)
76+
77+
At the moment, we support the vast majority of sklearn estimators. This
78+
includes complex use cases such as :class:`sklearn.pipeline.Pipeline`,
6279
:class:`sklearn.model_selection.GridSearchCV`, classes using Cython code, such
63-
as :class:`sklearn.tree.DecisionTreeClassifier`, and more. If you discover an sklearn
64-
estimator that does not work, please open an issue on the skops `GitHub page
65-
<https://github.com/skops-dev/skops/issues>`_ and let us know.
80+
as :class:`sklearn.tree.DecisionTreeClassifier`, and more. If you discover an
81+
sklearn estimator that does not work, please open an issue on the skops `GitHub
82+
page <https://github.com/skops-dev/skops/issues>`_ and let us know.
6683

6784
In contrast to ``pickle``, skops cannot persist arbitrary Python code. This
6885
means if you have custom functions (say, a custom function to be used with
@@ -74,16 +91,16 @@ Roadmap
7491
-------
7592

7693
Currently, it is still possible to run insecure code when using skops
77-
persistence. For example, it's possible to load a save file that evaluates arbitrary
78-
code using :func:`eval`. However, we have concrete plans on how to mitigate
79-
this, so please stay updated.
94+
persistence. For example, it's possible to load a save file that evaluates
95+
arbitrary code using :func:`eval`. However, we have concrete plans on how to
96+
mitigate this, so please stay updated.
8097

8198
On top of trying to support persisting all relevant sklearn objects, we plan on
82-
making persistence extensible for other libraries. As a user, this means that if
83-
you trust a certain library, you will be able to tell skops to load code from
84-
that library. As a library author, there will be a clear path of what needs to
85-
be done to add secure persistence to your library, such that skops can save and
86-
load code from your library.
99+
making persistence extensible for other libraries. As a user, this means that
100+
if you trust a certain library, you will be able to tell skops to load code
101+
from that library. As a library author, there will be a clear path of what
102+
needs to be done to add secure persistence to your library, such that skops can
103+
save and load code from your library.
87104

88105
To follow what features are currently planned, filter for the `"persistence"
89106
label <https://github.com/skops-dev/skops/labels/persistence>`_ in our GitHub

skops/card/_model_card.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -165,20 +165,25 @@ def metadata_from_config(config_path: Union[str, Path]) -> ModelCardData:
165165
return card_data
166166

167167

168-
def _load_model(model: Any) -> Any:
169-
"""Loads the mddel if provided a file path, if already a model instance return it
170-
unmodified.
168+
def _load_model(model: Any, trusted=False) -> Any:
169+
"""Return a model instance.
170+
171+
Loads the model if provided a file path, if already a model instance return
172+
it unmodified.
171173
172174
Parameters
173175
----------
174176
model : pathlib.Path, str, or sklearn estimator
175177
Path/str or the actual model instance. if a Path or str, loads the model.
176178
179+
trusted : bool, default=False
180+
Passed to :func:`skops.io.load` if the model is a file path and it's
181+
a `skops` file.
182+
177183
Returns
178184
-------
179185
model : object
180186
Model instance.
181-
182187
"""
183188

184189
if not isinstance(model, (Path, str)):
@@ -190,11 +195,11 @@ def _load_model(model: Any) -> Any:
190195

191196
try:
192197
if zipfile.is_zipfile(model_path):
193-
model = load(model_path)
198+
model = load(model_path, trusted=trusted)
194199
else:
195200
model = joblib.load(model_path)
196201
except Exception as ex:
197-
msg = f'An "{type(ex).__name__}" occured during model loading.'
202+
msg = f'An "{type(ex).__name__}" occurred during model loading.'
198203
raise RuntimeError(msg) from ex
199204

200205
return model
@@ -227,6 +232,10 @@ class Card:
227232
of the ``config.json`` file, which itself is created by
228233
:func:`skops.hub_utils.init`.
229234
235+
trusted: bool, default=False
236+
Passed to :func:`skops.io.load` if the model is a file path and it's
237+
a `skops` file.
238+
230239
Attributes
231240
----------
232241
model: estimator object
@@ -294,13 +303,15 @@ def __init__(
294303
model: Any,
295304
model_diagram: bool = True,
296305
metadata: Optional[ModelCardData] = None,
306+
trusted: bool = False,
297307
) -> None:
298308
self.model = model
299309
self.model_diagram = model_diagram
300310
self._eval_results = {} # type: ignore
301311
self._template_sections: dict[str, str] = {}
302312
self._extra_sections: list[tuple[str, Any]] = []
303313
self.metadata = metadata or ModelCardData()
314+
self.trusted = trusted
304315

305316
def get_model(self) -> Any:
306317
"""Returns sklearn estimator object if ``Path``/``str``
@@ -311,7 +322,7 @@ def get_model(self) -> Any:
311322
model : Object
312323
Model instance.
313324
"""
314-
model = _load_model(self.model)
325+
model = _load_model(self.model, self.trusted)
315326
return model
316327

317328
def add(self, **kwargs: str) -> "Card":

skops/card/tests/test_card.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def save_model_to_file(model_instance, suffix):
4040
def test_load_model(suffix):
4141
model0 = LinearRegression(n_jobs=123)
4242
_, save_file = save_model_to_file(model0, suffix)
43-
loaded_model_str = _load_model(save_file)
43+
loaded_model_str = _load_model(save_file, trusted=True)
4444
save_file_path = Path(save_file)
45-
loaded_model_path = _load_model(save_file_path)
46-
loaded_model_instance = _load_model(model0)
45+
loaded_model_path = _load_model(save_file_path, trusted=True)
46+
loaded_model_instance = _load_model(model0, trusted=True)
4747

4848
assert loaded_model_str.n_jobs == 123
4949
assert loaded_model_path.n_jobs == 123
@@ -431,7 +431,7 @@ def test_with_metadata(self, card: Card, meth):
431431

432432
class TestCardModelAttribute:
433433
def path_to_card(self, path):
434-
card = Card(model=path)
434+
card = Card(model=path, trusted=True)
435435
card.add(
436436
model_description="A description",
437437
model_card_authors="Jane Doe",
@@ -470,7 +470,7 @@ def test_load_model_exception(self, meth, suffix):
470470

471471
os.close(file_handle)
472472

473-
with pytest.raises(Exception, match="occured during model loading."):
473+
with pytest.raises(Exception, match="occurred during model loading."):
474474
card = Card(file_name)
475475
meth(card)
476476

skops/io/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from ._persist import dump, dumps, load, loads
1+
from ._persist import dump, dumps, get_untrusted_types, load, loads
22

3-
__all__ = ["dumps", "load", "loads", "dump"]
3+
__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types"]

skops/io/_audit.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from skops.io.exceptions import UntrustedTypesFoundException
2+
3+
4+
def check_type(module_name, type_name, trusted):
5+
"""Check if a type is safe to load.
6+
7+
A type is safe to load only if it's present in the trusted list.
8+
9+
Parameters
10+
----------
11+
module_name : str
12+
The module name of the type.
13+
14+
type_name : str
15+
The class name of the type.
16+
17+
trusted : bool, or list of str
18+
If ``True``, the tree is considered safe. Otherwise trusted has to be
19+
a list of trusted types.
20+
21+
Returns
22+
-------
23+
is_safe : bool
24+
True if the type is safe, False otherwise.
25+
"""
26+
if trusted is True:
27+
return True
28+
return module_name + "." + type_name in trusted
29+
30+
31+
def audit_tree(tree, trusted):
32+
"""Audit a tree of nodes.
33+
34+
A tree is safe if it only contains trusted types. Audit is skipped if
35+
trusted is ``True``.
36+
37+
Parameters
38+
----------
39+
tree : skops.io._dispatch.Node
40+
The tree to audit.
41+
42+
trusted : bool, or list of str
43+
If ``True``, the tree is considered safe. Otherwise trusted has to be
44+
a list of trusted types names.
45+
46+
An entry in the list is typically of the form
47+
``skops.io._utils.get_module(obj) + "." + obj.__class__.__name__``.
48+
49+
Raises
50+
------
51+
UntrustedTypesFoundException
52+
If the tree contains an untrusted type.
53+
"""
54+
if trusted is True:
55+
return
56+
57+
unsafe = tree.get_unsafe_set()
58+
if isinstance(trusted, (list, set)):
59+
unsafe -= set(trusted)
60+
if unsafe:
61+
raise UntrustedTypesFoundException(unsafe)

0 commit comments

Comments
 (0)