Skip to content

Commit fd17836

Browse files
authored
Merge pull request #2 from edahelsinki/threading
Threading
2 parents 2433f8c + fdccc21 commit fd17836

File tree

8 files changed

+213
-54
lines changed

8 files changed

+213
-54
lines changed

.github/workflows/python-pytest.yml

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
name: Test Python Package
55

66
on:
7+
push:
8+
branches: [ master ]
79
pull_request:
810
branches: [ master ]
911

setup.cfg

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = slise
3-
version = 2.0.0
3+
version = 2.1.0
44
author = Anton Björklund
55
author_email = [email protected]
66
description = The SLISE algorithm for robust regression and explanations of black box models
@@ -35,3 +35,6 @@ install_requires =
3535
exclude =
3636
examples
3737
tests
38+
39+
[options.extras_require]
40+
tbb = tbb

slise/__init__.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,27 @@
2121
constraints used to generate the data, e.g., the laws of physics).
2222
2323
24-
More in-depth details about the algorithm can be found in the paper:
24+
More in-depth details about the algorithm can be found in the papers:
2525
2626
Björklund A., Henelius A., Oikarinen E., Kallonen K., Puolamäki K.
2727
Sparse Robust Regression for Explaining Classifiers.
2828
Discovery Science (DS 2019).
2929
Lecture Notes in Computer Science, vol 11828, Springer.
3030
https://doi.org/10.1007/978-3-030-33778-0_27
3131
32+
Björklund A., Henelius A., Oikarinen E., Kallonen K., Puolamäki K.
33+
Robust regression via error tolerance.
34+
Data Mining and Knowledge Discovery (2022).
35+
https://doi.org/10.1007/s10618-022-00819-2
36+
3237
"""
3338

34-
from slise.slise import SliseRegression, regression, SliseExplainer, explain
39+
from slise.slise import (
40+
SliseRegression,
41+
regression,
42+
SliseExplainer,
43+
explain,
44+
SliseWarning,
45+
)
3546
from slise.utils import limited_logit as logit
3647
from slise.data import normalise_robust

slise/optimisation.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88
from lbfgs import LBFGSError, fmin_lbfgs
9-
from numba import jit
9+
from numba import jit, get_num_threads, set_num_threads, threading_layer
1010
from scipy.optimize import brentq
1111

1212
from slise.utils import (
@@ -375,6 +375,33 @@ def debug_log(
375375
)
376376

377377

378+
def set_threads(num: int = -1) -> int:
379+
"""Set the number of numba threads
380+
381+
Args:
382+
num (int, optional): The number of threads. Defaults to -1.
383+
384+
Returns:
385+
int: The old number of theads (or -1 if unchanged).
386+
"""
387+
if num > 0:
388+
old = get_num_threads()
389+
set_num_threads(num)
390+
return old
391+
return -1
392+
393+
394+
def check_threading_layer():
395+
"""Check which numba threading_layer is active, and warn if it is "workqueue".
396+
"""
397+
loss_residuals(np.ones(1), np.ones(1), 1)
398+
if threading_layer() == "workqueue":
399+
warn(
400+
'Using `numba.threading_layer()=="workqueue"` can be devastatingly slow! See https://numba.pydata.org/numba-doc/latest/user/threading-layer.html for alternatives.',
401+
SliseWarning,
402+
)
403+
404+
378405
def graduated_optimisation(
379406
alpha: np.ndarray,
380407
X: np.ndarray,

0 commit comments

Comments
 (0)