Skip to content

Commit 2160489

Browse files
Carole SudreCarole Sudre
authored andcommitted
Improving testing of assignment localisation
1 parent f5dc43a commit 2160489

2 files changed

Lines changed: 104 additions & 10 deletions

File tree

MetricsReloaded/utility/assignment_localization.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def __init__(
110110
self.flag_fp_in = flag_fp_in
111111
self.pixdim = pixdim
112112
all_input = []
113+
self.ref_loc_mod = None
114+
self.pred_loc_mod = None
113115
if len(self.pixdim) == 0:
114116
if len(pred_loc) > 0:
115117
if pred_loc[0].size > 0:
@@ -136,11 +138,14 @@ def __init__(
136138

137139
flag_usable, flag_predmod, flag_refmod = self.check_input_localization()
138140
# self.pred_class = pred_class
139-
141+
print('Flag', flag_usable, flag_predmod, flag_refmod)
140142
# self.ref_class = ref_class
141143
self.flag_usable = flag_usable
142144
self.flag_predmod = flag_predmod
143145
self.flag_refmod = flag_refmod
146+
147+
148+
144149
if self.flag_usable:
145150
if localization == "box_iou":
146151
self.matrix = self.pairwise_boxiou()
@@ -163,12 +168,23 @@ def __init__(
163168
elif localization == "com_dist":
164169
self.matrix = self.pairwise_pointcomdist()
165170
else:
171+
print(' not valid localisation ')
166172
self.flag_usable = False
173+
self.df_matching = None
174+
self.valid = None
175+
warnings.warn("No adequate localization strategy chosen - not going ahead")
176+
else:
177+
print(' not valid localisation ')
178+
self.flag_usable = False
179+
self.df_matching = None
180+
self.valid = None
167181
warnings.warn("No adequate localization strategy chosen - not going ahead")
168182

169183
if self.localization in ['point_in_mask','point_in_box']:
170184
if self.assignment == 'greedy_matching':
171185
self.flag_usable = False
186+
self.df_matching = None
187+
self.valid = None
172188
warnings.warn("The localization strategy does not provide grading. Impossible to base assignment on localization performance!")
173189
if self.flag_usable:
174190
self.df_matching, self.valid = self.resolve_ambiguities_matching()
@@ -240,6 +256,7 @@ def check_input_localization(self):
240256
return flag_usable, flag_predmod, flag_refmod
241257
if input_ref == 'mask':
242258
flag_refmod = True
259+
self.box_fromrefmask()
243260
warnings.warn('We will need to modify ref to make it interpretable as box corners')
244261
elif self.localization == 'com_dist':
245262
if input_ref == 'mask':
@@ -359,8 +376,8 @@ def pairwise_pointinbox(self):
359376
pred_points = self.pred_loc
360377
if self.flag_refmod:
361378
ref_boxes = self.ref_loc_mod
362-
if self.flag_predmod:
363-
pred_points = self.pred_loc_mod
379+
# if self.flag_predmod:
380+
# pred_points = self.pred_loc_mod
364381
matrix_pinb = np.zeros([pred_points.shape[0],ref_boxes.shape[0]])
365382
for (p, p_point) in enumerate(pred_points):
366383
for (r, r_box) in enumerate(ref_boxes):
@@ -376,10 +393,10 @@ def pairwise_pointinmask(self):
376393
"""
377394
ref_masks = self.ref_loc
378395
pred_points = self.pred_loc
379-
if self.flag_refmod:
380-
ref_masks = self.ref_loc_mod
381-
if self.flag_predmod:
382-
pred_points = self.pred_loc_mod
396+
# if self.flag_refmod:
397+
# ref_masks = self.ref_loc_mod
398+
# if self.flag_predmod:
399+
# pred_points = self.pred_loc_mod
383400
matrix_pinm = np.zeros([pred_points.shape[0],ref_masks.shape[0]])
384401
for (p,p_point) in enumerate(pred_points):
385402
for (r,r_mask) in enumerate(ref_masks):
@@ -600,13 +617,13 @@ def resolve_ambiguities_matching(self):
600617
df_ambiguous_ref = df_matching[
601618
(df_matching["count_ref"] > 1) & (df_matching["ref"] > -1)
602619
]
603-
df_ambiguous_seg = df_matching[
620+
df_ambiguous_pred = df_matching[
604621
(df_matching["count_pred"] > 1) & (df_matching["pred"] > -1)
605622
]
606623
if (
607624
df_ambiguous_ref is None
608625
or df_ambiguous_ref.shape[0] == 0
609-
and df_ambiguous_seg.shape[0] == 0
626+
and df_ambiguous_pred.shape[0] == 0
610627
):
611628
print("No ambiguity in matching")
612629
df_matching_all = pd.concat([df_matching, df_fp, df_fn])
@@ -620,7 +637,7 @@ def resolve_ambiguities_matching(self):
620637
list_matching = []
621638
for (r, c) in zip(row, col):
622639
df_tmp = df_matching[
623-
df_matching["seg"] == list_valid[r] & (df_matching["ref"] == c)
640+
df_matching["pred"] == list_valid[r] & (df_matching["ref"] == c)
624641
]
625642
list_matching.append(df_tmp)
626643
df_ordered2 = pd.concat(list_matching)

test/test_utility/test_assignment_localization.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
ref_351 = [pq_ref1, pq_ref2, pq_ref3]
5252
pred_351 = [pq_pred1, pq_pred2, pq_pred3,pq_pred4]
5353

54+
pred_com_351 = [np.asarray([4.5,2]), np.asarray([14,5]), np.asarray([9,14.5]), np.asarray([13.5,14.5])]
55+
5456

5557
## Data for figure 59 and testing of localisation
5658
f59_ref1 = np.zeros([15, 15])
@@ -63,6 +65,48 @@
6365
f59_pred2 = np.zeros([15, 15])
6466
f59_pred2[4:8, 5:9] = 1
6567

68+
def test_pixdim_info():
69+
asm = AssignmentMapping(pred_loc = pred_351, ref_loc=ref_351, pred_prob=[0.4, 0.6, 0.3, 0.9])
70+
asm2 = AssignmentMapping(pred_loc=[], ref_loc=ref_351, pred_prob=[])
71+
assert_allclose(asm.pixdim,[1,1])
72+
assert_allclose(asm2.pixdim,[1,1])
73+
74+
def test_not_suitable_localisation():
75+
asm = AssignmentMapping(pred_loc=pred_com_351, ref_loc=ref_351, pred_prob=[0.4, 0.6, 0.3, 0.9], localization='box_iou')
76+
assert not asm.flag_usable
77+
assert asm.df_matching is None
78+
79+
def test_emptyref_flags():
80+
asm = AssignmentMapping(pred_loc=pred_351, ref_loc=[], pred_prob=[0.4, 0.6, 0.3, 0.9])
81+
assert asm.flag_usable
82+
83+
def test_not_good_localisation_option():
84+
asm = AssignmentMapping(pred_loc = pred_351, ref_loc=ref_351, pred_prob=[0.4, 0.6, 0.3, 0.9], localization='mask_iop')
85+
assert not asm.flag_usable
86+
87+
def test_boundaryiou_loc():
88+
asm = AssignmentMapping(pred_loc = pred_351, ref_loc=ref_351, pred_prob=[0.4, 0.6, 0.3, 0.9], localization='boundary_iou')
89+
assert asm.matrix[0,1] == 0
90+
91+
def test_input_boxcom_maskcom():
92+
asm1 = AssignmentMapping(pred_loc=pred_boxes_6a, ref_loc=ref_boxes_6a, pred_prob=pred_proba_6a, localization='box_com')
93+
asm2 = AssignmentMapping(pred_loc = pred_351, ref_loc=ref_351, pred_prob=[0.4, 0.6, 0.3, 0.9], localization='box_com')
94+
assert asm1.flag_refmod
95+
assert asm1.flag_predmod
96+
assert asm2.flag_refmod
97+
assert asm2.flag_predmod
98+
99+
def test_input_pim_maskmask():
100+
asm2 = AssignmentMapping(pred_loc = pred_351, ref_loc=ref_351, pred_prob=[0.4, 0.6, 0.3, 0.9], localization='point_in_mask')
101+
assert not asm2.flag_usable
102+
103+
def test_input_pib_pm():
104+
asm = AssignmentMapping(pred_loc=pred_com_351, ref_loc=ref_351, pred_prob=[0.4, 0.6, 0.3, 0.9], localization='point_in_box')
105+
matrix_pib = asm.pairwise_pointinbox()
106+
assert asm.flag_refmod
107+
assert matrix_pib[0,0] == 1
108+
109+
66110
def test_assignment_6c():
67111
asm1 = AssignmentMapping(pred_loc=pred_boxes_6a, ref_loc=ref_boxes_6a, pred_prob=pred_proba_6a, thresh=0.1,localization='box_iou')
68112
df_matching, df_fn, df_fp, list_valid = asm1.initial_mapping()
@@ -222,6 +266,11 @@ def test_pairwise_pointinbox():
222266
expected_matrix = np.asarray([[1,0],[0,0]])
223267
assert_allclose(am.matrix, expected_matrix)
224268

269+
def test_pairwise_pointinbox_pm():
270+
asm = AssignmentMapping(pred_com_351, ref_351,[1,1,1,1],'point_in_box')
271+
assert asm.matrix[0,0] == 1
272+
273+
225274
def test_pairwise_pointcomdist():
226275
ref1 = [3,4]
227276
ref2 = [10,10]
@@ -257,6 +306,11 @@ def test_com_from_refbox_6a():
257306
expected_ref_com = [[5,3.5],[7.5,10],[2,17],[15,16]]
258307
assert_array_almost_equal(np.asarray(expected_ref_com), test_com)
259308

309+
def test_pairwise_boxior_mm():
310+
asm = AssignmentMapping(pred_351, ref_351,[1,1,1,1],'box_ior')
311+
assert asm.matrix[0,0] == 0.8
312+
313+
260314
def test_com_from_predbox_6a():
261315
"""
262316
Using figure 6a as illustration
@@ -308,6 +362,23 @@ def test_box_frompredmask():
308362
expected_box = [[3,1,6,6],[13,4,15,5],[7,13,11,16],[13,13,15,16]]
309363
assert_array_almost_equal(np.asarray(expected_box),test_box)
310364

365+
def test_matching_ambiguity():
366+
ref_box = np.asarray([2,3,10,11])
367+
pred_box1 = np.asarray([1,1,8,8])
368+
pred_box2 = np.asarray([2,8,11,11])
369+
ref = [ref_box]
370+
pred = [pred_box1, pred_box2]
371+
asm = AssignmentMapping(pred, ref,[0.5,1],thresh=0.1,localization='box_iou')
372+
asm_h = AssignmentMapping(pred, ref,[0.5,1],thresh=0.1,localization='box_iou',assignment='hungarian')
373+
asm_p = AssignmentMapping(pred, ref,[0.5,1],thresh=0.1,localization='box_iou',assignment='greedy_performance')
374+
df_match, list_val = asm.resolve_ambiguities_matching()
375+
df_match_h, list_val_h = asm_h.resolve_ambiguities_matching()
376+
df_match_p, list_val_p = asm_p.resolve_ambiguities_matching()
377+
print(df_match, df_match_h, df_match_p)
378+
assert df_match.shape[0] == 2
379+
assert df_match_h.shape[0] == 2
380+
assert df_match_p.shape[0] == 2
381+
311382

312383
def test_localization():
313384
ref = [f59_ref1, f59_ref2]
@@ -349,3 +420,9 @@ def test_localization():
349420
np.asarray(m12[m12["pred"] == 0]["ref"])[0] == 0
350421
and np.asarray(m21[m21["pred"] == 0]["ref"])[0] == -1
351422
)
423+
424+
425+
426+
427+
428+

0 commit comments

Comments
 (0)