@@ -336,11 +336,13 @@ def plot_roc_curve(y_true, y_probas, title='ROC Curves',
336
336
return ax
337
337
338
338
339
- def plot_roc (y_true , y_probas , title = 'ROC Curves' ,
340
- plot_micro = True , plot_macro = True , classes_to_plot = None ,
341
- ax = None , figsize = None , cmap = 'nipy_spectral' ,
342
- title_fontsize = "large" , text_fontsize = "medium" ,
343
- show_labels = True ,):
339
+ def plot_roc (
340
+ y_true , y_probas , title = 'ROC Curves' ,
341
+ plot_micro = True , plot_macro = True , classes_to_plot = None ,
342
+ ax = None , figsize = None , cmap = 'nipy_spectral' ,
343
+ title_fontsize = "large" , text_fontsize = "medium" ,
344
+ show_labels = True , digits = 3 ,
345
+ ):
344
346
"""Generates the ROC curves from labels and predicted scores/probabilities
345
347
346
348
Args:
@@ -386,6 +388,9 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
386
388
show_labels (boolean, optional): Shows the labels in the plot.
387
389
Defaults to ``True``.
388
390
391
+ digits (int, optional): Number of digits for formatting output floating point values.
392
+ Use e.g. 2 or 4. Defaults to 3.
393
+
389
394
Returns:
390
395
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
391
396
drawn.
@@ -428,8 +433,8 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
428
433
roc_auc = auc (fpr_dict [i ], tpr_dict [i ])
429
434
color = plt .cm .get_cmap (cmap )(float (i ) / len (classes ))
430
435
ax .plot (fpr_dict [i ], tpr_dict [i ], lw = 2 , color = color ,
431
- label = 'ROC curve of class {0} (area = {1:0.2f })'
432
- '' .format (classes [i ], roc_auc ))
436
+ label = 'ROC curve of class {0} (area = {1:.{digits}f })'
437
+ '' .format (classes [i ], roc_auc , digits = digits ))
433
438
434
439
if plot_micro :
435
440
binarized_y_true = label_binarize (y_true , classes = classes )
@@ -440,7 +445,7 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
440
445
roc_auc = auc (fpr , tpr )
441
446
ax .plot (fpr , tpr ,
442
447
label = 'micro-average ROC curve '
443
- '(area = {0:0.2f} )' .format (roc_auc ),
448
+ '(area = {0:.{digits}f} )' .format (roc_auc , digits = digits ),
444
449
color = 'deeppink' , linestyle = ':' , linewidth = 4 )
445
450
446
451
if plot_macro :
@@ -459,7 +464,7 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
459
464
460
465
ax .plot (all_fpr , mean_tpr ,
461
466
label = 'macro-average ROC curve '
462
- '(area = {0:0.2f} )' .format (roc_auc ),
467
+ '(area = {0:.{digits}f} )' .format (roc_auc , digits = digits ),
463
468
color = 'navy' , linestyle = ':' , linewidth = 4 )
464
469
465
470
ax .plot ([0 , 1 ], [0 , 1 ], 'k--' , lw = 2 )
@@ -475,7 +480,7 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
475
480
476
481
def plot_ks_statistic (y_true , y_probas , title = 'KS Statistic Plot' ,
477
482
ax = None , figsize = None , title_fontsize = "large" ,
478
- text_fontsize = "medium" ):
483
+ text_fontsize = "medium" , digits = 3 ):
479
484
"""Generates the KS Statistic plot from labels and scores/probabilities
480
485
481
486
Args:
@@ -503,6 +508,9 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot',
503
508
Use e.g. "small", "medium", "large" or integer-values. Defaults to
504
509
"medium".
505
510
511
+ digits (int, optional): Number of digits for formatting output floating point values.
512
+ Use e.g. 2 or 4. Defaults to 3.
513
+
506
514
Returns:
507
515
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
508
516
drawn.
@@ -543,9 +551,10 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot',
543
551
ax .plot (thresholds , pct2 , lw = 3 , label = 'Class {}' .format (classes [1 ]))
544
552
idx = np .where (thresholds == max_distance_at )[0 ][0 ]
545
553
ax .axvline (max_distance_at , * sorted ([pct1 [idx ], pct2 [idx ]]),
546
- label = 'KS Statistic: {:.3f} at {:.3f}' .format (ks_statistic ,
547
- max_distance_at ),
548
- linestyle = ':' , lw = 3 , color = 'black' )
554
+ label = 'KS Statistic: {:.{digits}f} at {:.{digits}f}' .format (
555
+ ks_statistic , max_distance_at , digits = digits
556
+ ),
557
+ linestyle = ':' , lw = 3 , color = 'black' )
549
558
550
559
ax .set_xlim ([0.0 , 1.0 ])
551
560
ax .set_ylim ([0.0 , 1.0 ])
@@ -685,13 +694,16 @@ def plot_precision_recall_curve(y_true, y_probas,
685
694
return ax
686
695
687
696
688
- def plot_precision_recall (y_true , y_probas ,
689
- title = 'Precision-Recall Curve' ,
690
- plot_micro = True ,
691
- classes_to_plot = None , ax = None ,
692
- figsize = None , cmap = 'nipy_spectral' ,
693
- title_fontsize = "large" ,
694
- text_fontsize = "medium" ):
697
+ def plot_precision_recall (
698
+ y_true , y_probas ,
699
+ title = 'Precision-Recall Curve' ,
700
+ plot_micro = True ,
701
+ classes_to_plot = None , ax = None ,
702
+ figsize = None , cmap = 'nipy_spectral' ,
703
+ title_fontsize = "large" ,
704
+ text_fontsize = "medium" ,
705
+ digits = 3 ,
706
+ ):
695
707
"""Generates the Precision Recall Curve from labels and probabilities
696
708
697
709
Args:
@@ -731,6 +743,9 @@ def plot_precision_recall(y_true, y_probas,
731
743
Use e.g. "small", "medium", "large" or integer-values. Defaults to
732
744
"medium".
733
745
746
+ digits (int, optional): Number of digits for formatting output floating point values.
747
+ Use e.g. 2 or 4. Defaults to 3.
748
+
734
749
Returns:
735
750
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
736
751
drawn.
@@ -778,8 +793,9 @@ def plot_precision_recall(y_true, y_probas,
778
793
color = plt .cm .get_cmap (cmap )(float (i ) / len (classes ))
779
794
ax .plot (recall , precision , lw = 2 ,
780
795
label = 'Precision-recall curve of class {0} '
781
- '(area = {1:0.3f})' .format (classes [i ],
782
- average_precision ),
796
+ '(area = {1:.{digits}f})' .format (classes [i ],
797
+ average_precision ,
798
+ digits = digits ),
783
799
color = color )
784
800
785
801
if plot_micro :
@@ -790,7 +806,7 @@ def plot_precision_recall(y_true, y_probas,
790
806
average = 'micro' )
791
807
ax .plot (recall , precision ,
792
808
label = 'micro-average Precision-recall curve '
793
- '(area = {0:0.3f} )' .format (average_precision ),
809
+ '(area = {0:.{digits}f} )' .format (average_precision , digits = digits ),
794
810
color = 'navy' , linestyle = ':' , linewidth = 4 )
795
811
796
812
ax .set_xlim ([0.0 , 1.0 ])
@@ -802,10 +818,12 @@ def plot_precision_recall(y_true, y_probas,
802
818
return ax
803
819
804
820
805
- def plot_silhouette (X , cluster_labels , title = 'Silhouette Analysis' ,
806
- metric = 'euclidean' , copy = True , ax = None , figsize = None ,
807
- cmap = 'nipy_spectral' , title_fontsize = "large" ,
808
- text_fontsize = "medium" ):
821
+ def plot_silhouette (
822
+ X , cluster_labels , title = 'Silhouette Analysis' ,
823
+ metric = 'euclidean' , copy = True , ax = None , figsize = None ,
824
+ cmap = 'nipy_spectral' , title_fontsize = "large" ,
825
+ text_fontsize = "medium" , digits = 3 ,
826
+ ):
809
827
"""Plots silhouette analysis of clusters provided.
810
828
811
829
Args:
@@ -847,6 +865,9 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
847
865
Use e.g. "small", "medium", "large" or integer-values. Defaults to
848
866
"medium".
849
867
868
+ digits (int, optional): Number of digits for formatting output floating point values.
869
+ Use e.g. 2 or 4. Defaults to 3.
870
+
850
871
Returns:
851
872
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
852
873
drawn.
@@ -908,8 +929,10 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
908
929
909
930
y_lower = y_upper + 10
910
931
911
- ax .axvline (x = silhouette_avg , color = "red" , linestyle = "--" ,
912
- label = 'Silhouette score: {0:0.3f}' .format (silhouette_avg ))
932
+ ax .axvline (
933
+ x = silhouette_avg , color = "red" , linestyle = "--" ,
934
+ label = 'Silhouette score: {0:.{digits}f}' .format (silhouette_avg , digits = 2 )
935
+ )
913
936
914
937
ax .set_yticks ([]) # Clear the y-axis labels / ticks
915
938
ax .set_xticks (np .arange (- 0.1 , 1.0 , 0.2 ))
@@ -920,11 +943,13 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
920
943
return ax
921
944
922
945
923
- def plot_calibration_curve (y_true , probas_list , clf_names = None , n_bins = 10 ,
924
- title = 'Calibration plots (Reliability Curves)' ,
925
- ax = None , figsize = None , cmap = 'nipy_spectral' ,
926
- title_fontsize = "large" , text_fontsize = "medium" ,
927
- pos_label = None , strategy = "uniform" ,):
946
+ def plot_calibration_curve (
947
+ y_true , probas_list , clf_names = None , n_bins = 10 ,
948
+ title = 'Calibration plots (Reliability Curves)' ,
949
+ ax = None , figsize = None , cmap = 'nipy_spectral' ,
950
+ title_fontsize = "large" , text_fontsize = "medium" ,
951
+ pos_label = None , strategy = "uniform" ,
952
+ ):
928
953
"""Plots calibration curves for a set of classifier probability estimates.
929
954
930
955
Plotting the calibration curves of a classifier is useful for determining
@@ -1073,9 +1098,13 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
1073
1098
return ax
1074
1099
1075
1100
1076
- def plot_cumulative_gain (y_true , y_probas , title = 'Cumulative Gains Curve' ,
1077
- ax = None , figsize = None , title_fontsize = "large" ,
1078
- text_fontsize = "medium" , class_names = None ):
1101
+ def plot_cumulative_gain (
1102
+ y_true , y_probas , title = 'Cumulative Gains Curve' ,
1103
+ classes_to_plot = None , plot_micro = True , plot_macro = True ,
1104
+ ax = None , figsize = None , title_fontsize = "large" ,
1105
+ text_fontsize = "medium" , cmap = 'nipy_spectral' ,
1106
+ class_names = None ,
1107
+ ):
1079
1108
"""Generates the Cumulative Gains Plot from labels and scores/probabilities
1080
1109
1081
1110
The cumulative gains chart is used to determine the effectiveness of a
@@ -1093,6 +1122,17 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
1093
1122
title (string, optional): Title of the generated plot. Defaults to
1094
1123
"Cumulative Gains Curve".
1095
1124
1125
+ classes_to_plot (list-like, optional): Classes for which the Cumulative Gain
1126
+ curve should be plotted. e.g. [0, 'cold']. If given class does not exist,
1127
+ it will be ignored. If ``None``, all classes will be plotted. Defaults to
1128
+ ``None``
1129
+
1130
+ plot_micro (boolean, optional): Plot the micro average ROC curve.
1131
+ Defaults to ``True``.
1132
+
1133
+ plot_macro (boolean, optional): Plot the macro average ROC curve.
1134
+ Defaults to ``True``.
1135
+
1096
1136
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
1097
1137
plot the learning curve. If None, the plot is drawn on a new set of
1098
1138
axes.
@@ -1107,6 +1147,11 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
1107
1147
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
1108
1148
Use e.g. "small", "medium", "large" or integer-values. Defaults to
1109
1149
"medium".
1150
+
1151
+ cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
1152
+ Colormap used for plotting the projection. View Matplotlib Colormap
1153
+ documentation for available options.
1154
+ https://matplotlib.org/users/colormaps.html
1110
1155
1111
1156
class_names (list of strings, optional): List of class names. Used for
1112
1157
the legend. Order should be synchronized with the order of classes
@@ -1129,28 +1174,58 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
1129
1174
:align: center
1130
1175
:alt: Cumulative Gains Plot
1131
1176
"""
1177
+ if ax is None :
1178
+ fig , ax = plt .subplots (1 , 1 , figsize = figsize )
1179
+ ax .set_title (title , fontsize = title_fontsize )
1180
+
1132
1181
y_true = np .array (y_true )
1133
1182
y_probas = np .array (y_probas )
1134
-
1135
1183
classes = np .unique (y_true )
1136
- if class_names is None : class_names = classes
1137
- if len (classes ) != 2 :
1184
+
1185
+ if classes_to_plot is None :
1186
+ classes_to_plot = classes
1187
+ if class_names is None : class_names = classes_to_plot
1188
+
1189
+ if len (classes_to_plot ) != 2 :
1138
1190
raise ValueError ('Cannot calculate Cumulative Gains for data with '
1139
1191
'{} category/ies' .format (len (classes )))
1140
1192
1141
- # Compute Cumulative Gain Curves
1142
- percentages , gains1 = cumulative_gain_curve (y_true , y_probas [:, 0 ],
1143
- classes [0 ])
1144
- percentages , gains2 = cumulative_gain_curve (y_true , y_probas [:, 1 ],
1145
- classes [1 ])
1193
+ perc_dict = dict ()
1194
+ gain_dict = dict ()
1146
1195
1147
- if ax is None :
1148
- fig , ax = plt .subplots (1 , 1 , figsize = figsize )
1196
+ indices_to_plot = np .isin (classes , classes_to_plot )
1197
+ # Loop for all classes to get different class gain
1198
+ for i , to_plot in enumerate (indices_to_plot ):
1199
+ perc_dict [i ], gain_dict [i ] = cumulative_gain_curve (y_true , y_probas [:, i ], pos_label = classes [i ])
1149
1200
1150
- ax .set_title (title , fontsize = title_fontsize )
1201
+ if to_plot :
1202
+ color = plt .cm .get_cmap (cmap )(float (i ) / len (classes ))
1203
+ ax .plot (perc_dict [i ], gain_dict [i ], lw = 2 , color = color ,
1204
+ label = 'Class {} Cumulative Gain curve' .format (class_names [i ]))
1151
1205
1152
- ax .plot (percentages , gains1 , lw = 3 , label = 'Class {}' .format (class_names [0 ]))
1153
- ax .plot (percentages , gains2 , lw = 3 , label = 'Class {}' .format (class_names [1 ]))
1206
+ # Whether or to plot macro or micro
1207
+ if plot_micro :
1208
+ binarized_y_true = label_binarize (y_true , classes = classes )
1209
+ if len (classes ) == 2 :
1210
+ binarized_y_true = np .hstack ((1 - binarized_y_true , binarized_y_true ))
1211
+
1212
+ perc , gain = cumulative_gain_curve (binarized_y_true .ravel (), y_probas .ravel ())
1213
+ ax .plot (perc , gain , label = 'micro-average Cumulative Gain curve' ,
1214
+ color = 'deeppink' , linestyle = ':' , linewidth = 4 )
1215
+
1216
+ if plot_macro :
1217
+ # First aggregate all percentages
1218
+ all_perc = np .unique (np .concatenate ([perc_dict [x ] for x in range (len (classes ))]))
1219
+
1220
+ # Then interpolate all cumulative gain
1221
+ mean_gain = np .zeros_like (all_perc )
1222
+ for i in range (len (classes )):
1223
+ mean_gain += np .interp (all_perc , perc_dict [i ], gain_dict [i ])
1224
+
1225
+ mean_gain /= len (classes )
1226
+
1227
+ ax .plot (all_perc , mean_gain , label = 'macro-average Cumulative Gain curve' ,
1228
+ color = 'navy' , linestyle = ':' , linewidth = 4 )
1154
1229
1155
1230
ax .set_xlim ([0.0 , 1.0 ])
1156
1231
ax .set_ylim ([0.0 , 1.0 ])
@@ -1159,16 +1234,19 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
1159
1234
1160
1235
ax .set_xlabel ('Percentage of sample' , fontsize = text_fontsize )
1161
1236
ax .set_ylabel ('Gain' , fontsize = text_fontsize )
1237
+
1162
1238
ax .tick_params (labelsize = text_fontsize )
1163
1239
ax .grid ('on' )
1164
1240
ax .legend (loc = 'lower right' , fontsize = text_fontsize )
1165
1241
1166
1242
return ax
1167
1243
1168
1244
1169
- def plot_lift_curve (y_true , y_probas , title = 'Lift Curve' ,
1170
- ax = None , figsize = None , title_fontsize = "large" ,
1171
- text_fontsize = "medium" , class_names = None ):
1245
+ def plot_lift_curve (
1246
+ y_true , y_probas , title = 'Lift Curve' ,
1247
+ ax = None , figsize = None , title_fontsize = "large" ,
1248
+ text_fontsize = "medium" , class_names = None
1249
+ ):
1172
1250
"""Generates the Lift Curve from labels and scores/probabilities
1173
1251
1174
1252
The lift curve is used to determine the effectiveness of a
0 commit comments