@@ -1072,7 +1072,7 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
1072
1072
1073
1073
def plot_cumulative_gain (y_true , y_probas , title = 'Cumulative Gains Curve' ,
1074
1074
ax = None , figsize = None , title_fontsize = "large" ,
1075
- text_fontsize = "medium" ):
1075
+ text_fontsize = "medium" , class_names = None ):
1076
1076
"""Generates the Cumulative Gains Plot from labels and scores/probabilities
1077
1077
1078
1078
The cumulative gains chart is used to determine the effectiveness of a
@@ -1104,6 +1104,10 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
1104
1104
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
1105
1105
Use e.g. "small", "medium", "large" or integer-values. Defaults to
1106
1106
"medium".
1107
+
1108
+ class_names (list of strings, optional): List of class names. Used for
1109
+ the legend. Order should be synchronized with the order of classes
1110
+ in y_probas.
1107
1111
1108
1112
Returns:
1109
1113
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
@@ -1126,6 +1130,7 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
1126
1130
y_probas = np .array (y_probas )
1127
1131
1128
1132
classes = np .unique (y_true )
1133
+ if class_names is None : class_names = classes
1129
1134
if len (classes ) != 2 :
1130
1135
raise ValueError ('Cannot calculate Cumulative Gains for data with '
1131
1136
'{} category/ies' .format (len (classes )))
@@ -1141,8 +1146,8 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
1141
1146
1142
1147
ax .set_title (title , fontsize = title_fontsize )
1143
1148
1144
- ax .plot (percentages , gains1 , lw = 3 , label = 'Class {}' .format (classes [0 ]))
1145
- ax .plot (percentages , gains2 , lw = 3 , label = 'Class {}' .format (classes [1 ]))
1149
+ ax .plot (percentages , gains1 , lw = 3 , label = 'Class {}' .format (class_names [0 ]))
1150
+ ax .plot (percentages , gains2 , lw = 3 , label = 'Class {}' .format (class_names [1 ]))
1146
1151
1147
1152
ax .set_xlim ([0.0 , 1.0 ])
1148
1153
ax .set_ylim ([0.0 , 1.0 ])
@@ -1160,7 +1165,7 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
1160
1165
1161
1166
def plot_lift_curve (y_true , y_probas , title = 'Lift Curve' ,
1162
1167
ax = None , figsize = None , title_fontsize = "large" ,
1163
- text_fontsize = "medium" ):
1168
+ text_fontsize = "medium" , class_names = None ):
1164
1169
"""Generates the Lift Curve from labels and scores/probabilities
1165
1170
1166
1171
The lift curve is used to determine the effectiveness of a
@@ -1192,6 +1197,10 @@ def plot_lift_curve(y_true, y_probas, title='Lift Curve',
1192
1197
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
1193
1198
Use e.g. "small", "medium", "large" or integer-values. Defaults to
1194
1199
"medium".
1200
+
1201
+ class_names (list of strings, optional): List of class names. Used for
1202
+ the legend. Order should be synchronized with the order of classes
1203
+ in y_probas.
1195
1204
1196
1205
Returns:
1197
1206
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
@@ -1214,6 +1223,7 @@ def plot_lift_curve(y_true, y_probas, title='Lift Curve',
1214
1223
y_probas = np .array (y_probas )
1215
1224
1216
1225
classes = np .unique (y_true )
1226
+ if class_names is None : class_names = classes
1217
1227
if len (classes ) != 2 :
1218
1228
raise ValueError ('Cannot calculate Lift Curve for data with '
1219
1229
'{} category/ies' .format (len (classes )))
@@ -1236,8 +1246,8 @@ def plot_lift_curve(y_true, y_probas, title='Lift Curve',
1236
1246
1237
1247
ax .set_title (title , fontsize = title_fontsize )
1238
1248
1239
- ax .plot (percentages , gains1 , lw = 3 , label = 'Class {}' .format (classes [0 ]))
1240
- ax .plot (percentages , gains2 , lw = 3 , label = 'Class {}' .format (classes [1 ]))
1249
+ ax .plot (percentages , gains1 , lw = 3 , label = 'Class {}' .format (class_names [0 ]))
1250
+ ax .plot (percentages , gains2 , lw = 3 , label = 'Class {}' .format (class_names [1 ]))
1241
1251
1242
1252
ax .plot ([0 , 1 ], [1 , 1 ], 'k--' , lw = 2 , label = 'Baseline' )
1243
1253
0 commit comments