1+ from collections import defaultdict
12import os
23import time
34from typing import List , Optional , Tuple , Union
@@ -94,6 +95,8 @@ class ImageSegmentationTensorflowDataset:
9495 :type split: str, optional
9596 :param lut_ontology: LUT to transform label classes, defaults to None
9697 :type lut_ontology: dict, optional
98+ :param normalization: Parameters for normalizing input images, defaults to None
99+ :type normalization: dict, optional
97100 """
98101
99102 def __init__ (
@@ -103,8 +106,14 @@ def __init__(
103106 batch_size : int = 1 ,
104107 split : str = "all" ,
105108 lut_ontology : Optional [dict ] = None ,
109+ normalization : Optional [dict ] = None ,
106110 ):
107111 self .image_size = image_size
112+ self .normalization = None
113+ if normalization is not None :
114+ mean = tf .constant (normalization ["mean" ], dtype = tf .float32 )
115+ std = tf .constant (normalization ["std" ], dtype = tf .float32 )
116+ self .normalization = {"mean" : mean , "std" : std }
108117
109118 # Filter split and make filenames global
110119 if split != "all" :
@@ -155,9 +164,17 @@ def read_image(self, fname: str, label=False) -> tf.Tensor:
155164 # Resize (use NN to avoid interpolation when dealing with labels)
156165 method = "nearest" if label else "bilinear"
157166 image = tf_image .resize (images = image , size = self .image_size , method = method )
167+
168+ # If label, round values to avoid interpolation artifacts
158169 if label :
159170 image = tf .round (image )
160171
172+ # If normalization parameters are provided, normalize image
173+ else :
174+ if self .normalization is not None :
175+ image = tf .cast (image , tf .float32 ) / 255.0
176+ image = (image - self .normalization ["mean" ]) / self .normalization ["std" ]
177+
161178 return image
162179
163180 def load_data (
@@ -217,6 +234,11 @@ def t_in(image):
217234 tensor = tf .convert_to_tensor (image )
218235 tensor = tf_image .resize (images = tensor , size = self .model_cfg ["image_size" ])
219236 tensor = tf .expand_dims (tensor , axis = 0 )
237+ if "normalization" in self .model_cfg :
238+ mean = tf .constant (self .model_cfg ["normalization" ]["mean" ])
239+ std = tf .constant (self .model_cfg ["normalization" ]["std" ])
240+ tensor = tf .cast (tensor , tf .float32 ) / 255.0
241+ tensor = (tensor - mean ) / std
220242 return tensor
221243
222244 self .t_in = t_in
@@ -275,18 +297,23 @@ def eval(
275297 batch_size = self .model_cfg .get ("batch_size" , 1 ),
276298 split = split ,
277299 lut_ontology = lut_ontology ,
300+ normalization = self .model_cfg .get ("normalization" , None ),
278301 )
279302
303+ # Retrieve ignored label indices
304+ ignored_label_indices = []
305+ for ignored_class in self .model_cfg .get ("ignored_classes" , []):
306+ ignored_label_indices .append (dataset .ontology [ignored_class ]["idx" ])
307+
280308 # Init metrics
281309 results = {}
282- iou = um .IoU (self .n_classes )
283- cm = um .ConfusionMatrix (self .n_classes )
310+ metrics_factory = um .MetricsFactory (self .n_classes )
284311
285312 # Evaluation loop
286313 pbar = tqdm (dataset .dataset )
287314 for image , label in pbar :
288315 if self .model_type == "native" :
289- pred = self .model (image )
316+ pred = self .model (image , training = False )
290317 elif self .model_type == "compiled" :
291318 pred = self .model .signatures ["serving_default" ](image )
292319 else :
@@ -295,37 +322,48 @@ def eval(
295322 if isinstance (pred , dict ):
296323 pred = list (pred .values ())[0 ]
297324
325+ # Get valid points masks depending on ignored label indices
326+ if ignored_label_indices :
327+ valid_mask = tf .ones_like (label , dtype = tf .bool )
328+ for idx in ignored_label_indices :
329+ valid_mask *= label != idx
330+ else :
331+ valid_mask = None
332+
298333 label = tf .squeeze (label , axis = 3 )
299334 pred = tf .argmax (pred , axis = 3 )
300- cm .update (pred .numpy (), label .numpy ())
301-
302- pred = tf .one_hot (pred , self .n_classes )
303- pred = tf .transpose (pred , perm = [0 , 3 , 1 , 2 ])
335+ if valid_mask is not None :
336+ valid_mask = tf .squeeze (valid_mask , axis = 3 )
337+ metrics_factory .update (
338+ pred .numpy (),
339+ label .numpy (),
340+ valid_mask .numpy () if valid_mask is not None else None ,
341+ )
304342
305- label = tf . one_hot ( label , self . n_classes )
306- label = tf . transpose ( label , perm = [ 0 , 3 , 1 , 2 ] )
343+ # Build results dataframe
344+ results = defaultdict ( dict )
307345
308- iou .update (pred .numpy (), label .numpy ())
346+ # Add per class and global metrics
347+ for metric in metrics_factory .get_metric_names ():
348+ per_class = metrics_factory .get_metric_per_name (metric , per_class = True )
309349
310- # Get metrics results
311- iou_per_class , iou = iou .compute ()
312- acc_per_class , acc = cm .get_accuracy ()
313- iou_per_class = [float (n ) for n in iou_per_class ]
314- acc_per_class = [float (n ) for n in acc_per_class ]
350+ for class_name , class_data in self .ontology .items ():
351+ results [class_name ][metric ] = float (per_class [class_data ["idx" ]])
315352
316- # Build results dataframe
317- results = {}
318- for class_name , class_data in self .ontology .items ():
319- results [class_name ] = {
320- "iou" : iou_per_class [class_data ["idx" ]],
321- "acc" : acc_per_class [class_data ["idx" ]],
322- }
323- results ["global" ] = {"iou" : iou , "acc" : acc }
353+ if metric not in ["tp" , "fp" , "fn" , "tn" ]:
354+ for avg_method in ["macro" , "micro" ]:
355+ results [avg_method ][metric ] = metrics_factory .get_averaged_metric (
356+ metric , avg_method
357+ )
324358
325- results = pd .DataFrame (results )
326- results .index .name = "metric"
359+ # Add confusion matrix
360+ for class_name_a , class_data_a in self .ontology .items ():
361+ for class_name_b , class_data_b in self .ontology .items ():
362+ results [class_name_a ][class_name_b ] = metrics_factory .confusion_matrix [
363+ class_data_a ["idx" ], class_data_b ["idx" ]
364+ ]
327365
328- return results
366+ return pd . DataFrame ( results )
329367
330368 def get_computational_cost (self , runs : int = 30 , warm_up_runs : int = 5 ) -> dict :
331369 """Get different metrics related to the computational cost of the model
0 commit comments