11import importlib .util
22import json
33import os
4+ import shutil
45import sys
56import zipfile
67from pathlib import Path
7- from typing import Callable
8+ from typing import Any , Callable
9+
10+
11+ class ReplaceSourceCode :
12+ """
13+ ContextManager over-writing imports with given `path` to ZIP with source code created by
14+ `pydentification.experiment.storage.code.save_code_snapshot`.
15+
16+ Code is extracted to a temporary directory and added to the `sys.path` for the duration of the context and removed
17+ afterward on exit. The source code needs to be unique directory to avoid conflicts with other imports.
18+ """
19+
20+ def __init__ (self , path : Path ):
21+ self .path = path
22+ self .source_path = path .with_suffix ("" )
23+
24+ def __enter__ (self ):
25+ if self .source_path .exists ():
26+ raise FileExistsError (f"Can't overwrite { self .source_path .stem } !" )
27+
28+ with zipfile .ZipFile (self .path , "r" ) as zip :
29+ zip .extractall (str (self .source_path ))
30+
31+ sys .path .append (str (self .source_path ))
32+ return self
33+
34+ def __exit__ (self , exc_type , exc_val , exc_tb ):
35+ sys .path .remove (str (self .source_path ))
36+ shutil .rmtree (self .source_path )
837
938
1039def _import_function_from_path (module_path : str , function_name : str ) -> Callable :
@@ -21,12 +50,9 @@ def _import_function_from_path(module_path: str, function_name: str) -> Callable
2150 return function
2251
2352
24- def _safe_unzip (path : Path ):
25- if path .with_suffix ("" ).exists ():
26- raise FileExistsError (f"Can't overwrite { path .stem } !" )
27-
28- with zipfile .ZipFile (path , "r" ) as zip :
29- zip .extractall (str (path .parent ))
53+ def _load_model_and_parameters (path : str | Path , name : str , parameters : dict [str , Any ]) -> Any :
54+ model_fn = _import_function_from_path (path , name )
55+ return model_fn (** parameters )
3056
3157
3258def compose_model (
@@ -48,20 +74,14 @@ def compose_model(
4874 if isinstance (source , str ):
4975 source = Path (source )
5076
51- if source is not None :
52- _safe_unzip (source )
53- sys .path .append (str (source .parent ))
54-
55- model_fn = _import_function_from_path (path , name )
56-
5777 if parameters is not None :
5878 with open (parameters , "r" ) as f :
5979 parameters = json .load (f )
6080 else :
6181 parameters = {}
6282
63- model = model_fn (** parameters )
64-
6583 if source is not None :
66- sys .path .remove (str (source .parent ))
67- return model
84+ with ReplaceSourceCode (source ):
85+ return _load_model_and_parameters (path , name , parameters )
86+ else :
87+ return _load_model_and_parameters (path , name , parameters )
0 commit comments