Skip to content

Commit 2cf3ca2

Browse files
Anton BjörklundAggrathon
authored andcommitted
fix explanation prediction logit bugs
1 parent 4b533e3 commit 2cf3ca2

File tree

3 files changed

+58
-42
lines changed

3 files changed

+58
-42
lines changed

slise/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def fill_density(ax, X, x, n):
396396
if np.var(X) == 0:
397397
X = np.random.normal(X[0], 1e-8, len(X))
398398
kde1 = gaussian_kde(X, 0.2)
399-
if np.any(subset):
399+
if np.sum(subset) > 1:
400400
kde2 = gaussian_kde(X[subset], 0.2)
401401
else:
402402
kde2 = lambda x: x * 0

slise/slise.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from warnings import warn
99

1010
import numpy as np
11-
from matplotlib.pyplot import Figure
11+
from matplotlib.pyplot import Figure, yscale
1212
from scipy.special import expit as sigmoid
1313

1414
from slise.data import (
@@ -73,17 +73,17 @@ def regression(
7373
SliseRegression: Object containing the regression result.
7474
"""
7575
return SliseRegression(
76-
epsilon,
77-
lambda1,
78-
lambda2,
79-
intercept,
80-
normalise,
81-
initialisation,
82-
beta_max,
83-
max_approx,
84-
max_iterations,
85-
debug,
86-
).fit(X, Y, weight, init)
76+
epsilon=epsilon,
77+
lambda1=lambda1,
78+
lambda2=lambda2,
79+
intercept=intercept,
80+
normalise=normalise,
81+
initialisation=initialisation,
82+
beta_max=beta_max,
83+
max_approx=max_approx,
84+
max_iterations=max_iterations,
85+
debug=debug,
86+
).fit(X=X, Y=Y, weight=weight, init=init)
8787

8888

8989
def explain(
@@ -143,19 +143,19 @@ def explain(
143143
SliseExplainer: Object containing the explanation.
144144
"""
145145
return SliseExplainer(
146-
X,
147-
Y,
148-
epsilon,
149-
lambda1,
150-
lambda2,
151-
normalise,
152-
logit,
153-
initialisation,
154-
beta_max,
155-
max_approx,
156-
max_iterations,
157-
debug,
158-
).explain(x, y, weight, init)
146+
X=X,
147+
Y=Y,
148+
epsilon=epsilon,
149+
lambda1=lambda1,
150+
lambda2=lambda2,
151+
normalise=normalise,
152+
logit=logit,
153+
initialisation=initialisation,
154+
beta_max=beta_max,
155+
max_approx=max_approx,
156+
max_iterations=max_iterations,
157+
debug=debug,
158+
).explain(x=x, y=y, weight=weight, init=init)
159159

160160

161161
class SliseRegression:
@@ -276,9 +276,9 @@ def fit(
276276
alpha, beta = initialise_fixed(init, X, Y, self.epsilon, self._weight)
277277
# Optimisation
278278
alpha = graduated_optimisation(
279-
alpha,
280-
X,
281-
Y,
279+
alpha=alpha,
280+
X=X,
281+
Y=Y,
282282
epsilon=self.epsilon,
283283
beta=beta,
284284
lambda1=self.lambda1,
@@ -588,7 +588,10 @@ def __init__(
588588
if X.shape[1] == X2.shape[1]:
589589
x_cols = None
590590
X, x_center, x_scale = normalise_robust(X2)
591-
Y, y_center, y_scale = normalise_robust(Y)
591+
if logit:
592+
(y_center, y_scale) = (0, 1)
593+
else:
594+
Y, y_center, y_scale = normalise_robust(Y)
592595
self._scale = DataScaling(x_center, x_scale, y_center, y_scale, x_cols)
593596
else:
594597
self._scale = None
@@ -645,9 +648,9 @@ def explain(
645648
else:
646649
alpha, beta = initialise_fixed(init, X, Y, self.epsilon, self._weight)
647650
alpha = graduated_optimisation(
648-
alpha,
649-
X,
650-
Y,
651+
alpha=alpha,
652+
X=X,
653+
Y=Y,
651654
epsilon=self.epsilon,
652655
beta=beta,
653656
lambda1=self.lambda1,
@@ -663,8 +666,11 @@ def explain(
663666
)
664667
self._alpha = alpha
665668
if self._normalise:
669+
y = self._y
670+
if self._logit:
671+
y = limited_logit(y)
666672
alpha2 = self._scale.unscale_model(alpha)
667-
alpha2[0] = self._y - np.sum(self._x * alpha2[1:])
673+
alpha2[0] = y - np.sum(self._x * alpha2[1:])
668674
self._coefficients = alpha2
669675
else:
670676
self._coefficients = alpha
@@ -708,7 +714,7 @@ def predict(self, X: Union[np.ndarray, None] = None) -> np.ndarray:
708714
Y = mat_mul_inter(self._X, self._coefficients)
709715
else:
710716
Y = mat_mul_inter(X, self._coefficients)
711-
if self.scaler.logit:
717+
if self._logit:
712718
Y = sigmoid(Y)
713719
return Y
714720

tests/test_slise.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from warnings import catch_warnings
2+
23
import numpy as np
4+
from pytest import approx
35
from scipy.special import expit as sigmoid
4-
5-
from slise.optimisation import loss_smooth
6-
from slise.data import add_intercept_column, scale_same
6+
from slise import explain, regression
7+
from slise.data import add_intercept_column
78
from slise.initialisation import (
89
initialise_candidates,
910
initialise_candidates2,
@@ -12,7 +13,7 @@
1213
initialise_ols,
1314
initialise_zeros,
1415
)
15-
from slise import regression, explain
16+
from slise.optimisation import loss_smooth
1617
from slise.utils import mat_mul_inter
1718

1819
from .utils import *
@@ -170,6 +171,7 @@ def test_slise_reg():
170171

171172
def test_slise_exp():
172173
print("Testing slise explanation")
174+
np.random.seed(49)
173175
X, Y, mod = data_create2(100, 5)
174176
Y2 = sigmoid(Y)
175177
w = np.random.uniform(size=100) + 0.5
@@ -178,31 +180,39 @@ def test_slise_exp():
178180
reg = explain(X, Y, 0.1, x, y, lambda1=1e-4, lambda2=1e-4, normalise=True)
179181
reg.print()
180182
assert reg.score() <= 0, f"Slise loss should usually be <=0 ({reg.score():.2f})"
183+
assert y == approx(reg.predict(x))
181184
assert 1.0 >= reg.subset().mean() > 0.0
182-
reg = explain(X, Y, 0.1, 19, lambda1=0.01, lambda2=0.01, normalise=True)
185+
reg = explain(X, Y, 0.1, 17, lambda1=0.01, lambda2=0.01, normalise=True)
183186
reg.print()
184187
assert reg.score() <= 0, f"Slise loss should usually be <=0 ({reg.score():.2f})"
188+
assert Y[17] == approx(reg.predict(X[17]))
185189
assert 1.0 >= reg.subset().mean() > 0.0
186190
reg = explain(X, Y, 0.1, x, y, lambda1=0.01, lambda2=0.01, normalise=False)
187191
assert reg.score() <= 0, f"Slise loss should usually be <=0 ({reg.score():.2f})"
192+
assert y == approx(reg.predict(x))
188193
assert 1.0 >= reg.subset().mean() > 0.0
189194
reg = explain(X, Y, 0.1, x, y, lambda1=0, lambda2=0, normalise=False)
190195
reg.print()
191196
assert reg.score() <= 0, f"Slise loss should usually be <=0 ({reg.score():.2f})"
197+
assert y == approx(reg.predict(x))
192198
assert 1.0 >= reg.subset().mean() > 0.0
193-
reg = explain(X, Y, 0.1, 19, lambda1=0.01, lambda2=0.01, normalise=False)
199+
reg = explain(X, Y, 0.1, 18, lambda1=0.01, lambda2=0.01, normalise=False)
194200
reg.print()
195201
assert reg.score() <= 0, f"Slise loss should usually be <=0 ({reg.score():.2f})"
202+
assert Y[18] == approx(reg.predict(X[18]))
196203
assert 1.0 >= reg.subset().mean() > 0.0
197204
reg = explain(X, Y, 0.1, 19, lambda1=0, lambda2=0, normalise=False)
198205
reg.print()
199206
assert reg.score() <= 0, f"Slise loss should usually be <=0 ({reg.score():.2f})"
207+
assert Y[19] == approx(reg.predict(X[19]))
200208
assert 1.0 >= reg.subset().mean() > 0.0
201209
reg = explain(X, Y, 0.1, 19, lambda1=0.01, lambda2=0.01, weight=w, normalise=False)
202210
reg.print()
203211
assert reg.score() <= 0, f"Slise loss should usually be <=0 ({reg.score():.2f})"
212+
assert Y[19] == approx(reg.predict(X[19]))
204213
assert 1.0 >= reg.subset().mean() > 0.0
205-
reg = explain(X, Y2, 0.5, 19, weight=w, normalise=True, logit=True)
214+
reg = explain(X, Y2, 0.5, 20, weight=w, normalise=True, logit=True)
206215
reg.print()
207216
assert reg.score() <= 0, f"Slise loss should usually be <=0 ({reg.score():.2f})"
217+
assert Y2[20] == approx(reg.predict(X[20]))
208218
assert 1.0 >= reg.subset().mean() > 0.0

0 commit comments

Comments
 (0)