Skip to content

Commit 1b1789a

Browse files
fix multiclass test
1 parent c026fd5 commit 1b1789a

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/integration_tests/test_contributions_multiclass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ def test_rank_contributions_1(self):
9797
model.fit(self.x_train, self.y_train)
9898
explainer = shap.TreeExplainer(model)
9999
shap_values = explainer.shap_values(self.x_test)
100-
slist = [pd.DataFrame(data=tab, index=self.x_test.index, columns=self.x_test.columns) for tab in shap_values]
100+
slist = [
101+
pd.DataFrame(data=shap_values[:, :, i], index=self.x_test.index, columns=self.x_test.columns)
102+
for i in range(3)
103+
]
101104

102105
for i in range(3):
103106
s_ord, x_ord, s_dict = rank_contributions(slist[i], pd.DataFrame(data=self.x_test))

0 commit comments

Comments
 (0)