Skip to content

Commit a7f7249

Browse files
committed
run.sh modified
1 parent 807cf79 commit a7f7249

File tree

4 files changed

+105
-13
lines changed

4 files changed

+105
-13
lines changed

code/4-ROC_PR_curve/PlotHIST.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
import scipy.io as sio
1111
from sklearn import *
1212
import matplotlib.pyplot as plt
13+
import os
1314

14-
def Plot_HIST_Fn(label,distance, phase, num_bins = 50):
15+
def Plot_HIST_Fn(label,distance, save_path, num_bins = 50):
1516

1617
dissimilarity = distance[:]
1718
gen_dissimilarity_original = []
@@ -27,6 +28,35 @@ def Plot_HIST_Fn(label,distance, phase, num_bins = 50):
2728
plt.hist(gen_dissimilarity_original, bins, alpha=0.5, facecolor='blue', normed=False, label='gen_dist_original')
2829
plt.hist(imp_dissimilarity_original, bins, alpha=0.5, facecolor='red', normed=False, label='imp_dist_original')
2930
plt.legend(loc='upper right')
30-
plt.title(phase + '_' + 'OriginalFeatures_Histogram.jpg')
31+
plt.title('OriginalFeatures_Histogram.jpg')
3132
plt.show()
32-
fig.savefig(phase + '_' + 'OriginalFeatures_Histogram.jpg')
33+
fig.savefig(save_path)
34+
35+
if __name__ == '__main__':
36+
37+
tf.app.flags.DEFINE_string(
38+
'evaluation_dir', '../../results/SCORES',
39+
'Directory where checkpoints and event logs are written to.')
40+
41+
tf.app.flags.DEFINE_string(
42+
'plot_dir', '../../results/PLOTS',
43+
'Directory where plots are saved to.')
44+
45+
tf.app.flags.DEFINE_integer(
46+
'num_bins', '50',
47+
'Number of bins for plotting histogram.')
48+
49+
# Store all elemnts in FLAG structure!
50+
FLAGS = tf.app.flags.FLAGS
51+
52+
# Loading necessary data.
53+
score = np.load(os.path.join(FLAGS.evaluation_dir,'score_vector.npy'))
54+
label = np.load(os.path.join(FLAGS.evaluation_dir,'target_label_vector.npy'))
55+
save_path = os.path.join(FLAGS.plot_dir,'Histogram.jpg')
56+
57+
# Creating the path
58+
if not os.path.exists(FLAGS.plot_dir):
59+
os.makedirs(FLAGS.plot_dir)
60+
61+
Plot_HIST_Fn(label,score, save_path, FLAGS.num_bins)
62+

code/4-ROC_PR_curve/PlotPR.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
import scipy.io as sio
1111
from sklearn import *
1212
import matplotlib.pyplot as plt
13+
import os
1314

1415
def Plot_PR_Fn(label,distance,phase):
1516

16-
precision, recall, thresholds = metrics.precision_recall_curve(label, -distance, pos_label=1, sample_weight=None)
17-
AP = metrics.average_precision_score(label, -distance, average='macro', sample_weight=None)
17+
precision, recall, thresholds = metrics.precision_recall_curve(label, distance, pos_label=1, sample_weight=None)
18+
AP = metrics.average_precision_score(label, distance, average='macro', sample_weight=None)
1819

1920
# AP(average precision) calculation.
2021
# This score corresponds to the area under the precision-recall curve.
@@ -38,5 +39,28 @@ def Plot_PR_Fn(label,distance,phase):
3839
# plt.text(0.5, 0.5, 'AP = ' + str(AP), fontdict=None)
3940
plt.grid()
4041
plt.show()
41-
fig.savefig(phase + '_' + 'PR.jpg')
42+
fig.savefig(save_path)
4243

44+
if __name__ == '__main__':
45+
46+
tf.app.flags.DEFINE_string(
47+
'evaluation_dir', '../../results/SCORES',
48+
'Directory where checkpoints and event logs are written to.')
49+
50+
tf.app.flags.DEFINE_string(
51+
'plot_dir', '../../results/PLOTS',
52+
'Directory where plots are saved to.')
53+
54+
# Store all elemnts in FLAG structure!
55+
FLAGS = tf.app.flags.FLAGS
56+
57+
# Loading necessary data.
58+
score = np.load(os.path.join(FLAGS.evaluation_dir,'score_vector.npy'))
59+
label = np.load(os.path.join(FLAGS.evaluation_dir,'target_label_vector.npy'))
60+
save_path = os.path.join(FLAGS.plot_dir,'PR.jpg')
61+
62+
# Creating the path
63+
if not os.path.exists(FLAGS.plot_dir):
64+
os.makedirs(FLAGS.plot_dir)
65+
66+
Plot_PR_Fn(label,score,save_path)

