Skip to content

Commit bdf1116

Browse files
Change cumulative curve like roc_curve reiinakano#98
1 parent 9bf40b5 commit bdf1116

File tree

1 file changed

+131
-53
lines changed

1 file changed

+131
-53
lines changed

scikitplot/metrics.py

Lines changed: 131 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,13 @@ def plot_roc_curve(y_true, y_probas, title='ROC Curves',
336336
return ax
337337

338338

339-
def plot_roc(y_true, y_probas, title='ROC Curves',
340-
plot_micro=True, plot_macro=True, classes_to_plot=None,
341-
ax=None, figsize=None, cmap='nipy_spectral',
342-
title_fontsize="large", text_fontsize="medium",
343-
show_labels=True,):
339+
def plot_roc(
340+
y_true, y_probas, title='ROC Curves',
341+
plot_micro=True, plot_macro=True, classes_to_plot=None,
342+
ax=None, figsize=None, cmap='nipy_spectral',
343+
title_fontsize="large", text_fontsize="medium",
344+
show_labels=True, digits=3,
345+
):
344346
"""Generates the ROC curves from labels and predicted scores/probabilities
345347
346348
Args:
@@ -386,6 +388,9 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
386388
show_labels (boolean, optional): Shows the labels in the plot.
387389
Defaults to ``True``.
388390
391+
digits (int, optional): Number of digits for formatting output floating point values.
392+
Use e.g. 2 or 4. Defaults to 3.
393+
389394
Returns:
390395
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
391396
drawn.
@@ -428,8 +433,8 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
428433
roc_auc = auc(fpr_dict[i], tpr_dict[i])
429434
color = plt.cm.get_cmap(cmap)(float(i) / len(classes))
430435
ax.plot(fpr_dict[i], tpr_dict[i], lw=2, color=color,
431-
label='ROC curve of class {0} (area = {1:0.2f})'
432-
''.format(classes[i], roc_auc))
436+
label='ROC curve of class {0} (area = {1:.{digits}f})'
437+
''.format(classes[i], roc_auc, digits=digits))
433438

434439
if plot_micro:
435440
binarized_y_true = label_binarize(y_true, classes=classes)
@@ -440,7 +445,7 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
440445
roc_auc = auc(fpr, tpr)
441446
ax.plot(fpr, tpr,
442447
label='micro-average ROC curve '
443-
'(area = {0:0.2f})'.format(roc_auc),
448+
'(area = {0:.{digits}f})'.format(roc_auc, digits=digits),
444449
color='deeppink', linestyle=':', linewidth=4)
445450

446451
if plot_macro:
@@ -459,7 +464,7 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
459464

460465
ax.plot(all_fpr, mean_tpr,
461466
label='macro-average ROC curve '
462-
'(area = {0:0.2f})'.format(roc_auc),
467+
'(area = {0:.{digits}f})'.format(roc_auc, digits=digits),
463468
color='navy', linestyle=':', linewidth=4)
464469

465470
ax.plot([0, 1], [0, 1], 'k--', lw=2)
@@ -475,7 +480,7 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
475480

476481
def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot',
477482
ax=None, figsize=None, title_fontsize="large",
478-
text_fontsize="medium"):
483+
text_fontsize="medium", digits=3):
479484
"""Generates the KS Statistic plot from labels and scores/probabilities
480485
481486
Args:
@@ -503,6 +508,9 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot',
503508
Use e.g. "small", "medium", "large" or integer-values. Defaults to
504509
"medium".
505510
511+
digits (int, optional): Number of digits for formatting output floating point values.
512+
Use e.g. 2 or 4. Defaults to 3.
513+
506514
Returns:
507515
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
508516
drawn.
@@ -543,9 +551,10 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot',
543551
ax.plot(thresholds, pct2, lw=3, label='Class {}'.format(classes[1]))
544552
idx = np.where(thresholds == max_distance_at)[0][0]
545553
ax.axvline(max_distance_at, *sorted([pct1[idx], pct2[idx]]),
546-
label='KS Statistic: {:.3f} at {:.3f}'.format(ks_statistic,
547-
max_distance_at),
548-
linestyle=':', lw=3, color='black')
554+
label = 'KS Statistic: {:.{digits}f} at {:.{digits}f}'.format(
555+
ks_statistic, max_distance_at, digits=digits
556+
),
557+
linestyle = ':', lw=3, color='black')
549558

