Skip to content

LivePatchWSIDataloader.__next__ may return a Python list (empty batch) and can raise on empty tiles; propose robust handling #7

@Isaaccheong95

Description

@Isaaccheong95

Hello!

When all patches in a batch are filtered out, pathopatch/patch_extraction/dataset.py: LivePatchWSIDataloader.next leaves patches as a list (len(patches)==0 case), but the function is annotated to return a Tensor. This leads to type errors (and sometimes patches[0] IndexError).

def __next__(self) -> Tuple[torch.Tensor, List[dict], List[np.ndarray]]:
"""Create one batch of patches
Raises:
StopIteration: If the end of the dataset is reached.
Returns:
Tuple[torch.Tensor, List[dict], List[np.ndarray]]:
* torch.Tensor: Batch of patches, shape (batch_size, 3, patch_size, patch_size)
* List[dict]: List of metadata for each patch
* List[np.ndarray]: List of masks for each patch
"""
patches = []
metadata = []
masks = []
if self.i < len(self.element_list):
batch_item_count = 0
while batch_item_count < self.batch_size and self.i < len(
self.element_list
):
patch, meta, mask = self.dataset[self.element_list[self.i]]
self.i += 1
if patch is None and meta["discard_patch"]:
self.discard_count += 1
continue
elif self.dataset.config.filter_patches:
output = self.dataset.detector_model(
self.dataset.detector_transforms(patch)[None, ...]
)
output_prob = torch.softmax(output, dim=-1)
prediction = torch.argmax(output_prob, dim=-1)
if int(prediction) != 0:
self.discard_count += 1
continue
patches.append(patch)
metadata.append(meta)
masks.append(mask)
batch_item_count += 1
if len(patches) > 1:
patches = [torch.tensor(f) for f in patches]
patches = torch.stack(patches)
elif len(patches) == 1:
patches = torch.tensor(patches[0][None, ...])
return patches, metadata, masks
else:
raise StopIteration

Suggested fix

def __next__(self):
             if len(patches) > 1: 
                 patches = [torch.tensor(f) for f in patches] 
                 patches = torch.stack(patches) 
             elif len(patches) == 1: 
                 patches = torch.tensor(patches[0][None, ...])
+          elif len(patches) == 0:
+               raise StopIteration
             return patches, metadata, masks 

I would be happy to open a PR for this. Do let me know if this is your preferred fix.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions