Skip to content
This repository was archived by the owner on Mar 17, 2021. It is now read-only.

Commit 488b9f2

Browse files
committed
Merge branch 'evaluation' into 'dev'
Evaluation for comparison of results from regression between images for synthesis See merge request !44
2 parents dcd8cf7 + da75ff7 commit 488b9f2

File tree

4 files changed

+134
-10
lines changed

4 files changed

+134
-10
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import absolute_import, print_function
2+
3+
import os.path
4+
5+
import nibabel as nib
6+
import numpy as np
7+
8+
import niftynet.utilities.csv_table as csv_table
9+
from niftynet.evaluation.pairwise_measures import PairwiseMeasuresRegression
10+
11+
MEASURES = (
12+
'mse','rmse','mae','r2'
13+
)
14+
# MEASURES_NEW = ('ref volume', 'reg volume', 'tp', 'fp', 'fn', 'outline_error',
15+
# 'detection_error', 'dice')
16+
OUTPUT_FORMAT = '{:4f}'
17+
OUTPUT_FILE_PREFIX = 'PairwiseMeasureReg'
18+
19+
20+
def run(param, csv_dict):
21+
# output
22+
out_name = '{}_{}_{}.csv'.format(
23+
OUTPUT_FILE_PREFIX,
24+
os.path.split(param.ref_dir)[1],
25+
os.path.split(param.seg_dir)[1])
26+
print("Writing {} to {}".format(out_name, param.save_csv_dir))
27+
28+
# inputs
29+
csv_loader = csv_table.CSVTable(csv_dict=csv_dict, allow_missing=False)
30+
reg_names = [csv_loader._csv_table[m][1][0][0] for m in range(
31+
0, len(csv_loader._csv_table))]
32+
ref_names = [csv_loader._csv_table[m][2][0][0] for m in range(
33+
0, len(csv_loader._csv_table))]
34+
# reg_names = util.list_files(param.reg_dir, param.ext)
35+
# ref_names = util.list_files(param.ref_dir, param.ext)
36+
pair_list = list(zip(reg_names, ref_names))
37+
# TODO check reg_names ref_names matching
38+
# TODO do we evaluate all combinations?
39+
# import itertools
40+
# pair_list = list(itertools.product(reg_names, ref_names))
41+
print("List of references is {}".format(ref_names))
42+
print("List of regressions is {}".format(reg_names))
43+
44+
# prepare a header for csv
45+
with open(os.path.join(param.save_csv_dir, out_name), 'w+') as out_stream:
46+
# a trivial PairwiseMeasures obj to produce header_str
47+
m_headers = PairwiseMeasuresRegression(0, 0, measures=MEASURES).header_str()
48+
out_stream.write("Name (ref), Name (reg)" + m_headers + '\n')
49+
50+
# do the pairwise evaluations
51+
for i, pair_ in enumerate(pair_list):
52+
reg_name = pair_[0]
53+
ref_name = pair_[1]
54+
print('>>> {} of {} evaluations, comparing {} and {}.'.format(
55+
i + 1, len(pair_list), ref_name, reg_name))
56+
reg_nii = nib.load(os.path.join(param.seg_dir, reg_name))
57+
ref_nii = nib.load(os.path.join(param.ref_dir, ref_name))
58+
voxel_sizes = reg_nii.header.get_zooms()[0:3]
59+
reg = np.squeeze(reg_nii.get_data())
60+
ref = np.squeeze(ref_nii.get_data())
61+
assert (np.all(reg) >= 0)
62+
assert (np.all(ref) >= 0)
63+
assert (reg.shape == ref.shape)
64+
PE = PairwiseMeasuresRegression(reg, ref,
65+
measures=MEASURES)
66+
fixed_fields = "{}, {}, ".format(ref_name, reg_name)
67+
out_stream.write(fixed_fields + PE.to_string(
68+
OUTPUT_FORMAT) + '\n')
69+

niftynet/evaluation/pairwise_measures.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,53 @@ def to_string(self, fmt='{:.4f}'):
264264
if isinstance(result, tuple) else fmt.format(result)
265265
result_str += ','
266266
return result_str[:-1] # trim the last comma
267+
268+
269+
class PairwiseMeasuresRegression(object):
270+
def __init__(self, reg_img, ref_img, measures=None):
271+
272+
self.reg = reg_img
273+
self.ref = ref_img
274+
self.measures = measures
275+
276+
self.m_dict = {
277+
'mse': (self.mse, 'MSE'),
278+
'rmse': (self.rmse, 'RMSE'),
279+
'mae': (self.mae, 'MAE'),
280+
'r2': (self.r2, 'R2')
281+
}
282+
283+
284+
def mse(self):
285+
return np.mean(np.square(self.reg - self.ref))
286+
287+
288+
def rmse(self):
289+
return np.sqrt(self.mse())
290+
291+
292+
def mae(self):
293+
return np.mean(np.abs(self.ref-self.reg))
294+
295+
def r2(self):
296+
ref_var = np.sum(np.square(self.ref-np.mean(self.ref)))
297+
reg_var = np.sum(np.square(self.reg-np.mean(self.reg)))
298+
cov_refreg = np.sum((self.reg-np.mean(self.reg))*(self.ref-np.mean(
299+
self.ref)))
300+
return np.square(cov_refreg / np.sqrt(ref_var*reg_var+0.00001))
301+
302+
303+
def header_str(self):
304+
result_str = [self.m_dict[key][1] for key in self.measures]
305+
result_str = ',' + ','.join(result_str)
306+
return result_str
307+
308+
309+
def to_string(self, fmt='{:.4f}'):
310+
result_str = ""
311+
for key in self.measures:
312+
result = self.m_dict[key][0]()
313+
result_str += ','.join(fmt.format(x) for x in result) \
314+
if isinstance(result, tuple) else fmt.format(result)
315+
result_str += ','
316+
return result_str[:-1] # trim the last comma

niftynet/utilities/misc_csv.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,15 @@ def __find_max_overlap_in_list(name, list_names):
112112
if len(list_names) == 0:
113113
return '', -1
114114
for test in list_names:
115-
match = SequenceMatcher(None, name, test).find_longest_match(
116-
0, len(name), 0, len(test))
117-
if match.size >= match_max and match.size/len(test) >= \
118-
match_ratio:
119-
match_max = match.size
120-
match_seq = test[match.b:(match.b + match.size)]
121-
match_ratio = match.size/len(test)
122-
match_orig = test
115+
if len(test)>0:
116+
match = SequenceMatcher(None, name, test).find_longest_match(
117+
0, len(name), 0, len(test))
118+
if match.size >= match_max and match.size/len(test) >= \
119+
match_ratio:
120+
match_max = match.size
121+
match_seq = test[match.b:(match.b + match.size)]
122+
match_ratio = match.size/len(test)
123+
match_orig = test
123124
if match_max == 0:
124125
return '', -1
125126
other_list = [name for name in list_names if match_seq in name and

run_evaluation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
niftynet.evaluation.compute_ROI_statistics.run(args, csv_dict)
1515
elif args.action.lower() == 'compare':
16-
import niftynet.evaluation.compare_segmentations
16+
if args.application_type == 'segmentation':
17+
import niftynet.evaluation.compare_segmentations
1718

18-
niftynet.evaluation.compare_segmentations.run(args, csv_dict)
19+
niftynet.evaluation.compare_segmentations.run(args, csv_dict)
20+
elif args.application_type == 'regression':
21+
import niftynet.evaluation.compare_regressions
22+
niftynet.evaluation.compare_regressions.run(args,csv_dict)

0 commit comments

Comments
 (0)