Skip to content

Commit e0801f4

Browse files
authored
inspect(): batch tolist() and reuse fg_mask (#220) (#221)
1 parent 114b87f commit e0801f4

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

mipcandy/data/inspection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,8 @@ def inspect(dataset: SupervisedDataset, *, background: int = 0, max_samples: int
272272
label = dataset.label(idx).int()
273273
progress.update(task, advance=1, description=f"Inspecting dataset {tuple(label.shape)}")
274274
ndim = label.ndim - 1
275-
indices = (label != background).nonzero()
275+
fg_mask = label != background
276+
indices = fg_mask.nonzero()
276277
if len(indices) == 0:
277278
r.append(InspectionAnnotation(
278279
tuple(label.shape[1:]), (0, 0, 0, 0) if ndim == 2 else (0, 0, 0, 0, 0, 0), (), {}, {}, {})
@@ -291,14 +292,13 @@ def inspect(dataset: SupervisedDataset, *, background: int = 0, max_samples: int
291292
target_samples = min(max_samples, len(indices))
292293
sampled_idx = torch.randperm(len(indices))[:target_samples]
293294
indices = indices[sampled_idx]
294-
class_locations[class_id] = [tuple(coord.tolist()[1:]) for coord in indices]
295+
class_locations[class_id] = tuple(tuple(loc) for loc in indices[:, 1:].tolist())
295296
r.append(InspectionAnnotation(
296297
tuple(label.shape[1:]), foreground_bbox, tuple(
297298
class_id for class_id in class_ids if class_id != background
298299
), class_counts, class_bboxes, class_locations
299300
))
300301
image = dataset.image(idx)
301-
fg_mask = label != background
302302
if image.shape[0] > 1:
303303
fg_mask = fg_mask.expand_as(image)
304304
fg = image[fg_mask]

0 commit comments

Comments
 (0)