550559
ax.set_xlim([0.0, 1.0])
551560
ax.set_ylim([0.0, 1.0])
@@ -685,13 +694,16 @@ def plot_precision_recall_curve(y_true, y_probas,
685694
return ax
686695

687696

688-
def plot_precision_recall(y_true, y_probas,
689-
title='Precision-Recall Curve',
690-
plot_micro=True,
691-
classes_to_plot=None, ax=None,
692-
figsize=None, cmap='nipy_spectral',
693-
title_fontsize="large",
694-
text_fontsize="medium"):
697+
def plot_precision_recall(
698+
y_true, y_probas,
699+
title='Precision-Recall Curve',
700+
plot_micro=True,
701+
classes_to_plot=None, ax=None,
702+
figsize=None, cmap='nipy_spectral',
703+
title_fontsize="large",
704+
text_fontsize="medium",
705+
digits=3,
706+
):
695707
"""Generates the Precision Recall Curve from labels and probabilities
696708
697709
Args:
@@ -731,6 +743,9 @@ def plot_precision_recall(y_true, y_probas,
731743
Use e.g. "small", "medium", "large" or integer-values. Defaults to
732744
"medium".
733745
746+
digits (int, optional): Number of digits for formatting output floating point values.
747+
Use e.g. 2 or 4. Defaults to 3.
748+
734749
Returns:
735750
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
736751
drawn.
@@ -778,8 +793,9 @@ def plot_precision_recall(y_true, y_probas,
778793
color = plt.cm.get_cmap(cmap)(float(i) / len(classes))
779794
ax.plot(recall, precision, lw=2,
780795
label='Precision-recall curve of class {0} '
781-
'(area = {1:0.3f})'.format(classes[i],
782-
average_precision),
796+
'(area = {1:.{digits}f})'.format(classes[i],
797+
average_precision,
798+
digits=digits),
783799
color=color)
784800

785801
if plot_micro:
@@ -790,7 +806,7 @@ def plot_precision_recall(y_true, y_probas,
790806
average='micro')
791807
ax.plot(recall, precision,
792808
label='micro-average Precision-recall curve '
793-
'(area = {0:0.3f})'.format(average_precision),
809+
'(area = {0:.{digits}f})'.format(average_precision, digits=digits),
794810
color='navy', linestyle=':', linewidth=4)
795811

796812
ax.set_xlim([0.0, 1.0])
@@ -802,10 +818,12 @@ def plot_precision_recall(y_true, y_probas,
802818
return ax
803819

804820

805-
def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
806-
metric='euclidean', copy=True, ax=None, figsize=None,
807-
cmap='nipy_spectral', title_fontsize="large",
808-
text_fontsize="medium"):
821+
def plot_silhouette(
822+
X, cluster_labels, title='Silhouette Analysis',
823+
metric='euclidean', copy=True, ax=None, figsize=None,
824+
cmap='nipy_spectral', title_fontsize="large",
825+
text_fontsize="medium", digits=3,
826+
):
809827
"""Plots silhouette analysis of clusters provided.
810828
811829
Args:
@@ -847,6 +865,9 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
847865
Use e.g. "small", "medium", "large" or integer-values. Defaults to
848866
"medium".
849867
868+
digits (int, optional): Number of digits for formatting output floating point values.
869+
Use e.g. 2 or 4. Defaults to 3.
870+
850871
Returns:
851872
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
852873
drawn.
@@ -908,8 +929,10 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
908929

909930
y_lower = y_upper + 10
910931

911-
ax.axvline(x=silhouette_avg, color="red", linestyle="--",
912-
label='Silhouette score: {0:0.3f}'.format(silhouette_avg))
932+
ax.axvline(
933+
x=silhouette_avg, color="red", linestyle="--",
934+
label='Silhouette score: {0:.{digits}f}'.format(silhouette_avg, digits=2)
935+
)
913936

914937
ax.set_yticks([]) # Clear the y-axis labels / ticks
915938
ax.set_xticks(np.arange(-0.1, 1.0, 0.2))
@@ -920,11 +943,13 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
920943
return ax
921944

922945

923-
def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
924-
title='Calibration plots (Reliability Curves)',
925-
ax=None, figsize=None, cmap='nipy_spectral',
926-
title_fontsize="large", text_fontsize="medium",
927-
pos_label=None, strategy="uniform",):
946+
def plot_calibration_curve(
947+
y_true, probas_list, clf_names=None, n_bins=10,
948+
title='Calibration plots (Reliability Curves)',
949+
ax=None, figsize=None, cmap='nipy_spectral',
950+
title_fontsize="large", text_fontsize="medium",
951+
pos_label=None, strategy="uniform",
952+
):
928953
"""Plots calibration curves for a set of classifier probability estimates.
929954
930955
Plotting the calibration curves of a classifier is useful for determining
@@ -1073,9 +1098,13 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
10731098
return ax
10741099

10751100

1076-
def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
1077-
ax=None, figsize=None, title_fontsize="large",
1078-
text_fontsize="medium", class_names = None):
1101+
def plot_cumulative_gain(
1102+
y_true, y_probas, title='Cumulative Gains Curve',
1103+
classes_to_plot=None, plot_micro=True, plot_macro=True,
1104+
ax=None, figsize=None, title_fontsize="large",
1105+
text_fontsize="medium", cmap='nipy_spectral',
1106+
class_names = None,
1107+
):
10791108
"""Generates the Cumulative Gains Plot from labels and scores/probabilities
10801109
10811110
The cumulative gains chart is used to determine the effectiveness of a
@@ -1093,6 +1122,17 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
10931122
title (string, optional): Title of the generated plot. Defaults to
10941123
"Cumulative Gains Curve".
10951124
1125+
classes_to_plot (list-like, optional): Classes for which the Cumulative Gain
1126+
curve should be plotted. e.g. [0, 'cold']. If given class does not exist,
1127+
it will be ignored. If ``None``, all classes will be plotted. Defaults to
1128+
``None``
1129+
1130+
plot_micro (boolean, optional): Plot the micro average ROC curve.
1131+
Defaults to ``True``.
1132+
1133+
plot_macro (boolean, optional): Plot the macro average ROC curve.
1134+
Defaults to ``True``.
1135+
10961136
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
10971137
plot the learning curve. If None, the plot is drawn on a new set of
10981138
axes.
@@ -1107,6 +1147,11 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11071147
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
11081148
Use e.g. "small", "medium", "large" or integer-values. Defaults to
11091149
"medium".
1150+
1151+
cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
1152+
Colormap used for plotting the projection. View Matplotlib Colormap
1153+
documentation for available options.
1154+
https://matplotlib.org/users/colormaps.html
11101155
11111156
class_names (list of strings, optional): List of class names. Used for
11121157
the legend. Order should be synchronized with the order of classes
@@ -1129,28 +1174,58 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11291174
:align: center
11301175
:alt: Cumulative Gains Plot
11311176
"""
1177+
if ax is None:
1178+
fig, ax = plt.subplots(1, 1, figsize=figsize)
1179+
ax.set_title(title, fontsize=title_fontsize)
1180+
11321181
y_true = np.array(y_true)
11331182
y_probas = np.array(y_probas)
1134-
11351183
classes = np.unique(y_true)
1136-
if class_names is None: class_names = classes
1137-
if len(classes) != 2:
1184+
1185+
if classes_to_plot is None:
1186+
classes_to_plot = classes
1187+
if class_names is None: class_names = classes_to_plot
1188+
1189+
if len(classes_to_plot) != 2:
11381190
raise ValueError('Cannot calculate Cumulative Gains for data with '
11391191
'{} category/ies'.format(len(classes)))
11401192

1141-
# Compute Cumulative Gain Curves
1142-
percentages, gains1 = cumulative_gain_curve(y_true, y_probas[:, 0],
1143-
classes[0])
1144-
percentages, gains2 = cumulative_gain_curve(y_true, y_probas[:, 1],
1145-
classes[1])
1193+
perc_dict = dict()
1194+
gain_dict = dict()
11461195

1147-
if ax is None:
1148-
fig, ax = plt.subplots(1, 1, figsize=figsize)
1196+
indices_to_plot = np.isin(classes, classes_to_plot)
1197+
# Loop for all classes to get different class gain
1198+
for i, to_plot in enumerate(indices_to_plot):
1199+
perc_dict[i], gain_dict[i] = cumulative_gain_curve(y_true, y_probas[:, i], pos_label=classes[i])
11491200

1150-
ax.set_title(title, fontsize=title_fontsize)
1201+
if to_plot:
1202+
color = plt.cm.get_cmap(cmap)(float(i) / len(classes))
1203+
ax.plot(perc_dict[i], gain_dict[i], lw=2, color=color,
1204+
label='Class {} Cumulative Gain curve'.format(class_names[i]))
11511205

1152-
ax.plot(percentages, gains1, lw=3, label='Class {}'.format(class_names[0]))
1153-
ax.plot(percentages, gains2, lw=3, label='Class {}'.format(class_names[1]))
1206+
# Whether or to plot macro or micro
1207+
if plot_micro:
1208+
binarized_y_true = label_binarize(y_true, classes=classes)
1209+
if len(classes) == 2:
1210+
binarized_y_true = np.hstack((1 - binarized_y_true, binarized_y_true))
1211+
1212+
perc, gain = cumulative_gain_curve(binarized_y_true.ravel(), y_probas.ravel())
1213+
ax.plot(perc, gain, label='micro-average Cumulative Gain curve',
1214+
color='deeppink', linestyle=':', linewidth=4)
1215+
1216+
if plot_macro:
1217+
# First aggregate all percentages
1218+
all_perc = np.unique(np.concatenate([perc_dict[x] for x in range(len(classes))]))
1219+
1220+
# Then interpolate all cumulative gain
1221+
mean_gain = np.zeros_like(all_perc)
1222+
for i in range(len(classes)):
1223+
mean_gain += np.interp(all_perc, perc_dict[i], gain_dict[i])
1224+
1225+
mean_gain /= len(classes)
1226+
1227+
ax.plot(all_perc, mean_gain, label='macro-average Cumulative Gain curve',
1228+
color='navy', linestyle=':', linewidth=4)
11541229

11551230
ax.set_xlim([0.0, 1.0])
11561231
ax.set_ylim([0.0, 1.0])
@@ -1159,16 +1234,19 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11591234

11601235
ax.set_xlabel('Percentage of sample', fontsize=text_fontsize)
11611236
ax.set_ylabel('Gain', fontsize=text_fontsize)
1237+
11621238
ax.tick_params(labelsize=text_fontsize)
11631239
ax.grid('on')
11641240
ax.legend(loc='lower right', fontsize=text_fontsize)
11651241

11661242
return ax
11671243

11681244

1169-
def plot_lift_curve(y_true, y_probas, title='Lift Curve',
1170-
ax=None, figsize=None, title_fontsize="large",
1171-
text_fontsize="medium", class_names = None):
1245+
def plot_lift_curve(
1246+
y_true, y_probas, title='Lift Curve',
1247+
ax=None, figsize=None, title_fontsize="large",
1248+
text_fontsize="medium", class_names = None
1249+
):
11721250
"""Generates the Lift Curve from labels and scores/probabilities
11731251
11741252
The lift curve is used to determine the effectiveness of a

0 commit comments

Comments
 (0)