@@ -272,7 +272,8 @@ def inspect(dataset: SupervisedDataset, *, background: int = 0, max_samples: int
272272 label = dataset .label (idx ).int ()
273273 progress .update (task , advance = 1 , description = f"Inspecting dataset { tuple (label .shape )} " )
274274 ndim = label .ndim - 1
275- indices = (label != background ).nonzero ()
275+ fg_mask = label != background
276+ indices = fg_mask .nonzero ()
276277 if len (indices ) == 0 :
277278 r .append (InspectionAnnotation (
278279 tuple (label .shape [1 :]), (0 , 0 , 0 , 0 ) if ndim == 2 else (0 , 0 , 0 , 0 , 0 , 0 ), (), {}, {}, {})
@@ -291,14 +292,13 @@ def inspect(dataset: SupervisedDataset, *, background: int = 0, max_samples: int
291292 target_samples = min (max_samples , len (indices ))
292293 sampled_idx = torch .randperm (len (indices ))[:target_samples ]
293294 indices = indices [sampled_idx ]
294- class_locations [class_id ] = [ tuple (coord . tolist ()[ 1 :]) for coord in indices ]
295+ class_locations [class_id ] = tuple (tuple ( loc ) for loc in indices [:, 1 :]. tolist ())
295296 r .append (InspectionAnnotation (
296297 tuple (label .shape [1 :]), foreground_bbox , tuple (
297298 class_id for class_id in class_ids if class_id != background
298299 ), class_counts , class_bboxes , class_locations
299300 ))
300301 image = dataset .image (idx )
301- fg_mask = label != background
302302 if image .shape [0 ] > 1 :
303303 fg_mask = fg_mask .expand_as (image )
304304 fg = image [fg_mask ]
0 commit comments