Skip to content

Commit 52e57b4

Browse files
Carole SudreCarole Sudre
authored andcommitted
Add complementarity tests for edge cases in utils'
'
1 parent 332f25d commit 52e57b4

3 files changed

Lines changed: 233 additions & 117 deletions

File tree

MetricsReloaded/utility/utils.py

Lines changed: 112 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -112,26 +112,26 @@ def border_map(self):
112112
border = self.binary_map - eroded
113113
return border
114114

115-
def border_map2(self):
116-
"""
117-
Creates the border for a 3D image
118-
:return:
119-
"""
120-
west = ndimage.shift(self.binary_map, [-1, 0, 0], order=0)
121-
east = ndimage.shift(self.binary_map, [1, 0, 0], order=0)
122-
north = ndimage.shift(self.binary_map, [0, 1, 0], order=0)
123-
south = ndimage.shift(self.binary_map, [0, -1, 0], order=0)
124-
top = ndimage.shift(self.binary_map, [0, 0, 1], order=0)
125-
bottom = ndimage.shift(self.binary_map, [0, 0, -1], order=0)
126-
cumulative = west + east + north + south + top + bottom
127-
border = ((cumulative < 6) * self.binary_map) == 1
128-
return border
115+
# def border_map2(self):
116+
# """
117+
# Creates the border for a 3D image
118+
# :return:
119+
# """
120+
# west = ndimage.shift(self.binary_map, [-1, 0, 0], order=0)
121+
# east = ndimage.shift(self.binary_map, [1, 0, 0], order=0)
122+
# north = ndimage.shift(self.binary_map, [0, 1, 0], order=0)
123+
# south = ndimage.shift(self.binary_map, [0, -1, 0], order=0)
124+
# top = ndimage.shift(self.binary_map, [0, 0, 1], order=0)
125+
# bottom = ndimage.shift(self.binary_map, [0, 0, -1], order=0)
126+
# cumulative = west + east + north + south + top + bottom
127+
# border = ((cumulative < 6) * self.binary_map) == 1
128+
# return border
129129

130130
def foreground_component(self):
131131
"""
132132
Create the connected component map from the binary map stored in self.binary_map
133133
134-
return: label map
134+
return: label map and number of labels
135135
"""
136136
return ndimage.label(self.binary_map)
137137

