Skip to content

Commit 8aef41d

Browse files
Merge pull request #663 from guillaume-vignal/feature/add_error_column_in_datatable_for_webapp
Add `_error_` Column Support for Classification
2 parents 22d324d + c0d74f8 commit 8aef41d

File tree

5 files changed

+151
-51
lines changed

5 files changed

+151
-51
lines changed

shapash/explainer/smart_explainer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,9 @@ def compile(
328328
self.predict_proba()
329329

330330
self.y_target = check_y(self.x_init, y_target, y_name="y_target")
331-
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
331+
self.prediction_error = predict_error(
332+
self.y_target, self.y_pred, self._case, proba_values=self.proba_values, classes=self._classes
333+
)
332334

333335
self._get_contributions_from_backend_or_user(x, contributions)
334336
self.check_contributions()
@@ -536,14 +538,14 @@ def add(
536538
"""
537539
if y_pred is not None:
538540
self.y_pred = check_y(self.x_init, y_pred, y_name="y_pred")
539-
if hasattr(self, "y_target"):
540-
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
541541
if proba_values is not None:
542542
self.proba_values = check_y(self.x_init, proba_values, y_name="proba_values")
543543
if y_target is not None:
544544
self.y_target = check_y(self.x_init, y_target, y_name="y_target")
545-
if hasattr(self, "y_pred"):
546-
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
545+
if hasattr(self, "y_target") and self.y_target is not None:
546+
self.prediction_error = predict_error(
547+
self.y_target, self.y_pred, self._case, proba_values=self.proba_values, classes=self._classes
548+
)
547549
if label_dict is not None:
548550
if isinstance(label_dict, dict) is False:
549551
raise ValueError(
@@ -1058,7 +1060,9 @@ def predict(self):
10581060
"""
10591061
self.y_pred = predict(self.model, self.x_encoded)
10601062
if hasattr(self, "y_target"):
1061-
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
1063+
self.prediction_error = predict_error(
1064+
self.y_target, self.y_pred, self._case, proba_values=self.proba_values, classes=self._classes
1065+
)
10621066

10631067
def to_pandas(
10641068
self,

shapash/utils/model.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def extract_features_model(model, model_attribute):
1111
"""
1212
Extract features of models if it's possible,
1313
If not extract the number features of model
14-
-------
14+
-------
1515
model: model object
1616
model used to check the different values of target estimate predict proba
1717
model_attribute: String or list
@@ -82,30 +82,76 @@ def predict(model, x_encoded):
8282
return y_pred
8383

8484

85-
def predict_error(y_target, y_pred, case):
85+
def predict_error(y_target, y_pred, model_type, proba_values=None, classes=None):
8686
"""
87-
The predict_error function computes the prediction errors from the
88-
prediction values and the target values.
87+
Compute prediction errors for regression or classification.
88+
89+
For regression:
90+
- If the target can be zero, absolute error is used:
91+
error = |y_true - y_pred|
92+
- Otherwise, relative error is used:
93+
error = |(y_true - y_pred) / y_true|
94+
95+
For classification:
96+
- The error is computed as:
97+
error = |1 - P(true_class)|
98+
- The probability of the true class is retrieved using the index:
99+
col_index = classes.index(label_code)
100+
where:
101+
* `classes` is the ordered list of label codes coming from the model
102+
* `label_code` is the true label from y_target
103+
* `proba_values.iloc[:, col_index]` corresponds to P(class == label_code)
89104
90105
Parameters
91106
----------
92107
y_target : pandas.DataFrame
93-
1-column dataframe containing the targets.
108+
One-column DataFrame containing the ground truth labels.
94109
y_pred : pandas.DataFrame
95-
1-column dataframe containing the predictions.
96-
case : str
97-
model case
110+
One-column DataFrame containing the predicted labels.
111+
model_type : str
112+
Either "regression" or "classification".
113+
proba_values : pandas.DataFrame, optional
114+
DataFrame of class probabilities returned by model.predict_proba().
115+
Each column corresponds to a class, in the same order as in `classes`.
116+
classes : list, optional
117+
Ordered list of class label codes (`model.classes_`), used to map the
118+
true label to the correct probability column.
98119
99120
Returns
100121
-------
101122
pandas.DataFrame
102-
1-column dataframe containing the prediction errors.
123+
One-column DataFrame containing the prediction errors, named "_error_".
103124
"""
104-
prediction_error = None
105-
if y_target is not None and y_pred is not None and case == "regression":
125+
126+
if y_target is None or y_pred is None:
127+
return None
128+
129+
# ================= REGRESSION =================
130+
if model_type == "regression":
106131
if (y_target == 0).any().iloc[0]:
107132
prediction_error = abs(y_target.values - y_pred.values)
108133
else:
109134
prediction_error = abs((y_target.values - y_pred.values) / y_target.values)
110-
prediction_error = pd.DataFrame(prediction_error, index=y_target.index, columns=["_error_"])
111-
return prediction_error
135+
136+
return pd.DataFrame(prediction_error, index=y_target.index, columns=["_error_"])
137+
138+
# ================= CLASSIFICATION =================
139+
elif model_type == "classification":
140+
if proba_values is None:
141+
prediction_error = (y_target.values != y_pred.values).astype(int)
142+
return pd.DataFrame(prediction_error, index=y_target.index, columns=["_error_"])
143+
144+
# classes = order of model.classes_
145+
true_labels = y_target.iloc[:, 0]
146+
errors = []
147+
148+
for idx, label_code in true_labels.items():
149+
try:
150+
col_index = classes.index(label_code)
151+
except ValueError as err:
152+
raise ValueError(f"Label_code {label_code} not found in classes list: {classes}") from err
153+
154+
proba_true_class = proba_values.iloc[idx, col_index]
155+
errors.append(abs(1 - proba_true_class))
156+
157+
return pd.DataFrame(errors, index=y_target.index, columns=["_error_"])

shapash/webapp/smart_app.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ def __init__(self, explainer, settings: dict = None):
112112
self.predict_col = ["_predict_"]
113113
self.special_cols = ["_index_", "_predict_"]
114114
if self.explainer.y_target is not None:
115-
self.special_cols.append("_target_")
116-
if self.explainer._case == "regression":
117-
self.special_cols.append("_error_")
115+
self.special_cols.extend(["_target_", "_error_"])
118116
self.explainer.features_imp = self.explainer.state.compute_features_import(self.explainer.contributions)
119117
if self.explainer._case == "classification":
120118
self.label = self.explainer.check_label_name(len(self.explainer._classes) - 1, "num")[1]
@@ -181,8 +179,7 @@ def init_data(self, rows=None):
181179
self.dataframe = self.dataframe.join(
182180
self.explainer.y_target.rename(columns={self.explainer.y_target.columns[0]: "_target_"}),
183181
)
184-
if self.explainer._case == "regression":
185-
self.dataframe = self.dataframe.join(self.explainer.prediction_error)
182+
self.dataframe = self.dataframe.join(self.explainer.prediction_error)
186183

187184
if isinstance(self.explainer.columns_order, list):
188185
special_cols_remaining = [col for col in self.special_cols if col not in self.explainer.columns_order]

tests/unit_tests/utils/test_model.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,74 @@
55
from shapash.utils.model import predict_error
66

77
y1 = pd.DataFrame(data=np.array([1, 2, 3]), columns=["pred"])
8-
expected1 = pd.DataFrame(data=np.array([0.0, 0.0, 0.0]), columns=["_error_"])
9-
108
y2 = pd.DataFrame(data=np.array([0, 2, 3]), columns=["pred"])
9+
y3 = pd.DataFrame(data=np.array([2, 2, 3]), columns=["pred"])
10+
11+
expected1 = pd.DataFrame(data=np.array([0.0, 0.0, 0.0]), columns=["_error_"])
1112
expected2 = pd.DataFrame(data=np.array([1, 0, 0]), columns=["_error_"])
13+
expected3 = pd.DataFrame(data=np.array([0, 0, 0]), columns=["_error_"])
14+
expected_proba1 = pd.DataFrame({"_error_": [0.9, 0.6, 0.4]})
15+
expected_proba2 = pd.DataFrame({"_error_": [0.5, 0.4, 0.6]})
16+
17+
proba_values1 = pd.DataFrame(
18+
[[0.1, 0.7, 0.2],
19+
[0.3, 0.4, 0.3],
20+
[0.2, 0.2, 0.6]],
21+
columns=[1, 2, 3]
22+
)
23+
24+
proba_values2 = pd.DataFrame(
25+
[
26+
[0.2, 0.5, 0.3],
27+
[0.1, 0.6, 0.3],
28+
[0.3, 0.3, 0.4],
29+
],
30+
columns=[1, 2, 3]
31+
)
1232

33+
classes = [1, 2, 3]
1334

1435
@pytest.mark.parametrize(
15-
"y_target, y_pred, case, expected",
36+
"y_target, y_pred, model_type, proba_values, classes, expected",
1637
[
17-
(None, None, "classification", None),
18-
(y1, y1, "classification", None),
19-
(y1, None, "regression", None),
20-
(None, y1, "regression", None),
21-
(y1, y1, "regression", expected1),
22-
(y2, y1, "regression", expected2),
38+
# -------------------------------
39+
# Classification — invalid inputs
40+
# -------------------------------
41+
(None, None, "classification", None, None, None),
42+
(y1, None, "classification", None, None, None),
43+
(None, y1, "classification", None, None, None),
44+
45+
# -------------------------------
46+
# Classification — simple 0/1 error
47+
# -------------------------------
48+
(y1, y1, "classification", None, None, expected3),
49+
(y2, y1, "classification", None, None, expected2),
50+
51+
# -------------------------------
52+
# Classification — with proba
53+
# error = |1 - P(true_class)|
54+
# -------------------------------
55+
(y1, y1, "classification", proba_values1, classes, expected_proba1),
56+
(y3, y1, "classification", proba_values2, classes, expected_proba2),
57+
58+
# -------------------------------
59+
# Regression — invalid inputs
60+
# -------------------------------
61+
(y1, None, "regression", None, None, None),
62+
(None, y1, "regression", None, None, None),
63+
64+
# -------------------------------
65+
# Regression — working cases
66+
# -------------------------------
67+
(y1, y1, "regression", None, None, expected1),
68+
(y2, y1, "regression", None, None, expected2),
2369
],
2470
)
25-
def test_predict_error_works(y_target, y_pred, case, expected):
26-
result = predict_error(y_target, y_pred, case)
27-
if result is not None:
28-
assert not pd.testing.assert_frame_equal(result, expected)
71+
def test_predict_error_works(y_target, y_pred, model_type, proba_values, classes, expected):
72+
result = predict_error(y_target, y_pred, model_type, proba_values, classes)
73+
74+
if expected is None:
75+
assert result is None
2976
else:
30-
assert result == expected
77+
# DataFrame comparison
78+
pd.testing.assert_frame_equal(result, expected)

tests/unit_tests/webapp/utils/test_callbacks.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(self, *args, **kwargs):
7777
}
7878
]
7979
}
80-
self.special_cols = ["_index_", "_predict_", "_target_"]
80+
self.special_cols = ["_index_", "_predict_", "_target_", "_error_"]
8181

8282
super().__init__(*args, **kwargs)
8383

@@ -87,6 +87,7 @@ def test_default_init_data(self):
8787
"_index_": [0, 1, 2, 3, 4],
8888
"_predict_": [0, 0, 0, 1, 1],
8989
"_target_": [0, 0, 0, 1, 1],
90+
"_error_": [0.0, 0.0, 0.0, 0.0, 0.0],
9091
"column1": [1, 2, 3, 4, 5],
9192
"column3": [1.1, 3.3, 2.2, 4.4, 5.5],
9293
"_column2": ["a", "b", "c", "d", "e"],
@@ -318,19 +319,20 @@ def test_get_id_card_features(self):
318319
selected_row = get_id_card_features(data, 3, self.special_cols, features_dict)
319320
expected_result = pd.DataFrame(
320321
{
321-
"feature_value": [3, 1, 1, 4, 4.4, "d", False, pd.Timestamp("2023-01-04")],
322+
"feature_value": [3, 1, 1, 0, 4, 4.4, "d", False, pd.Timestamp("2023-01-04")],
322323
"feature_name": [
323324
"_index_",
324325
"_predict_",
325326
"_target_",
327+
"_error_",
326328
"column1",
327329
"Useless col",
328330
"_Additional col",
329331
"_column4",
330332
"_column5",
331333
],
332334
},
333-
index=["_index_", "_predict_", "_target_", "column1", "column3", "_column2", "_column4", "_column5"],
335+
index=["_index_", "_predict_", "_target_", "_error_", "column1", "column3", "_column2", "_column4", "_column5"],
334336
)
335337
pd.testing.assert_frame_equal(selected_row, expected_result)
336338

@@ -343,19 +345,20 @@ def test_get_id_card_contrib(self):
343345
def test_create_id_card_data(self):
344346
selected_row = pd.DataFrame(
345347
{
346-
"feature_value": [3, 1, 1, 4, 4.4, "d", False, pd.Timestamp("2023-01-04")],
348+
"feature_value": [3, 1, 1, 0, 4, 4.4, "d", False, pd.Timestamp("2023-01-04")],
347349
"feature_name": [
348350
"_index_",
349351
"_predict_",
350352
"_target_",
353+
"_error_",
351354
"column1",
352355
"Useless col",
353356
"_Additional col",
354357
"_column4",
355358
"_column5",
356359
],
357360
},
358-
index=["_index_", "_predict_", "_target_", "column1", "column3", "_column2", "_column4", "_column5"],
361+
index=["_index_", "_predict_", "_target_", "_error_", "column1", "column3", "_column2", "_column4", "_column5"],
359362
)
360363

361364
selected_contrib = pd.DataFrame(
@@ -370,34 +373,35 @@ def test_create_id_card_data(self):
370373
)
371374
expected_result = pd.DataFrame(
372375
{
373-
"feature_value": [3, 1, 1, 4.4, 4, "d", False, pd.Timestamp("2023-01-04")],
376+
"feature_value": [3, 1, 1, 0, 4.4, 4, "d", False, pd.Timestamp("2023-01-04")],
374377
"feature_name": [
375378
"_index_",
376379
"_predict_",
377380
"_target_",
381+
"_error_",
378382
"Useless col",
379383
"column1",
380384
"_Additional col",
381385
"_column4",
382386
"_column5",
383387
],
384-
"feature_contrib": [np.nan, np.nan, np.nan, 0.0, -0.6, np.nan, np.nan, np.nan],
388+
"feature_contrib": [np.nan, np.nan, np.nan, np.nan, 0.0, -0.6, np.nan, np.nan, np.nan],
385389
},
386-
index=["_index_", "_predict_", "_target_", "column3", "column1", "_column2", "_column4", "_column5"],
390+
index=["_index_", "_predict_", "_target_", "_error_", "column3", "column1", "_column2", "_column4", "_column5"],
387391
)
388392
pd.testing.assert_frame_equal(selected_data, expected_result)
389393

390394
def test_create_id_card_layout(self):
391395
selected_data = pd.DataFrame(
392396
{
393-
"feature_value": [3, 1, 1, 4.4, 4, "d"],
394-
"feature_name": ["_index_", "_predict_", "_target_", "Useless col", "column1", "_Additional col"],
395-
"feature_contrib": [np.nan, np.nan, np.nan, 0.0, -0.6, np.nan],
397+
"feature_value": [3, 1, 1, 0, 4.4, 4, "d"],
398+
"feature_name": ["_index_", "_predict_", "_target_", "_error_", "Useless col", "column1", "_Additional col"],
399+
"feature_contrib": [np.nan, np.nan, np.nan, np.nan, 0.0, -0.6, np.nan],
396400
},
397-
index=["_index_", "_predict_", "_target_", "column3", "column1", "_column2"],
401+
index=["_index_", "_predict_", "_target_", "_error_","column3", "column1", "_column2"],
398402
)
399403
children = create_id_card_layout(selected_data, self.xpl.additional_features_dict)
400-
assert len(children) == 6
404+
assert len(children) == 7
401405

402406
def test_get_feature_filter_options(self):
403407
features_dict = copy.deepcopy(self.xpl.features_dict)
@@ -407,6 +411,7 @@ def test_get_feature_filter_options(self):
407411
"_index_",
408412
"_predict_",
409413
"_target_",
414+
"_error_",
410415
"Useless col",
411416
"_Additional col",
412417
"_column4",

0 commit comments

Comments
 (0)