Skip to content

Commit 6277e83

Browse files
committed
[experiment](feat) Refactor as context manager
1 parent 22d00ae commit 6277e83

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

pydentification/experiment/storage/compose.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,39 @@
11
import importlib.util
22
import json
33
import os
4+
import shutil
45
import sys
56
import zipfile
67
from pathlib import Path
7-
from typing import Callable
8+
from typing import Any, Callable
9+
10+
11+
class ReplaceSourceCode:
12+
"""
13+
ContextManager over-writing imports with given `path` to ZIP with source code created by
14+
`pydentification.experiment.storage.code.save_code_snapshot`.
15+
16+
Code is extracted to a temporary directory and added to the `sys.path` for the duration of the context and removed
17+
afterward on exit. The source code needs to be unique directory to avoid conflicts with other imports.
18+
"""
19+
20+
def __init__(self, path: Path):
21+
self.path = path
22+
self.source_path = path.with_suffix("")
23+
24+
def __enter__(self):
25+
if self.source_path.exists():
26+
raise FileExistsError(f"Can't overwrite {self.source_path.stem}!")
27+
28+
with zipfile.ZipFile(self.path, "r") as zip:
29+
zip.extractall(str(self.source_path))
30+
31+
sys.path.append(str(self.source_path))
32+
return self
33+
34+
def __exit__(self, exc_type, exc_val, exc_tb):
35+
sys.path.remove(str(self.source_path))
36+
shutil.rmtree(self.source_path)
837

938

1039
def _import_function_from_path(module_path: str, function_name: str) -> Callable:
@@ -21,12 +50,9 @@ def _import_function_from_path(module_path: str, function_name: str) -> Callable
2150
return function
2251

2352

24-
def _safe_unzip(path: Path):
25-
if path.with_suffix("").exists():
26-
raise FileExistsError(f"Can't overwrite {path.stem}!")
27-
28-
with zipfile.ZipFile(path, "r") as zip:
29-
zip.extractall(str(path.parent))
53+
def _load_model_and_parameters(path: str | Path, name: str, parameters: dict[str, Any]) -> Any:
54+
model_fn = _import_function_from_path(path, name)
55+
return model_fn(**parameters)
3056

3157

3258
def compose_model(
@@ -48,20 +74,14 @@ def compose_model(
4874
if isinstance(source, str):
4975
source = Path(source)
5076

51-
if source is not None:
52-
_safe_unzip(source)
53-
sys.path.append(str(source.parent))
54-
55-
model_fn = _import_function_from_path(path, name)
56-
5777
if parameters is not None:
5878
with open(parameters, "r") as f:
5979
parameters = json.load(f)
6080
else:
6181
parameters = {}
6282

63-
model = model_fn(**parameters)
64-
6583
if source is not None:
66-
sys.path.remove(str(source.parent))
67-
return model
84+
with ReplaceSourceCode(source):
85+
return _load_model_and_parameters(path, name, parameters)
86+
else:
87+
return _load_model_and_parameters(path, name, parameters)

0 commit comments

Comments
 (0)