@@ -417,103 +417,103 @@ def one_hot_encode(img, n_classes):
417417
"""
418418
return np.eye(n_classes)[img]
419419

420-
def to_string_count(measures_count, counting_dict, fmt="{:.4f}"):
421-
"""
422-
Transform to a comma separated string the content of results from the dictionary with all the counting based metrics
423-
424-
:param measures_count: list of counting metrics
425-
:param counting_dict: dictionary with the results of the counting metrics
426-
:param fmt: format in which the outputs should be written (default 4 decimal points)
427-
:return: complete comma-separated string of results in the order of keys specifid by measures_dist
428-
"""
429-
result_str = ""
430-
# list_space = ['com_ref', 'com_pred', 'list_labels']
431-
for key in measures_count:
432-
if len(counting_dict[key]) == 2:
433-
result = counting_dict[key][0]()
434-
else:
435-
result = counting_dict[key][0](counting_dict[key][2])
436-
result_str += (
437-
",".join(fmt.format(x) for x in result)
438-
if isinstance(result, tuple)
439-
else fmt.format(result)
440-
)
441-
result_str += ","
442-
return result_str[:-1] # trim the last comma
443-
444-
445-
def to_string_dist(measures_dist, distance_dict, fmt="{:.4f}"):
446-
"""
447-
Transform to a comma separated string the content of results from the dictionary with all the distance based metrics
448-
449-
:param measures_dist: list of distance metrics
450-
:param distance_dict: dictionary with the results of the distance metrics
451-
:param fmt: format in which the outputs should be written (default 4 decimal points)
452-
:return: complete comma-separated string of results in the order of keys specifid by measures_dist
453-
"""
454-
result_str = ""
455-
# list_space = ['com_ref', 'com_pred', 'list_labels']
456-
for key in measures_dist:
457-
if len(distance_dict[key]) == 2:
458-
result = distance_dict[key][0]()
459-
else:
460-
result = distance_dict[key][0](distance_dict[key][2])
461-
result_str += (
462-
",".join(fmt.format(x) for x in result)
463-
if isinstance(result, tuple)
464-
else fmt.format(result)
465-
)
466-
result_str += ","
467-
return result_str[:-1] # trim the last comma
468-
469-
470-
def to_string_mt(measures_mthresh, multi_thresholds_dict, fmt="{:.4f}"):
471-
"""
472-
Transform to a comma separated string the content of results from the dictionary with all the multi-threshold metric
473-
474-
:param measures_mthresh: list of multi threshold metrics
475-
:param multi_thresholds_dict: dictionary with the results of the multi-threshold metrics
476-
:param fmt: format in which the outputs should be written (default 4 decimal points)
477-
:return: complete comma-separated string of results in the order of keys specifid by measures_mthresh
478-
"""
479-
result_str = ""
480-
# list_space = ['com_ref', 'com_pred', 'list_labels']
481-
for key in measures_mthresh:
482-
if len(multi_thresholds_dict[key]) == 2:
483-
result = multi_thresholds_dict[key][0]()
484-
else:
485-
result = multi_thresholds_dict[key][0](
486-
multi_thresholds_dict[key][2]
487-
)
488-
result_str += (
489-
",".join(fmt.format(x) for x in result)
490-
if isinstance(result, tuple)
491-
else fmt.format(result)
492-
)
493-
result_str += ","
494-
return result_str[:-1] # trim the last comma
420+
# def to_string_count(measures_count, counting_dict, fmt="{:.4f}"):
421+
# """
422+
# Transform to a comma separated string the content of results from the dictionary with all the counting based metrics
423+
424+
# :param measures_count: list of counting metrics
425+
# :param counting_dict: dictionary with the results of the counting metrics
426+
# :param fmt: format in which the outputs should be written (default 4 decimal points)
427+
# :return: complete comma-separated string of results in the order of keys specifid by measures_dist
428+
# """
429+
# result_str = ""
430+
# # list_space = ['com_ref', 'com_pred', 'list_labels']
431+
# for key in measures_count:
432+
# if len(counting_dict[key]) == 2:
433+
# result = counting_dict[key][0]()
434+
# else:
435+
# result = counting_dict[key][0](counting_dict[key][2])
436+
# result_str += (
437+
# ",".join(fmt.format(x) for x in result)
438+
# if isinstance(result, tuple)
439+
# else fmt.format(result)
440+
# )
441+
# result_str += ","
442+
# return result_str[:-1] # trim the last comma
443+
444+
445+
# def to_string_dist(measures_dist, distance_dict, fmt="{:.4f}"):
446+
# """
447+
# Transform to a comma separated string the content of results from the dictionary with all the distance based metrics
448+
449+
# :param measures_dist: list of distance metrics
450+
# :param distance_dict: dictionary with the results of the distance metrics
451+
# :param fmt: format in which the outputs should be written (default 4 decimal points)
452+
# :return: complete comma-separated string of results in the order of keys specifid by measures_dist
453+
# """
454+
# result_str = ""
455+
# # list_space = ['com_ref', 'com_pred', 'list_labels']
456+
# for key in measures_dist:
457+
# if len(distance_dict[key]) == 2:
458+
# result = distance_dict[key][0]()
459+
# else:
460+
# result = distance_dict[key][0](distance_dict[key][2])
461+
# result_str += (
462+
# ",".join(fmt.format(x) for x in result)
463+
# if isinstance(result, tuple)
464+
# else fmt.format(result)
465+
# )
466+
# result_str += ","
467+
# return result_str[:-1] # trim the last comma
468+
469+
470+
# def to_string_mt(measures_mthresh, multi_thresholds_dict, fmt="{:.4f}"):
471+
# """
472+
# Transform to a comma separated string the content of results from the dictionary with all the multi-threshold metric
473+
474+
# :param measures_mthresh: list of multi threshold metrics
475+
# :param multi_thresholds_dict: dictionary with the results of the multi-threshold metrics
476+
# :param fmt: format in which the outputs should be written (default 4 decimal points)
477+
# :return: complete comma-separated string of results in the order of keys specifid by measures_mthresh
478+
# """
479+
# result_str = ""
480+
# # list_space = ['com_ref', 'com_pred', 'list_labels']
481+
# for key in measures_mthresh:
482+
# if len(multi_thresholds_dict[key]) == 2:
483+
# result = multi_thresholds_dict[key][0]()
484+
# else:
485+
# result = multi_thresholds_dict[key][0](
486+
# multi_thresholds_dict[key][2]
487+
# )
488+
# result_str += (
489+
# ",".join(fmt.format(x) for x in result)
490+
# if isinstance(result, tuple)
491+
# else fmt.format(result)
492+
# )
493+
# result_str += ","
494+
# return result_str[:-1] # trim the last comma
495495

