-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathconf_mat.py
More file actions
34 lines (28 loc) · 943 Bytes
/
Copy pathconf_mat.py
File metadata and controls
34 lines (28 loc) · 943 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['font.size'] = 22
plt.rcParams['axes.linewidth'] = 1.0
plt.rcParams['font.family'] = 'lato'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['mathtext.default'] = 'rm'
plt.rcParams['mathtext.fontset'] = 'cm'
y_true = np.load('data/y_true.npy')
y_pred = np.load('data/y_pred.npy')
n_class = 6
conf_mat = np.zeros((n_class, n_class))
size = y_true.shape[0]
for i in range(size):
m = int(y_true[i])
n = int(y_pred[i])
conf_mat[m, n] += 1
dataset_size = int(6000 * 1.0 / 6)
conf_mat = conf_mat / dataset_size
print(conf_mat)
# np.savetxt('results/resnet/hrr/exp4.csv', conf_mat, delimiter=',')
plt.figure(figsize=[10, 8])
plot = sns.heatmap(conf_mat, vmin=0, vmax=1, annot=True, fmt=".2g", linewidths=2, cmap='OrRd')
plot.set_xticklabels(list(range(1, n_class + 1)))
plot.set_yticklabels(list(range(1, n_class + 1)))
plt.title('')
plt.show()