Skip to content

Commit 22d324d

Browse files
Merge pull request #661 from guillaume-vignal/bugfeatures_dict
Synchronize `features_dict` with dataset columns
2 parents 4d2fb07 + bf1c933 commit 22d324d

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

shapash/explainer/smart_explainer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -716,10 +716,20 @@ def check_label_dict(self):
716716

717717
def check_features_dict(self):
718718
"""
719-
Check the features_dict and add the necessary keys if all the
720-
input X columns are not present
719+
Synchronize features_dict with dataset columns:
720+
- Remove features not present in dataset
721+
- Add missing dataset features to features_dict
721722
"""
722-
for feature in set(list(self.columns_dict.values())) - set(list(self.features_dict)):
723+
724+
dataset_features = set(self.columns_dict.values())
725+
current_features = set(self.features_dict.keys())
726+
727+
# Remove features not present in dataset
728+
for feature in current_features - dataset_features:
729+
self.features_dict.pop(feature, None)
730+
731+
# Add features present in dataset but missing in features_dict
732+
for feature in dataset_features - current_features:
723733
self.features_dict[feature] = feature
724734

725735
def _update_features_dict_with_groups(self, features_groups):

tests/unit_tests/explainer/test_smart_explainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,10 @@ def test_check_features_dict_1(self):
170170
"""
171171
Unit test check features dict 1
172172
"""
173-
xpl = SmartExplainer(self.model, features_dict={"Age": "Age (Years Old)"})
173+
xpl = SmartExplainer(self.model, features_dict={"Age": "Age (Years Old)", "Place": "Place of Residence"})
174174
xpl.columns_dict = {0: "Age", 1: "Education", 2: "Sex"}
175175
xpl.check_features_dict()
176+
assert len(xpl.features_dict) == 3
176177
assert xpl.features_dict["Age"] == "Age (Years Old)"
177178
assert xpl.features_dict["Education"] == "Education"
178179

0 commit comments

Comments
 (0)