Skip to content

Commit 70185a9

Browse files
ENH: Added CLI command to update skops files (#343)
Co-authored-by: Benjamin Bossan <[email protected]>
1 parent 107904c commit 70185a9

File tree

7 files changed

+415
-2
lines changed

7 files changed

+415
-2
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ node_modules
117117
# Vim
118118
*.swp
119119

120+
# MacOS
121+
.DS_Store
120122

121123
exports
122124
trash

docs/changes.rst

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ v0.8
1313
----
1414
- Adds the abillity to set the :attr:`.Section.folded` property when using :meth:`.Card.add`.
1515
:pr:`361` by :user:`Thomas Lazarus <lazarust>`.
16+
- Add the CLI command to update Skops files to the latest Skops persistence format.
17+
(:func:`.cli._update.main`). :pr:`333` by :user:`Edoardo Abati <EdAbati>`
1618

1719
v0.7
1820
----

docs/persistence.rst

+23-2
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,13 @@ for more details.
130130
Command Line Interface
131131
######################
132132

133-
Skops has a command line interface to convert scikit-learn models persisted with
134-
``Pickle`` to ``Skops`` files.
133+
Skops has a command line interface to:
134+
135+
- convert scikit-learn models persisted with ``Pickle`` to ``Skops`` files.
136+
- update ``Skops`` files to the latest version.
137+
138+
``skops convert``
139+
~~~~~~~~~~~~~~~~~
135140

136141
To convert a file from the command line, use the ``skops convert`` entrypoint.
137142

@@ -151,6 +156,22 @@ For example, to convert all ``.pkl`` flies in the current directory:
151156
Further help for the different supported options can be found by calling
152157
``skops convert --help`` in a terminal.
153158

159+
``skops update``
160+
~~~~~~~~~~~~~~~~
161+
162+
To update a ``Skops`` file from the command line, use the ``skops update`` command.
163+
Skops will check the protocol version of the file to determine if it needs to be updated to the current version.
164+
165+
The below command is an example on how to create an updated version of a file
166+
``my_model.skops`` and save it as ``my_model-updated.skops``:
167+
168+
.. code-block:: console
169+
170+
skops update my_model.skops -o my_model-updated.skops
171+
172+
Further help for the different supported options can be found by calling
173+
``skops update --help`` in a terminal.
174+
154175
Visualization
155176
#############
156177

skops/cli/_update.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import json
5+
import logging
6+
import shutil
7+
import tempfile
8+
import zipfile
9+
from pathlib import Path
10+
11+
from skops.cli._utils import get_log_level
12+
from skops.io import dump, load
13+
from skops.io._protocol import PROTOCOL
14+
15+
16+
def _update_file(
17+
input_file: str | Path,
18+
output_file: str | Path | None = None,
19+
inplace: bool = False,
20+
logger: logging.Logger = logging.getLogger(),
21+
) -> None:
22+
"""Function that is called by ``skops update`` entrypoint.
23+
24+
Loads a skops model from the input path, updates it to the current skops format, and
25+
saves to an output file. It will overwrite the input file if `inplace` is True.
26+
27+
Parameters
28+
----------
29+
input_file : str, or Path
30+
Path of input skops model to load.
31+
32+
output_file : str, or Path, default=None
33+
Path to save the updated skops model to.
34+
35+
inplace : bool, default=False
36+
Whether to update and overwrite the input file in place.
37+
38+
logger : logging.Logger, default=logging.getLogger()
39+
Logger to use for logging.
40+
"""
41+
if inplace:
42+
if output_file is None:
43+
output_file = input_file
44+
else:
45+
raise ValueError(
46+
"Cannot specify both an output file path and the inplace flag. Please"
47+
" choose whether you want to create a new file or overwrite the input"
48+
" file."
49+
)
50+
51+
input_model = load(input_file, trusted=True)
52+
with zipfile.ZipFile(input_file, "r") as zip_file:
53+
input_file_schema = json.loads(zip_file.read("schema.json"))
54+
55+
if input_file_schema["protocol"] == PROTOCOL:
56+
logger.warning(
57+
"File was not updated because already up to date with the current protocol:"
58+
f" {PROTOCOL}"
59+
)
60+
return None
61+
62+
if input_file_schema["protocol"] > PROTOCOL:
63+
logger.warning(
64+
"File cannot be updated because its protocol is more recent than the "
65+
f"current protocol: {PROTOCOL}"
66+
)
67+
return None
68+
69+
if output_file is None:
70+
logger.warning(
71+
f"File can be updated to the current protocol: {PROTOCOL}. Please"
72+
" specify an output file path or use the `inplace` flag to create the"
73+
" updated Skops file."
74+
)
75+
return None
76+
77+
with tempfile.TemporaryDirectory() as tmp_dir:
78+
tmp_output_file = Path(tmp_dir) / f"{output_file}.tmp"
79+
dump(input_model, tmp_output_file)
80+
shutil.move(str(tmp_output_file), str(output_file))
81+
logger.info(f"Updated skops file written to {output_file}")
82+
83+
84+
def format_parser(
85+
parser: argparse.ArgumentParser | None = None,
86+
) -> argparse.ArgumentParser:
87+
"""Adds arguments and help to parent CLI parser for the `update` method."""
88+
89+
if not parser: # used in tests
90+
parser = argparse.ArgumentParser()
91+
92+
parser_subgroup = parser.add_argument_group("update")
93+
parser_subgroup.add_argument("input", help="Path to an input file to update.")
94+
95+
parser_subgroup.add_argument(
96+
"-o",
97+
"--output-file",
98+
help="Specify the output file name for the updated skops file.",
99+
default=None,
100+
)
101+
parser_subgroup.add_argument(
102+
"--inplace",
103+
help="Update and overwrite the input file in place.",
104+
action="store_true",
105+
)
106+
parser_subgroup.add_argument(
107+
"-v",
108+
"--verbose",
109+
help=(
110+
"Increases verbosity of logging. Can be used multiple times to increase "
111+
"verbosity further."
112+
),
113+
action="count",
114+
dest="loglevel",
115+
default=0,
116+
)
117+
return parser
118+
119+
120+
def main(
121+
parsed_args: argparse.Namespace,
122+
logger: logging.Logger = logging.getLogger(),
123+
) -> None:
124+
output_file = Path(parsed_args.output_file) if parsed_args.output_file else None
125+
input_file = Path(parsed_args.input)
126+
inplace = parsed_args.inplace
127+
128+
logging.basicConfig(format="%(levelname)-8s: %(message)s")
129+
logger.setLevel(level=get_log_level(parsed_args.loglevel))
130+
131+
_update_file(
132+
input_file=input_file,
133+
output_file=output_file,
134+
inplace=inplace,
135+
logger=logger,
136+
)

skops/cli/entrypoint.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22

33
import skops.cli._convert
4+
import skops.cli._update
45

56

67
def main_cli(command_line_args=None):
@@ -32,6 +33,10 @@ def main_cli(command_line_args=None):
3233
"method": skops.cli._convert.main,
3334
"format_parser": skops.cli._convert.format_parser,
3435
},
36+
"update": {
37+
"method": skops.cli._update.main,
38+
"format_parser": skops.cli._update.format_parser,
39+
},
3540
}
3641

3742
for func_name, values in function_map.items():

skops/cli/tests/test_entrypoint.py

+20
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,23 @@ def test_convert_works_as_expected(
3939
)
4040

4141
assert caplog.at_level(logging.WARNING)
42+
43+
@mock.patch("skops.cli._update._update_file")
44+
def test_update_works_as_expected(
45+
self,
46+
update_file_mock: mock.MagicMock,
47+
):
48+
"""
49+
To make sure the parser is configured correctly, when 'update'
50+
is the first argument.
51+
"""
52+
53+
args = ["update", "abc.skops", "-o", "abc-new.skops"]
54+
55+
main_cli(args)
56+
update_file_mock.assert_called_once_with(
57+
input_file=pathlib.Path("abc.skops"),
58+
output_file=pathlib.Path("abc-new.skops"),
59+
inplace=False,
60+
logger=mock.ANY,
61+
)

0 commit comments

Comments
 (0)