Skip to content

Commit 811fe13

Browse files
committed
add exact EBM test for regression, and multiclass
1 parent 976efe0 commit 811fe13

File tree

3 files changed

+49
-31
lines changed

3 files changed

+49
-31
lines changed

python/interpret-core/interpret-core.pyproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
<Compile Include="tests\glassbox\ebm\research\__init__.py" />
118118
<Compile Include="tests\glassbox\ebm\test_bin.py" />
119119
<Compile Include="tests\glassbox\ebm\test_ebm.py" />
120+
<Compile Include="tests\glassbox\ebm\test_ebm_exact.py" />
120121
<Compile Include="tests\glassbox\ebm\test_ebm_utils.py" />
121122
<Compile Include="tests\glassbox\ebm\test_merge_ebms.py" />
122123
<Compile Include="tests\glassbox\ebm\test_multiclass.py" />

python/interpret-core/tests/glassbox/ebm/test_ebm.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,34 +1243,3 @@ def test_replicatability_classification():
12431243
if total1 != total2:
12441244
assert total1 == total2
12451245
break
1246-
1247-
1248-
def test_identical_classification():
1249-
from interpret.develop import get_option, set_option
1250-
1251-
original = get_option("acceleration")
1252-
set_option("acceleration", 0)
1253-
1254-
for iteration in range(1):
1255-
total = 0.0
1256-
seed = 0
1257-
for i in range(10):
1258-
X, y, names, types = make_synthetic(
1259-
seed=seed, classes=2, output_type="float", n_samples=250
1260-
)
1261-
seed += 1
1262-
1263-
ebm = ExplainableBoostingClassifier(
1264-
names, types, random_state=seed, max_rounds=10
1265-
)
1266-
ebm.fit(X, y)
1267-
1268-
pred = ebm.eval_terms(X)
1269-
total += np.sum(pred)
1270-
1271-
expected = -3.941291737419306e-15
1272-
if total != expected:
1273-
assert total == expected
1274-
break
1275-
1276-
set_option("acceleration", original)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2023 The InterpretML Contributors
2+
# Distributed under the MIT software license
3+
4+
from interpret.glassbox import (
5+
ExplainableBoostingClassifier,
6+
ExplainableBoostingRegressor,
7+
)
8+
from interpret.utils import make_synthetic
9+
from interpret.develop import get_option, set_option
10+
from interpret.utils._native import Native
11+
12+
13+
def test_identical_classification():
14+
original = get_option("acceleration")
15+
set_option("acceleration", 0)
16+
17+
total = 0.0
18+
seed = 0
19+
for n_classes in range(Native.Task_Regression, 4):
20+
if n_classes < 2 and n_classes != Native.Task_Regression:
21+
continue
22+
23+
classes = None if n_classes == Native.Task_Regression else n_classes
24+
25+
for iteration in range(2):
26+
X, y, names, types = make_synthetic(
27+
seed=seed, classes=classes, output_type="float", n_samples=257
28+
)
29+
30+
ebm_type = (
31+
ExplainableBoostingClassifier
32+
if 0 <= n_classes
33+
else ExplainableBoostingRegressor
34+
)
35+
ebm = ebm_type(names, types, random_state=seed, max_rounds=10)
36+
ebm.fit(X, y)
37+
38+
pred = ebm._predict_score(X)
39+
total += sum(pred.flat)
40+
41+
seed += 1
42+
43+
expected = 604.4169336846871
44+
45+
if total != expected:
46+
assert total == expected
47+
48+
set_option("acceleration", original)

0 commit comments

Comments
 (0)