Skip to content

Commit 4cbedf9

Browse files
Added class_names optional arg to plot_cumulative_gain. reiinakano#109
1 parent c5faedb commit 4cbedf9

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

scikitplot/metrics.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
10721072

10731073
def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
10741074
ax=None, figsize=None, title_fontsize="large",
1075-
text_fontsize="medium"):
1075+
text_fontsize="medium", class_names = None):
10761076
"""Generates the Cumulative Gains Plot from labels and scores/probabilities
10771077
10781078
The cumulative gains chart is used to determine the effectiveness of a
@@ -1104,6 +1104,10 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11041104
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
11051105
Use e.g. "small", "medium", "large" or integer-values. Defaults to
11061106
"medium".
1107+
1108+
class_names (list of strings, optional): List of class names. Used for
1109+
the legend. Order should be synchronized with the order of classes
1110+
in y_probas.
11071111
11081112
Returns:
11091113
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
@@ -1126,6 +1130,7 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11261130
y_probas = np.array(y_probas)
11271131

11281132
classes = np.unique(y_true)
1133+
if class_names is None: class_names = classes
11291134
if len(classes) != 2:
11301135
raise ValueError('Cannot calculate Cumulative Gains for data with '
11311136
'{} category/ies'.format(len(classes)))
@@ -1141,8 +1146,8 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11411146

11421147
ax.set_title(title, fontsize=title_fontsize)
11431148

1144-
ax.plot(percentages, gains1, lw=3, label='Class {}'.format(classes[0]))
1145-
ax.plot(percentages, gains2, lw=3, label='Class {}'.format(classes[1]))
1149+
ax.plot(percentages, gains1, lw=3, label='Class {}'.format(class_names[0]))
1150+
ax.plot(percentages, gains2, lw=3, label='Class {}'.format(class_names[1]))
11461151

11471152
ax.set_xlim([0.0, 1.0])
11481153
ax.set_ylim([0.0, 1.0])
@@ -1160,7 +1165,7 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11601165

11611166
def plot_lift_curve(y_true, y_probas, title='Lift Curve',
11621167
ax=None, figsize=None, title_fontsize="large",
1163-
text_fontsize="medium"):
1168+
text_fontsize="medium", class_names = None):
11641169
"""Generates the Lift Curve from labels and scores/probabilities
11651170
11661171
The lift curve is used to determine the effectiveness of a
@@ -1192,6 +1197,10 @@ def plot_lift_curve(y_true, y_probas, title='Lift Curve',
11921197
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
11931198
Use e.g. "small", "medium", "large" or integer-values. Defaults to
11941199
"medium".
1200+
1201+
class_names (list of strings, optional): List of class names. Used for
1202+
the legend. Order should be synchronized with the order of classes
1203+
in y_probas.
11951204
11961205
Returns:
11971206
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
@@ -1214,6 +1223,7 @@ def plot_lift_curve(y_true, y_probas, title='Lift Curve',
12141223
y_probas = np.array(y_probas)
12151224

12161225
classes = np.unique(y_true)
1226+
if class_names is None: class_names = classes
12171227
if len(classes) != 2:
12181228
raise ValueError('Cannot calculate Lift Curve for data with '
12191229
'{} category/ies'.format(len(classes)))
@@ -1236,8 +1246,8 @@ def plot_lift_curve(y_true, y_probas, title='Lift Curve',
12361246

12371247
ax.set_title(title, fontsize=title_fontsize)
12381248

1239-
ax.plot(percentages, gains1, lw=3, label='Class {}'.format(classes[0]))
1240-
ax.plot(percentages, gains2, lw=3, label='Class {}'.format(classes[1]))
1249+
ax.plot(percentages, gains1, lw=3, label='Class {}'.format(class_names[0]))
1250+
ax.plot(percentages, gains2, lw=3, label='Class {}'.format(class_names[1]))
12411251

12421252
ax.plot([0, 1], [1, 1], 'k--', lw=2, label='Baseline')
12431253

scikitplot/tests/test_metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,8 @@ def test_plot_calibration(self):
464464
def test_string_classes(self):
465465
plot_calibration_curve(
466466
convert_labels_into_string(self.y),
467-
[self.lr_probas, self.rf_probas]
467+
[self.lr_probas, self.rf_probas],
468+
pos_label='1', # Explicitly setting pos_label
468469
)
469470

470471
def test_cmap(self):

0 commit comments

Comments
 (0)