-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
Hello!
When all patches in a batch are filtered out, pathopatch/patch_extraction/dataset.py: LivePatchWSIDataloader.next leaves patches as a list (len(patches)==0 case), but the function is annotated to return a Tensor. This leads to type errors (and sometimes patches[0] IndexError).
PathoPatcher/pathopatch/patch_extraction/dataset.py
Lines 782 to 827 in 9ef1f12
| def __next__(self) -> Tuple[torch.Tensor, List[dict], List[np.ndarray]]: | |
| """Create one batch of patches | |
| Raises: | |
| StopIteration: If the end of the dataset is reached. | |
| Returns: | |
| Tuple[torch.Tensor, List[dict], List[np.ndarray]]: | |
| * torch.Tensor: Batch of patches, shape (batch_size, 3, patch_size, patch_size) | |
| * List[dict]: List of metadata for each patch | |
| * List[np.ndarray]: List of masks for each patch | |
| """ | |
| patches = [] | |
| metadata = [] | |
| masks = [] | |
| if self.i < len(self.element_list): | |
| batch_item_count = 0 | |
| while batch_item_count < self.batch_size and self.i < len( | |
| self.element_list | |
| ): | |
| patch, meta, mask = self.dataset[self.element_list[self.i]] | |
| self.i += 1 | |
| if patch is None and meta["discard_patch"]: | |
| self.discard_count += 1 | |
| continue | |
| elif self.dataset.config.filter_patches: | |
| output = self.dataset.detector_model( | |
| self.dataset.detector_transforms(patch)[None, ...] | |
| ) | |
| output_prob = torch.softmax(output, dim=-1) | |
| prediction = torch.argmax(output_prob, dim=-1) | |
| if int(prediction) != 0: | |
| self.discard_count += 1 | |
| continue | |
| patches.append(patch) | |
| metadata.append(meta) | |
| masks.append(mask) | |
| batch_item_count += 1 | |
| if len(patches) > 1: | |
| patches = [torch.tensor(f) for f in patches] | |
| patches = torch.stack(patches) | |
| elif len(patches) == 1: | |
| patches = torch.tensor(patches[0][None, ...]) | |
| return patches, metadata, masks | |
| else: | |
| raise StopIteration |
Suggested fix
def __next__(self):
if len(patches) > 1:
patches = [torch.tensor(f) for f in patches]
patches = torch.stack(patches)
elif len(patches) == 1:
patches = torch.tensor(patches[0][None, ...])
+ elif len(patches) == 0:
+ raise StopIteration
return patches, metadata, masks I would be happy to open a PR for this. Do let me know if this is your preferred fix.
Metadata
Metadata
Assignees
Labels
No labels