@@ -32,7 +32,7 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
32
32
pred_labels = None , title = None , normalize = False ,
33
33
hide_zeros = False , hide_counts = False , x_tick_rotation = 0 , ax = None ,
34
34
figsize = None , cmap = 'Blues' , title_fontsize = "large" ,
35
- text_fontsize = "medium" ):
35
+ text_fontsize = "medium" , show_colorbar = True ):
36
36
"""Generates confusion matrix plot from predictions and true labels
37
37
38
38
Args:
@@ -89,6 +89,9 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
89
89
Use e.g. "small", "medium", "large" or integer-values. Defaults to
90
90
"medium".
91
91
92
+ show_colorbar (bool, optional): If False, does not add colour bar.
93
+ Defaults to True.
94
+
92
95
Returns:
93
96
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
94
97
drawn.
@@ -151,7 +154,10 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
151
154
ax .set_title ('Confusion Matrix' , fontsize = title_fontsize )
152
155
153
156
image = ax .imshow (cm , interpolation = 'nearest' , cmap = plt .cm .get_cmap (cmap ))
154
- plt .colorbar (mappable = image )
157
+
158
+ if show_colorbar == True :
159
+ plt .colorbar (mappable = image )
160
+
155
161
x_tick_marks = np .arange (len (pred_classes ))
156
162
y_tick_marks = np .arange (len (true_classes ))
157
163
ax .set_xticks (x_tick_marks )
0 commit comments