496496

497-
def to_dict_meas_(measures, measures_dict, fmt="{:.4f}"):
498-
"""
499-
Given the selected metrics provides a dictionary
500-
with relevant metrics
497+
# def to_dict_meas_(measures, measures_dict, fmt="{:.4f}"):
498+
# """
499+
# Given the selected metrics provides a dictionary
500+
# with relevant metrics
501501

502-
:param measures: list of measures
503-
:param measures_dict: dictionary of result for metrics
504-
:param fmt: format to use (default 4 decimal places)
502+
# :param measures: list of measures
503+
# :param measures_dict: dictionary of result for metrics
504+
# :param fmt: format to use (default 4 decimal places)
505505

506-
:return: result_dict
507-
"""
508-
result_dict = {}
509-
# list_space = ['com_ref', 'com_pred', 'list_labels']
510-
for key in measures:
511-
if len(measures_dict[key]) == 2:
512-
result = measures_dict[key][0]()
513-
else:
514-
result = measures_dict[key][0](measures_dict[key][2])
515-
result_dict[key] = fmt.format(result)
516-
return result_dict # trim the last comma
506+
# :return: result_dict
507+
# """
508+
# result_dict = {}
509+
# # list_space = ['com_ref', 'com_pred', 'list_labels']
510+
# for key in measures:
511+
# if len(measures_dict[key]) == 2:
512+
# result = measures_dict[key][0]()
513+
# else:
514+
# result = measures_dict[key][0](measures_dict[key][2])
515+
# result_dict[key] = fmt.format(result)
516+
# return result_dict # trim the last comma
517517

518518
def combine_df(df1,df2):
519519
"""
@@ -554,24 +554,22 @@ def merge_list_df(list_df, on=['label','case']):
554554
for f in on:
555555
if f not in k.columns:
556556
flag_on = False
557+
print(f, ' not present')
558+
break
557559
if flag_on:
558560
list_fin.append(k)
559561
if len(list_fin) == 0:
560562
return None
561563
elif len(list_fin) == 1:
562564
return list_fin[0]
563565
else:
564-
print("list fin is ",list_fin)
566+
#print("list fin is ",list_fin)
565567
df_fin = list_fin[0]
568+
print(len(list_fin))
566569
for k in list_fin[1:]:
567570
df_fin = pd.merge(df_fin, k, on=on)
568571
return df_fin
569572

570-
571-
572-
573-
574-
575573
def trapezoidal_integration(x, fx):
576574
"""Trapezoidal integration
577575

test/test_processes/test_overall_process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
data_miss['ref_class'] = [ref1, ref2]
4444
data_miss['list_values'] = [1]
4545
data_miss['pred_prob'] = [None,None]
46-
data_miss['ref_missing'] = [ref3]
46+
data_miss['ref_missing_pred'] = [ref3]
4747

4848
data_agg = {}
4949
data_agg['pred_class'] = [pred12]
@@ -73,7 +73,7 @@ def test_op_aggregation():
7373
def test_op_refmissing():
7474
pe = PE(data_miss,'SemS',measures_overlap=['fbeta'],measures_boundary=['boundary_iou'])
7575
print(pe.grouped_lab, pe.resseg)
76-
assert_allclose(pe.grouped_lab.shape,[3,8])
76+
assert_allclose(pe.grouped_lab.shape,[2,8]) # to modify
7777

7878

7979

0 commit comments

Comments
 (0)