Skip to content

Commit 22d00ae

Browse files
committed
[experiment](feat) Add util for composing model from stored code and model dump
1 parent 95b6a43 commit 22d00ae

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

pydentification/experiment/storage/compose.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import sys
5+
import zipfile
56
from pathlib import Path
67
from typing import Callable
78

@@ -20,15 +21,37 @@ def _import_function_from_path(module_path: str, function_name: str) -> Callable
2021
return function
2122

2223

23-
def compose_model(path: str | Path, name: str = "model_fn", parameters: str | Path | None = None):
24+
def _safe_unzip(path: Path):
25+
if path.with_suffix("").exists():
26+
raise FileExistsError(f"Can't overwrite {path.stem}!")
27+
28+
with zipfile.ZipFile(path, "r") as zip:
29+
zip.extractall(str(path.parent))
30+
31+
32+
def compose_model(
33+
path: str | Path,
34+
name: str = "model_fn",
35+
parameters: str | Path | None = None,
36+
source: str | Path | None = None,
37+
):
2438
"""
2539
Compose model from dump, which will contain model generating function, JSON with its parameters and source code
2640
for module definitions (ZIP of entire `pydentification`).
2741
2842
:param path: filesystem Path to the model generating function, which will be imported by `import_function_from_path`
2943
:param name: name of the function to be imported, default is `model_fn`
3044
:param parameters: filesystem Path to the JSON file with parameters, if None, empty dictionary will be used
45+
:param source: filesystem Path to the ZIP file with source code
46+
if None imports are attempted from the current working directory.
3147
"""
48+
if isinstance(source, str):
49+
source = Path(source)
50+
51+
if source is not None:
52+
_safe_unzip(source)
53+
sys.path.append(str(source.parent))
54+
3255
model_fn = _import_function_from_path(path, name)
3356

3457
if parameters is not None:
@@ -37,4 +60,8 @@ def compose_model(path: str | Path, name: str = "model_fn", parameters: str | Pa
3760
else:
3861
parameters = {}
3962

40-
return model_fn(parameters)
63+
model = model_fn(**parameters)
64+
65+
if source is not None:
66+
sys.path.remove(str(source.parent))
67+
return model

0 commit comments

Comments
 (0)