Skip to content

Commit 21cd5c6

Browse files
make colorbar optional in plot_confusion_matrix() reiinakano#114
1 parent 75621df commit 21cd5c6

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

scikitplot/metrics.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
3232
pred_labels=None, title=None, normalize=False,
3333
hide_zeros=False, hide_counts=False, x_tick_rotation=0, ax=None,
3434
figsize=None, cmap='Blues', title_fontsize="large",
35-
text_fontsize="medium"):
35+
text_fontsize="medium", show_colorbar=True):
3636
"""Generates confusion matrix plot from predictions and true labels
3737
3838
Args:
@@ -89,6 +89,9 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
8989
Use e.g. "small", "medium", "large" or integer-values. Defaults to
9090
"medium".
9191
92+
show_colorbar (bool, optional): If False, does not add colour bar.
93+
Defaults to True.
94+
9295
Returns:
9396
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
9497
drawn.
@@ -151,7 +154,10 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
151154
ax.set_title('Confusion Matrix', fontsize=title_fontsize)
152155

153156
image = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.get_cmap(cmap))
154-
plt.colorbar(mappable=image)
157+
158+
if show_colorbar == True:
159+
plt.colorbar(mappable=image)
160+
155161
x_tick_marks = np.arange(len(pred_classes))
156162
y_tick_marks = np.arange(len(true_classes))
157163
ax.set_xticks(x_tick_marks)

0 commit comments

Comments
 (0)