-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathevaluation.py
55 lines (41 loc) · 2.01 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import itertools
import sys
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from utils.classification_report import classification_report
from utils.plot_confusion_matrix_util import plot_confusion_matrix
def evaluate(model, test, test_input, labels_vocab, save_path, name):
test_eval = model.evaluate(test_input, np.array(test.y))
print('Test loss:', test_eval[0])
print('Test accuracy:', test_eval[1])
predicted_values = np.argmax(model.predict(test_input), axis=-1)
true_values = np.argmax(test.y, -1)
# flatten to single array with class labels
true_values = list(itertools.chain(*true_values))
predicted_values = list(itertools.chain(*predicted_values))
orig_stdout = sys.stdout
f = open(save_path + 'results.txt', 'w')
sys.stdout = f
print("Macro Precision/Recall/F1 score:")
print(precision_recall_fscore_support(true_values, predicted_values, average='macro'))
print(60 * "-")
print("Micro Precision/Recall/F1 score:")
print(precision_recall_fscore_support(true_values, predicted_values, average='micro'))
print(60 * "-")
keys = list(labels_vocab.stoi.keys())
values = list(labels_vocab.stoi.values())
# Classification report's
macro_report = classification_report(true_values, predicted_values, labels=values, target_names=keys, digits=4, average='macro')
print(macro_report)
print(60 * "-")
micro_report = classification_report(true_values, predicted_values, labels=values, target_names=keys, digits=4, average='micro')
print(micro_report)
sys.stdout = orig_stdout
f.close()
# Confusion Matrix
cnf_matrix = confusion_matrix(true_values, predicted_values)
np.set_printoptions(precision=2)
plot_confusion_matrix(cnf_matrix, classes=list(labels_vocab.stoi.keys()), normalize=True, title='Confusion matrix - ' + name)
plt.savefig(save_path + '/images/confusion_matrix.png', dpi=200, format='png', bbox_inches='tight')
plt.close()