Skip to content

Commit faddb3b

Browse files
authored
feat: load model from file (#98)
1 parent 89c4424 commit faddb3b

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

src/vowpal_wabbit_next/workspace.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
args (List[str]): VowpalWabbit command line options for configuring the model. An overall list can be found `here <https://vowpalwabbit.org/docs/vowpal_wabbit/python/latest/command_line_args.html>`_. Options which affect the driver are not supported. For example:
136136
`--sort_features`, `--ngram`, `--feature_limit`, `--ignore`, `--extra_metrics`, `--dump_json_weights_experimental`
137137
model_data (Optional[bytes], optional): Bytes of a VW model to be loaded.
138-
record_invert_hash (bool, optional): If true, the invert hash will be recorded for each example. This is required to use :py:meth:`vowpal_wabbit_next.Workspace.json_weights`. This will slow down parsing and learn/predict.
138+
record_feature_names (bool, optional): If true, the invert hash will be recorded for each example. This is required to use :py:meth:`vowpal_wabbit_next.Workspace.json_weights`. This will slow down parsing and learn/predict.
139139
record_metrics (bool, optional): If true, reduction metrics will be enabled and can be fetched with :py:attr:`vowpal_wabbit_next.Workspace.metrics`
140140
enable_debug_tree (bool, optional): If true, debug information in the form of the computation tree will be emitted by :py:meth:`~vowpal_wabbit_next.learn_one`, :py:meth:`~vowpal_wabbit_next.predict_one` and :py:meth:`~vowpal_wabbit_next.predict_then_learn_one`. This will affect performance negatively. See :py:class:`~vowpal_wabbit_next.DebugNode` for more information.
141141
@@ -335,6 +335,51 @@ def serialize(self) -> bytes:
335335
"""
336336
return self._workspace.serialize()
337337

338+
@staticmethod
339+
def load_from_file(
340+
file_path: Union[str, os.PathLike[Any]],
341+
args: List[str] = [],
342+
*,
343+
record_feature_names: bool = False,
344+
record_metrics: bool = False,
345+
enable_debug_tree: bool = False,
346+
) -> Workspace[Any]:
347+
"""Load a VW model from a file.
348+
349+
Args:
350+
file_path (Union[str, os.PathLike[Any]]): Path to file containing serialized model
351+
args (List[str]): VowpalWabbit command line options for configuring the model. An overall list can be found `here <https://vowpalwabbit.org/docs/vowpal_wabbit/python/latest/command_line_args.html>`_. Options which affect the driver are not supported. For example:
352+
`--sort_features`, `--ngram`, `--feature_limit`, `--ignore`, `--extra_metrics`, `--dump_json_weights_experimental`
353+
record_feature_names (bool, optional): If true, the invert hash will be recorded for each example. This is required to use :py:meth:`vowpal_wabbit_next.Workspace.json_weights`. This will slow down parsing and learn/predict.
354+
record_metrics (bool, optional): If true, reduction metrics will be enabled and can be fetched with :py:attr:`vowpal_wabbit_next.Workspace.metrics`
355+
enable_debug_tree (bool, optional): If true, debug information in the form of the computation tree will be emitted by :py:meth:`~vowpal_wabbit_next.learn_one`, :py:meth:`~vowpal_wabbit_next.predict_one` and :py:meth:`~vowpal_wabbit_next.predict_then_learn_one`. This will affect performance negatively. See :py:class:`~vowpal_wabbit_next.DebugNode` for more information.
356+
357+
.. warning::
358+
This is an experimental feature.
359+
360+
Returns:
361+
Workspace[Any]: Workspace with the loaded model
362+
"""
363+
with open(file_path, "rb") as f:
364+
model_data = f.read()
365+
366+
if enable_debug_tree:
367+
return Workspace[Literal[True]](
368+
args,
369+
model_data=model_data,
370+
record_feature_names=record_feature_names,
371+
record_metrics=record_metrics,
372+
enable_debug_tree=True,
373+
)
374+
else:
375+
return Workspace[Literal[False]](
376+
args,
377+
model_data=model_data,
378+
record_feature_names=record_feature_names,
379+
record_metrics=record_metrics,
380+
enable_debug_tree=False,
381+
)
382+
338383
def serialize_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
339384
"""Serialize the current workspace as a VW model to a file."""
340385
return self._workspace.serialize_to_file(os.fspath(file_path))

tests/test_serialization.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ def test_serialize_to_file_and_load() -> None:
4343
model.serialize_to_file(model_path)
4444

4545
try:
46-
with open(model_path, "rb") as f:
47-
model2 = vw.Workspace(model_data=f.read())
48-
46+
model2 = vw.Workspace.load_from_file(model_path)
4947
parser2 = vw.TextFormatParser(model)
5048
test_example2 = parser2.parse_line(test_example_input)
5149
pred2 = model2.predict_one(test_example2)

0 commit comments

Comments
 (0)