Skip to content

Commit 3c8a621

Browse files
Merge pull request #589 from guillaume-vignal/feature/smartplotter_simplification
Refactor DataFrame Column Transformation to Avoid Future Warning
2 parents 4660ec7 + 93b1fcd commit 3c8a621

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

shapash/explainer/smart_plotter.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,6 @@ def contribution_plot(
431431
--------
432432
>>> xpl.plot.contribution_plot(0)
433433
"""
434-
435434
if self._explainer._case == "classification":
436435
label_num, _, label_value = self._explainer.check_label_name(label)
437436

@@ -505,8 +504,13 @@ def contribution_plot(
505504
else:
506505
feature_values = self._explainer.x_init.loc[list_ind, col_name]
507506

508-
if self.explainer.x_init[col_name].dtype == 'bool':
509-
feature_values = feature_values.astype(int)
507+
if isinstance(col_name, list):
508+
for el in col_name:
509+
if feature_values[el].dtype == "bool":
510+
feature_values[el] = feature_values[el].astype(int)
511+
else:
512+
if feature_values.dtype == "bool":
513+
feature_values = feature_values.astype(int)
510514

511515
if col_is_group:
512516
feature_values = project_feature_values_1d(
@@ -1131,13 +1135,9 @@ def interactions_plot(
11311135

11321136
# add break line to X label if necessary
11331137
max_len_by_row = max([round(50 / self._explainer.features_desc[feature_values1.columns.values[0]]), 8])
1134-
feature_values1.iloc[:, 0] = feature_values1.iloc[:, 0].apply(
1135-
add_line_break,
1136-
args=(
1137-
max_len_by_row,
1138-
120,
1139-
),
1140-
)
1138+
args = (max_len_by_row, 120)
1139+
feature_values_str = feature_values1.iloc[:, 0].apply(add_line_break, args=args)
1140+
feature_values1 = pd.DataFrame({feature_values1.columns[0]: feature_values_str})
11411141

11421142
# selecting the best plot : Scatter, Violin?
11431143
if col_value_count1 > violin_maxf:

shapash/plots/plot_contribution.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def plot_scatter(
8585

8686
# add break line to X label if necessary
8787
args = (max_len_by_row, 120)
88-
feature_values.iloc[:, 0] = feature_values.iloc[:, 0].apply(add_line_break, args=args)
88+
feature_values_str = feature_values.iloc[:, 0].apply(add_line_break, args=args)
89+
feature_values = pd.DataFrame({column_name: feature_values_str})
8990

9091
if pred is not None:
9192
hv_text = [f"Id: {x}<br />Predict: {y}" for x, y in zip(feature_values.index, pred.values.flatten())]
@@ -267,7 +268,8 @@ def plot_violin(
267268

268269
# add break line to X label if necessary
269270
args = (max_len_by_row, 120)
270-
feature_values.iloc[:, 0] = feature_values.iloc[:, 0].apply(add_line_break, args=args)
271+
feature_values_str = feature_values.iloc[:, 0].apply(add_line_break, args=args)
272+
feature_values = pd.DataFrame({column_name: feature_values_str})
271273

272274
contributions = contributions.loc[feature_values.index]
273275
if pred is not None:

shapash/utils/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,14 @@ def is_nested_list(object_param):
7373
return any(isinstance(elem, list) for elem in object_param)
7474

7575

76-
def add_line_break(text, nbchar, maxlen=150):
76+
def add_line_break(value, nbchar, maxlen=150):
7777
"""
7878
adding line break in string if necessary
7979
8080
Parameters
8181
----------
82-
text : string
83-
string to check in order to add line break
82+
value : string or oither type
83+
if string to check in order to add line break
8484
nbchar : int
8585
number of characters before line break
8686
maxlen : int
@@ -91,10 +91,10 @@ def add_line_break(text, nbchar, maxlen=150):
9191
string
9292
original text + line break
9393
"""
94-
if isinstance(text, str):
94+
if isinstance(value, str):
9595
length = 0
9696
tot_length = 0
97-
input_word = text.split()
97+
input_word = value.split()
9898
final_sep = []
9999
for w in input_word[:-1]:
100100
length = length + len(w)
@@ -113,7 +113,7 @@ def add_line_break(text, nbchar, maxlen=150):
113113
new_string = "".join(sum(zip(input_word, final_sep + [""]), ())[:-1]) + last_char
114114
return new_string
115115
else:
116-
return text
116+
return value
117117

118118

119119
def truncate_str(text, maxlen=40):

0 commit comments

Comments
 (0)