@@ -211,7 +211,7 @@ def __init__(
211211 detector_device (str): Device for the detector model
212212 detector_model (torch.nn.Module): Detector model for filtering patches
213213 detector_transforms (Callable): Transforms to apply to the detector model
214- mask_images (List[ PIL.Image.Image]): List of mask images
214+ mask_images (dict[str, PIL.Image.Image]): Dictionary of mask images for the patches, key is the name and value is the mask image
215215
216216 Methods:
217217 __init__(slide_processor_config: PreProcessingDatasetConfig, logger: logging.Logger = None) -> None:
@@ -250,6 +250,7 @@ def __init__(
250250 self .polygons : List [Polygon ]
251251 self .region_labels : List [str ]
252252 self .transforms = transforms
253+ self .mask_images : dict [str , Image .Image ]
253254
254255 # filter
255256 self .detector_device : str
@@ -372,7 +373,11 @@ def _set_tissue_detector(self) -> None:
372373 def _prepare_slide (
373374 self ,
374375 ) -> Tuple [
375- List [Tuple [int , int , float ]], int , List [Polygon ], List [str ], List [Image .Image ]
376+ List [Tuple [int , int , float ]],
377+ int ,
378+ List [Polygon ],
379+ List [str ],
380+ dict [str , Image .Image ],
376381 ]:
377382 """Prepare the slide for patch extraction
378383
@@ -390,7 +395,7 @@ def _prepare_slide(
390395 * int: Level of the slide
391396 * List[Polygon]: List of polygons, downsampled to the target level
392397 * List[str]: List of region labels
393- * List[ Image.Image]: List of mask images
398+ * dict[str, Image.Image]: Dictionary of mask images for the patches, key is the name and value is the mask image
394399 """
395400 self .slide_openslide = self .slide_metadata_loader (str (self .config .wsi_path ))
396401 self .slide = self .image_loader (str (self .config .wsi_path ))
@@ -800,7 +805,7 @@ def __next__(self) -> Tuple[torch.Tensor, List[dict], List[np.ndarray]]:
800805 if len (patches ) > 1 :
801806 patches = [torch .tensor (f ) for f in patches ]
802807 patches = torch .stack (patches )
803- else :
808+ elif len ( patches ) == 1 :
804809 patches = torch .tensor (patches [0 ][None , ...])
805810 return patches , metadata , masks
806811 else :
0 commit comments