11from warnings import catch_warnings
2+
23import numpy as np
4+ from pytest import approx
35from scipy .special import expit as sigmoid
4-
5- from slise .optimisation import loss_smooth
6- from slise .data import add_intercept_column , scale_same
6+ from slise import explain , regression
7+ from slise .data import add_intercept_column
78from slise .initialisation import (
89 initialise_candidates ,
910 initialise_candidates2 ,
1213 initialise_ols ,
1314 initialise_zeros ,
1415)
15- from slise import regression , explain
16+ from slise . optimisation import loss_smooth
1617from slise .utils import mat_mul_inter
1718
1819from .utils import *
@@ -170,6 +171,7 @@ def test_slise_reg():
170171
171172def test_slise_exp ():
172173 print ("Testing slise explanation" )
174+ np .random .seed (49 )
173175 X , Y , mod = data_create2 (100 , 5 )
174176 Y2 = sigmoid (Y )
175177 w = np .random .uniform (size = 100 ) + 0.5
@@ -178,31 +180,39 @@ def test_slise_exp():
178180 reg = explain (X , Y , 0.1 , x , y , lambda1 = 1e-4 , lambda2 = 1e-4 , normalise = True )
179181 reg .print ()
180182 assert reg .score () <= 0 , f"Slise loss should usually be <=0 ({ reg .score ():.2f} )"
183+ assert y == approx (reg .predict (x ))
181184 assert 1.0 >= reg .subset ().mean () > 0.0
182- reg = explain (X , Y , 0.1 , 19 , lambda1 = 0.01 , lambda2 = 0.01 , normalise = True )
185+ reg = explain (X , Y , 0.1 , 17 , lambda1 = 0.01 , lambda2 = 0.01 , normalise = True )
183186 reg .print ()
184187 assert reg .score () <= 0 , f"Slise loss should usually be <=0 ({ reg .score ():.2f} )"
188+ assert Y [17 ] == approx (reg .predict (X [17 ]))
185189 assert 1.0 >= reg .subset ().mean () > 0.0
186190 reg = explain (X , Y , 0.1 , x , y , lambda1 = 0.01 , lambda2 = 0.01 , normalise = False )
187191 assert reg .score () <= 0 , f"Slise loss should usually be <=0 ({ reg .score ():.2f} )"
192+ assert y == approx (reg .predict (x ))
188193 assert 1.0 >= reg .subset ().mean () > 0.0
189194 reg = explain (X , Y , 0.1 , x , y , lambda1 = 0 , lambda2 = 0 , normalise = False )
190195 reg .print ()
191196 assert reg .score () <= 0 , f"Slise loss should usually be <=0 ({ reg .score ():.2f} )"
197+ assert y == approx (reg .predict (x ))
192198 assert 1.0 >= reg .subset ().mean () > 0.0
193- reg = explain (X , Y , 0.1 , 19 , lambda1 = 0.01 , lambda2 = 0.01 , normalise = False )
199+ reg = explain (X , Y , 0.1 , 18 , lambda1 = 0.01 , lambda2 = 0.01 , normalise = False )
194200 reg .print ()
195201 assert reg .score () <= 0 , f"Slise loss should usually be <=0 ({ reg .score ():.2f} )"
202+ assert Y [18 ] == approx (reg .predict (X [18 ]))
196203 assert 1.0 >= reg .subset ().mean () > 0.0
197204 reg = explain (X , Y , 0.1 , 19 , lambda1 = 0 , lambda2 = 0 , normalise = False )
198205 reg .print ()
199206 assert reg .score () <= 0 , f"Slise loss should usually be <=0 ({ reg .score ():.2f} )"
207+ assert Y [19 ] == approx (reg .predict (X [19 ]))
200208 assert 1.0 >= reg .subset ().mean () > 0.0
201209 reg = explain (X , Y , 0.1 , 19 , lambda1 = 0.01 , lambda2 = 0.01 , weight = w , normalise = False )
202210 reg .print ()
203211 assert reg .score () <= 0 , f"Slise loss should usually be <=0 ({ reg .score ():.2f} )"
212+ assert Y [19 ] == approx (reg .predict (X [19 ]))
204213 assert 1.0 >= reg .subset ().mean () > 0.0
205- reg = explain (X , Y2 , 0.5 , 19 , weight = w , normalise = True , logit = True )
214+ reg = explain (X , Y2 , 0.5 , 20 , weight = w , normalise = True , logit = True )
206215 reg .print ()
207216 assert reg .score () <= 0 , f"Slise loss should usually be <=0 ({ reg .score ():.2f} )"
217+ assert Y2 [20 ] == approx (reg .predict (X [20 ]))
208218 assert 1.0 >= reg .subset ().mean () > 0.0
0 commit comments