Skip to content

Commit 426afb7

Browse files
committed
Fixing bug in dataloader when no patch can be returned
1 parent da66242 commit 426afb7

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

pathopatch/patch_extraction/dataset.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def __init__(
211211
detector_device (str): Device for the detector model
212212
detector_model (torch.nn.Module): Detector model for filtering patches
213213
detector_transforms (Callable): Transforms to apply to the detector model
214-
mask_images (List[PIL.Image.Image]): List of mask images
214+
mask_images (dict[str, PIL.Image.Image]): Dictionary of mask images for the patches, key is the name and value is the mask image
215215
216216
Methods:
217217
__init__(slide_processor_config: PreProcessingDatasetConfig, logger: logging.Logger = None) -> None:
@@ -250,6 +250,7 @@ def __init__(
250250
self.polygons: List[Polygon]
251251
self.region_labels: List[str]
252252
self.transforms = transforms
253+
self.mask_images: dict[str, Image.Image]
253254

254255
# filter
255256
self.detector_device: str
@@ -372,7 +373,11 @@ def _set_tissue_detector(self) -> None:
372373
def _prepare_slide(
373374
self,
374375
) -> Tuple[
375-
List[Tuple[int, int, float]], int, List[Polygon], List[str], List[Image.Image]
376+
List[Tuple[int, int, float]],
377+
int,
378+
List[Polygon],
379+
List[str],
380+
dict[str, Image.Image],
376381
]:
377382
"""Prepare the slide for patch extraction
378383
@@ -390,7 +395,7 @@ def _prepare_slide(
390395
* int: Level of the slide
391396
* List[Polygon]: List of polygons, downsampled to the target level
392397
* List[str]: List of region labels
393-
* List[Image.Image]: List of mask images
398+
* dict[str, Image.Image]: Dictionary of mask images for the patches, key is the name and value is the mask image
394399
"""
395400
self.slide_openslide = self.slide_metadata_loader(str(self.config.wsi_path))
396401
self.slide = self.image_loader(str(self.config.wsi_path))
@@ -800,7 +805,7 @@ def __next__(self) -> Tuple[torch.Tensor, List[dict], List[np.ndarray]]:
800805
if len(patches) > 1:
801806
patches = [torch.tensor(f) for f in patches]
802807
patches = torch.stack(patches)
803-
else:
808+
elif len(patches) == 1:
804809
patches = torch.tensor(patches[0][None, ...])
805810
return patches, metadata, masks
806811
else:

0 commit comments

Comments
 (0)