99#%%%%%%%% imports %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1010import numpy as np
1111import sys
12+ import warnings
1213from ivtmetrics .recognition import Recognition
1314
1415#%%%%%%%%%% RECOGNITION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
@@ -50,7 +51,7 @@ class Detection(Recognition):
5051 add compute_global_AP('i/ivt') return AP for all seen examples
5152 add reset_video()
5253 """
53- def __init__ (self , num_class = 100 , num_tool = 6 ):
54+ def __init__ (self , num_class = 100 , num_tool = 6 , threshold = 0.5 ):
5455 super (Recognition , self ).__init__ ()
5556 self .num_class = num_class
5657 self .num_tool = num_tool
@@ -60,6 +61,7 @@ def __init__(self, num_class=100, num_tool=6):
6061 self .accumulator = {}
6162 self .video_count = 0
6263 self .end_call = False
64+ self .threshold = threshold
6365 self .reset ()
6466
6567 def reset (self ):
@@ -115,6 +117,7 @@ def is_match(self, det_gt, det_pd, threshold):
115117 return status
116118
117119 def list2stack (self , x ):
120+ if x == []: x = [[- 1 ,- 1 ,- 1 ,- 1 ,- 1 ,- 1 ]] # empty
118121 #x format for a single frame: list(list): each list = [tripletID, toolID, toolProbs, x, y, w, h] bbox is scaled (0..1)
119122 assert isinstance (x [0 ], list ), "Each frame must be a list of lists, each list a prediction of triplet and object locations"
120123 x = np .stack (x , axis = 0 )
@@ -136,7 +139,6 @@ def dict2stack(self, x):
136139 p = [d ['triplet' ]]
137140 p .extend (d ["instrument" ])
138141 y .append (p )
139- # y = np.stack(y, axis=0)
140142 return self .list2stack (y )
141143
142144 def update (self , targets , predictions , format = "list" ):
@@ -160,38 +162,38 @@ def update_frame(self, targets, predictions, format="list"):
160162 sys .exit ("unkown input format for update function. Must be a list or dict" )
161163 if len (detection_pd ) + len (detection_gt ) == 0 :
162164 return
163- detection_gt_i = detection_gt .copy ()
164- detection_pd_i = detection_pd .copy ()
165+ detection_gt_ivt = detection_gt .copy ()
166+ detection_pd_ivt = detection_pd .copy ()
165167 # for triplet
166- for gt in detection_gt :
168+ for gt in detection_gt_ivt :
167169 self .accumulator [self .video_count ]["npos" ][int (gt [0 ])] += 1
168- for det_pd in detection_pd :
170+ for det_pd in detection_pd_ivt :
169171 self .accumulator [self .video_count ]["ndet" ][int (det_pd [0 ])] += 1
170172 matched = False
171- for k , det_gt in enumerate (detection_gt ):
173+ for k , det_gt in enumerate (detection_gt_ivt ):
172174 y = det_gt [0 :]
173175 f = det_pd [0 :]
174- if self .is_match (y , f , threshold = 0.5 ):
175- detection_gt = np .delete (detection_gt , obj = k , axis = 0 )
176+ if self .is_match (y , f , threshold = self . threshold ):
177+ detection_gt_ivt = np .delete (detection_gt_ivt , obj = k , axis = 0 )
176178 matched = True
177179 break
178180 if matched :
179181 self .accumulator [self .video_count ]["hits" ][int (det_pd [0 ])].append (1.0 )
180182 else :
181183 self .accumulator [self .video_count ]["hits" ][int (det_pd [0 ])].append (0.0 )
182- # for instrument
183- detection_gt = detection_gt_i
184- detection_pd = detection_pd_i
185- for gt in detection_gt :
184+ # for instrument
185+ detection_gt_i = detection_gt . copy ()
186+ detection_pd_i = detection_pd . copy ()
187+ for gt in detection_gt_i :
186188 self .accumulator [self .video_count ]["npos_i" ][int (gt [1 ])] += 1
187- for det_pd in detection_pd :
189+ for det_pd in detection_pd_i :
188190 self .accumulator [self .video_count ]["ndet_i" ][int (det_pd [1 ])] += 1
189191 matched = False
190- for k , det_gt in enumerate (detection_gt ):
191- y = det_gt [1 :6 ]
192- f = det_pd [1 :6 ]
193- if self .is_match (y , f , threshold = 0.5 ):
194- detection_gt = np .delete (detection_gt , obj = k , axis = 0 )
192+ for k , det_gt in enumerate (detection_gt_i ):
193+ y = det_gt [1 :]
194+ f = det_pd [1 :]
195+ if self .is_match (y , f , threshold = self . threshold ):
196+ detection_gt_i = np .delete (detection_gt_i , obj = k , axis = 0 )
195197 matched = True
196198 break
197199 if matched :
@@ -240,7 +242,9 @@ def compute(self, component="ivt", video_id=None):
240242 classwise_ap .append (ap )
241243 classwise_rec .append (np .max (rec ))
242244 classwise_prec .append (np .max (prec ))
243- return (classwise_ap , np .nanmean (classwise_ap )), (classwise_rec , np .nanmean (classwise_rec )), (classwise_prec , np .nanmean (classwise_prec ))
245+ with warnings .catch_warnings ():
246+ warnings .simplefilter ("ignore" , category = RuntimeWarning )
247+ return (classwise_ap , np .nanmean (classwise_ap )), (classwise_rec , np .nanmean (classwise_rec )), (classwise_prec , np .nanmean (classwise_prec ))
244248
245249 def compute_video_AP (self , component = "ivt" ):
246250 classwise_ap = []
@@ -252,12 +256,14 @@ def compute_video_AP(self, component="ivt"):
252256 classwise_ap .append (ap )
253257 classwise_rec .append (rec )
254258 classwise_prec .append (prec )
255- classwise_ap = np .nanmean (np .stack (classwise_ap , axis = 0 ), axis = 0 )
256- classwise_rec = np .nanmean (np .stack (classwise_rec , axis = 0 ), axis = 0 )
257- classwise_prec = np .nanmean (np .stack (classwise_prec , axis = 0 ), axis = 0 )
258- mAP = np .nanmean (classwise_ap )
259- mRec = np .nanmean (classwise_rec )
260- mPrec = np .nanmean (classwise_prec )
259+ with warnings .catch_warnings ():
260+ warnings .simplefilter ("ignore" , category = RuntimeWarning )
261+ classwise_ap = np .nanmean (np .stack (classwise_ap , axis = 0 ), axis = 0 )
262+ classwise_rec = np .nanmean (np .stack (classwise_rec , axis = 0 ), axis = 0 )
263+ classwise_prec = np .nanmean (np .stack (classwise_prec , axis = 0 ), axis = 0 )
264+ mAP = np .nanmean (classwise_ap )
265+ mRec = np .nanmean (classwise_rec )
266+ mPrec = np .nanmean (classwise_prec )
261267 return {"AP" :classwise_ap , "mAP" :mAP , "Rec" :classwise_rec , "mRec" :mRec , "Pre" :classwise_prec , "mPre" :mPrec }
262268
263269 def compute_AP (self , component = "ivt" ):
0 commit comments