Skip to content

Autoload shared libraries #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
11 changes: 7 additions & 4 deletions conifer/backends/cpp/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ class CPPModel(ModelBase):
def __init__(self, ensembleDict, config, metadata=None):
super(CPPModel, self).__init__(ensembleDict, config, metadata)
self.config = CPPConfig(config)

def load_shared_library(self, model_json, shared_library):
import importlib
spec = importlib.util.spec_from_file_location(os.path.basename(shared_library).split(".so")[0], shared_library)
self.bridge = importlib.util.module_from_spec(spec).BDT(model_json)
spec.loader.exec_module(self.bridge)

@copydocstring(ModelBase.write)
def write(self):
Expand Down Expand Up @@ -99,10 +105,7 @@ def compile(self):

try:
logger.debug(f'Importing conifer_bridge_{self._stamp} from conifer_bridge_{self._stamp}.so')
import importlib.util
spec = importlib.util.spec_from_file_location(f'conifer_bridge_{self._stamp}', f'./conifer_bridge_{self._stamp}.so')
self.bridge = importlib.util.module_from_spec(spec).BDT(f"{cfg.project_name}.json")
spec.loader.exec_module(self.bridge)
self.load_shared_library(f"{cfg.project_name}.json", f"./conifer_bridge_{self._stamp}.so")
except ImportError:
os.chdir(curr_dir)
raise Exception("Can't import pybind11 bridge, is it compiled?")
Expand Down
11 changes: 7 additions & 4 deletions conifer/backends/xilinxhls/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,12 @@ def decision_function(self, X, trees=False):
y = y.reshape(y.shape[0])
return y

def load_shared_library(self, model_json, shared_library):
import importlib
spec = importlib.util.spec_from_file_location(os.path.basename(shared_library).split(".so")[0], shared_library)
self.bridge = importlib.util.module_from_spec(spec)
spec.loader.exec_module(self.bridge)

@copydocstring(ModelBase.compile)
def compile(self):
self.write()
Expand All @@ -534,10 +540,7 @@ def compile(self):

try:
logger.debug(f'Importing conifer_bridge_{self._stamp} from conifer_bridge_{self._stamp}.so')
import importlib.util
spec = importlib.util.spec_from_file_location(f'conifer_bridge_{self._stamp}', f'./conifer_bridge_{self._stamp}.so')
self.bridge = importlib.util.module_from_spec(spec)
spec.loader.exec_module(self.bridge)
self.load_shared_library(f"{cfg.project_name}.json", f"./conifer_bridge_{self._stamp}.so")
except ImportError:
os.chdir(curr_dir)
raise Exception("Can't import pybind11 bridge, is it compiled?")
Expand Down
34 changes: 33 additions & 1 deletion conifer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,9 @@ def _profile(self, what : Literal["scores", "thresholds"], ax=None):

return ax

def load_shared_library(self, model_json, shared_library):
pass

class ModelMetaData:
def __init__(self):
self.version = version
Expand Down Expand Up @@ -532,7 +535,7 @@ def make_model(ensembleDict, config=None):
backend = get_backend(backend)
return backend.make_model(ensembleDict, config)

def load_model(filename, new_config=None):
def load_model(filename, new_config=None, shared_library=True):
'''
Load a Model from JSON file

Expand All @@ -542,6 +545,14 @@ def load_model(filename, new_config=None):
filename to load from
new_config: dictionary (optional)
if provided, override the configuration specified in the JSON file
shared_library: string|bool (optional)
If True, the shared library will be looked for in the same directory as the JSON file, using the timestamp of the last metadata entry available
If False, the shared library will not be loaded
If a string, it could be:
- path to the shared library to load
- path to the directory where to look for the .so file, using the timestamp of the last metadata entry available

No shared library will be loaded if a new configuration is provided
'''
with open(filename, 'r') as json_file:
js = json.load(json_file)
Expand All @@ -561,4 +572,25 @@ def load_model(filename, new_config=None):

model = make_model(js, config)
model._metadata = metadata + model._metadata

if new_config is None and shared_library is not False:
shared_library_path=None
if isinstance(shared_library, str) and shared_library.endswith(".so"):
shared_library_path=shared_library
else:
from glob import glob
shared_library_dirpath=os.path.abspath(os.path.dirname(filename)) if shared_library is True else os.path.abspath(shared_library)
timestamps=[int(md._to_dict()["time"]) for md in model._metadata[-2::-1]]
so_files=glob(os.path.join(shared_library_dirpath, 'conifer_bridge_*.so'))
so_files=[os.path.basename(so_file) for so_file in so_files]
for timestamp in timestamps:
if f"conifer_bridge_{timestamp}.so" in so_files:
shared_library_path=os.path.join(shared_library_dirpath, f'conifer_bridge_{timestamp}.so')
break

try:
model.load_shared_library(filename, shared_library_path)
except Exception:
print("An existing shared library was either not found or could not be loaded. Run model.compile()")

return model
24 changes: 24 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import conifer
import json
import os

'''
Test conifer's model saving and loading functionality by loading some models and checking the predictions
Expand All @@ -16,6 +17,29 @@ def test_hls_save_load(hls_convert, train_skl):
y_hls_0, y_hls_1 = util.predict_skl(orig_model, X, y, load_model)
np.testing.assert_array_equal(y_hls_0, y_hls_1)

def test_hls_reload_last_shared_library(hls_convert, train_skl):
clf, X, y = train_skl
initial_model = conifer.model.load_model(f'{hls_convert.config.output_dir}/{hls_convert.config.project_name}.json', shared_library = False)
initial_model.config.output_dir += '_loaded'
initial_model.compile()
# Re-load without recompiling to check if the shared library is loaded correctly
reload_model = conifer.model.load_model(f'{hls_convert.config.output_dir}_loaded/{hls_convert.config.project_name}.json', shared_library=True)
y_hls, y_hls_reload = util.predict_skl(initial_model, X, y, reload_model)
np.testing.assert_array_equal(y_hls, y_hls_reload)
assert os.path.basename(initial_model.bridge.__file__) == os.path.basename(reload_model.bridge.__file__), "Loaded two different shared libraries"

def test_hls_reload_manual_shared_library(hls_convert, train_skl):
clf, X, y = train_skl
initial_model = conifer.model.load_model(f'{hls_convert.config.output_dir}/{hls_convert.config.project_name}.json', shared_library = False)
initial_model.config.output_dir += '_loaded'
initial_model.compile()
so_path = os.path.basename(initial_model.bridge.__file__) # manually get the shared library path
# Re-load without recompiling to check if the shared library is loaded correctly
reload_model = conifer.model.load_model(f'{hls_convert.config.output_dir}_loaded/{hls_convert.config.project_name}.json', shared_library=so_path) # pass the shared library path manually
y_hls, y_hls_reload = util.predict_skl(initial_model, X, y, reload_model)
np.testing.assert_array_equal(y_hls, y_hls_reload)
assert os.path.basename(initial_model.bridge.__file__) == os.path.basename(reload_model.bridge.__file__), "Loaded two different shared libraries"

def test_hdl_save_load(vhdl_convert, train_skl):
orig_model = vhdl_convert
clf, X, y = train_skl
Expand Down