Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mmrotate/core/evaluation/eval_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,14 @@ def print_map_summary(mean_ap,
num_classes = len(results)

recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
precision = np.zeros((num_scales,num_classes),dtype=np.float32) # add precision
Comment thread
0Freeebaby marked this conversation as resolved.
Outdated
aps = np.zeros((num_scales, num_classes), dtype=np.float32)
num_gts = np.zeros((num_scales, num_classes), dtype=int)
for i, cls_result in enumerate(results):
if cls_result['recall'].size > 0:
recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
if cls_result['precision'].size > 0 : # add precision
Comment thread
0Freeebaby marked this conversation as resolved.
precision[:, i] = np.array(cls_result['precision'], ndmin=2)[:, -1] # add precision
Comment thread
0Freeebaby marked this conversation as resolved.
Outdated
aps[:, i] = cls_result['ap']
num_gts[:, i] = cls_result['num_gts']

Expand All @@ -295,18 +298,18 @@ def print_map_summary(mean_ap,
if not isinstance(mean_ap, list):
mean_ap = [mean_ap]

header = ['class', 'gts', 'dets', 'recall', 'ap']
header = ['class', 'gts', 'dets', 'recall', 'precision', 'ap'] # add precision
Comment thread
0Freeebaby marked this conversation as resolved.
Outdated
for i in range(num_scales):
if scale_ranges is not None:
print_log(f'Scale range {scale_ranges[i]}', logger=logger)
table_data = [header]
for j in range(num_classes):
row_data = [
label_names[j], num_gts[i, j], results[j]['num_dets'],
f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}'
f'{recalls[i, j]:.3f}',f'{precision[i, j]:.3f}', f'{aps[i, j]:.3f}'
]
table_data.append(row_data)
table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}'])
table_data.append(['mAP', '', '', '', '', f'{mean_ap[i]:.3f}'])
table = AsciiTable(table_data)
table.inner_footing_row_border = True
print_log('\n' + table.table, logger=logger)