-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathprecision_recall.py
More file actions
executable file
·97 lines (65 loc) · 2.64 KB
/
precision_recall.py
File metadata and controls
executable file
·97 lines (65 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python
import matplotlib
matplotlib.use('Agg')
import pandas, json, pdb, numpy as np, argparse, sys, os, cv2, glob
import matplotlib.pyplot as plt
def get_best(annos, box):
ovmax = -float('inf')
best_rect = None
for r in annos['rects']:
if 'taken' in r and r['taken']:
continue
bi = list(map(max, zip([r['x1'], r['y1']], box[:2]))) + list(map(min, zip([r['x2'], r['y2']], box[2:4])))
iw=bi[2]-bi[0]+1;
ih=bi[3]-bi[1]+1;
if iw > 0 and ih > 0:
ua=(r['x2']-r['x1']+1)*(r['y2']-r['y1']+1)+ \
(box[2]-box[0]+1)*(box[3]-box[1]+1)- \
iw*ih;
ov = iw*ih/ua
if ov > ovmax:
ovmax = ov
best_rect = r
if best_rect:
best_rect['taken'] = True
return best_rect
def precision_recall(gt, predictions):
npos = sum(map(lambda x: len(x['rects']), gt))
gt_map = {x['image_path'] : x for x in gt}
predictions = predictions.sort_values(by='score', ascending=False)
tp, fp = np.zeros(len(predictions)), np.zeros(len(predictions))
for i, (_, row) in enumerate(predictions.iterrows()):
best = get_best(gt_map[row.image_id], (row.x1, row.y1, row.x2, row.y2))
if best:
tp[i] = 1
else:
fp[i] = 1
fp = np.cumsum(fp)
tp = np.cumsum(tp)
rec = tp/npos
prec = tp / (fp+tp)
return rec, prec
if __name__ == '__main__':
dflt_val = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/val_data.json')
parser = argparse.ArgumentParser()
parser.add_argument('--predictions', nargs='+', help='Detection csv files')
parser.add_argument('--names', nargs='+', help='Titles')
parser.add_argument('--val_data', default=dflt_val, help='Validation data')
args = parser.parse_args()
if len(args.predictions) != len(args.names):
args.names = list(map(lambda x: os.path.basename(x), args.predictions))
plt.gca().set_prop_cycle('color', ['red', 'green', 'blue', 'yellow'])
fig, ax = plt.subplots(nrows=1, ncols=1)
for predictions, name in zip(args.predictions, args.names):
gt = json.load(open(args.val_data))
df = pandas.read_csv(predictions)
rec, prec = precision_recall(gt, df)
ax.plot(rec, prec)
i = np.argmax(rec + prec)
f1 = 2 * (prec[i] * rec[i]) / (prec[i] + rec[i])
print('%s: precision = %f, recall = %f, F1 = %f' % (name, prec[i], rec[i], f1))
plt.legend(args.names, loc='lower left')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
fig.savefig('./precision_recall.png')
plt.close(fig)