Skip to content

Commit ce2ede1

Browse files
committed
improve the exact EBM test
1 parent 9e49c46 commit ce2ede1

1 file changed

Lines changed: 14 additions & 5 deletions

File tree

python/interpret-core/tests/glassbox/ebm/test_ebm_exact.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,34 @@ def test_identical_classification():
2222

2323
classes = None if n_classes == Native.Task_Regression else n_classes
2424

25-
for iteration in range(2):
25+
for iteration in range(1):
26+
test_type = (
27+
"regression"
28+
if n_classes == Native.Task_Regression
29+
else str(n_classes) + " classes"
30+
)
31+
print(f"Exact test for {test_type}, iteration {iteration}.")
2632
X, y, names, types = make_synthetic(
27-
seed=seed, classes=classes, output_type="float", n_samples=257
33+
seed=seed,
34+
classes=classes,
35+
output_type="float",
36+
n_samples=257 + iteration,
2837
)
2938

3039
ebm_type = (
3140
ExplainableBoostingClassifier
3241
if 0 <= n_classes
3342
else ExplainableBoostingRegressor
3443
)
35-
ebm = ebm_type(names, types, random_state=seed, max_rounds=10)
44+
ebm = ebm_type(names, types, random_state=seed)
3645
ebm.fit(X, y)
3746

3847
pred = ebm._predict_score(X)
39-
total += sum(pred.flat)
48+
total += sum(pred.flat) # do not use numpy which could use SIMD for sum.
4049

4150
seed += 1
4251

43-
expected = 604.4169336846871
52+
expected = 345.57668871448516
4453

4554
if total != expected:
4655
assert total == expected

0 commit comments

Comments
 (0)