code/4-ROC_PR_curve/PlotROC.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
import scipy.io as sio
1111
from sklearn import *
1212
import matplotlib.pyplot as plt
13+
import os
1314

1415

15-
def Plot_ROC_Fn(label,distance,phase):
1616

17-
fpr, tpr, thresholds = metrics.roc_curve(label, -distance, pos_label=1)
18-
AUC = metrics.roc_auc_score(label, -distance, average='macro', sample_weight=None)
17+
def Plot_ROC_Fn(label,distance,save_path):
18+
19+
fpr, tpr, thresholds = metrics.roc_curve(label, distance, pos_label=1)
20+
AUC = metrics.roc_auc_score(label, distance, average='macro', sample_weight=None)
1921
# AP = metrics.average_precision_score(label, -distance, average='macro', sample_weight=None)
2022

2123
# Calculating EER
@@ -37,7 +39,7 @@ def Plot_ROC_Fn(label,distance,phase):
3739
plt.setp(lines, linewidth=2, color='r')
3840
ax.set_xticks(np.arange(0, 1.1, 0.1))
3941
ax.set_yticks(np.arange(0, 1.1, 0.1))
40-
plt.title(phase + '_' + 'ROC.jpg')
42+
plt.title('ROC.jpg')
4143
plt.xlabel('False Positive Rate')
4244
plt.ylabel('True Positive Rate')
4345

@@ -52,7 +54,32 @@ def Plot_ROC_Fn(label,distance,phase):
5254
# plt.text(0.5, 0.4, 'EER = ' + str(EER), fontdict=None)
5355
plt.grid()
5456
plt.show()
55-
fig.savefig(phase + '_' + 'ROC.jpg')
57+
fig.savefig(save_path)
58+
59+
if __name__ == '__main__':
60+
61+
tf.app.flags.DEFINE_string(
62+
'evaluation_dir', '../../results/SCORES',
63+
'Directory where checkpoints and event logs are written to.')
64+
65+
tf.app.flags.DEFINE_string(
66+
'plot_dir', '../../results/PLOTS',
67+
'Directory where plots are saved to.')
68+
69+
# Store all elemnts in FLAG structure!
70+
FLAGS = tf.app.flags.FLAGS
71+
72+
# Loading scores and labels
73+
score = np.load(os.path.join(FLAGS.evaluation_dir,'score_vector.npy'))
74+
label = np.load(os.path.join(FLAGS.evaluation_dir,'target_label_vector.npy'))
75+
save_path = os.path.join(FLAGS.plot_dir,'ROC.jpg')
76+
77+
# Creating the path
78+
if not os.path.exists(FLAGS.plot_dir):
79+
os.makedirs(FLAGS.plot_dir)
80+
81+
Plot_ROC_Fn(label,score,save_path)
82+
5683

5784

5885

run.sh

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,21 @@ if [ $do_training = 'train' ]; then
1919
python -u ./code/2-enrollment/enrollment.py --development_dataset_path=$development_dataset --enrollment_dataset_path=$enrollment_dataset --checkpoint_dir=results/TRAIN_CNN_3D/ --enrollment_dir=results/Model
2020

2121
# evaluation
22-
python -u ./code/3-evaluation/evaluation.py --development_dataset_path=$development_dataset --evaluation_dataset_path=$evaluation_dataset --checkpoint_dir=results/TRAIN_CNN_3D/ --evaluation_dir=results/ROC --enrollment_dir=results/Model
22+
python -u ./code/3-evaluation/evaluation.py --development_dataset_path=$development_dataset --evaluation_dataset_path=$evaluation_dataset --checkpoint_dir=results/TRAIN_CNN_3D/ --evaluation_dir=results/SCORES --enrollment_dir=results/Model
2323

2424
# ROC curve
25-
python -u ./code/4-ROC_PR_curve/calculate_roc.py --evaluation_dir=results/ROC
25+
python -u ./code/4-ROC_PR_curve/calculate_roc.py --evaluation_dir=results/SCORES
26+
27+
# Plot ROC
28+
python -u ./code/4-ROC_PR_curve/PlotROC.py --evaluation_dir=results/SCORES --plot_dir=results/PLOTS
29+
30+
# Plot ROC
31+
python -u ./code/4-ROC_PR_curve/PlotPR.py --evaluation_dir=results/SCORES --plot_dir=results/PLOTS
32+
33+
# Plot HIST
34+
python -u ./code/4-ROC_PR_curve/PlotHIST.py --evaluation_dir=results/SCORES --plot_dir=results/PLOTS --num_bins=5
35+
36+
2637

2738
else
2839

0 commit comments

Comments
 (0)