-
Notifications
You must be signed in to change notification settings - Fork 60
ENH: Enable heatmaps when tiling on the fly #491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
df92a48
a7754d3
6f29abe
ce04dab
be25cf6
8183fb6
1a93d3a
bbaa5fa
5c10885
585c893
7970037
792385e
85bd4ec
5281940
2ea8c94
a720a28
a996323
547889f
7550cae
e989b4f
73f0798
3418f47
73b044f
58c74ec
fec8f8e
87e693f
bf16cb0
b656599
6408a2e
aa3e6f1
4226ba3
4a97089
50a58c0
6dfa5d6
600dd06
33c52b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -15,14 +15,14 @@ | |||||
from health_ml.utils.bag_utils import BagDataset, multibag_collate | ||||||
from health_ml.utils.common_utils import _create_generator | ||||||
|
||||||
from health_cpath.utils.wsi_utils import image_collate | ||||||
from health_cpath.utils.wsi_utils import array_collate | ||||||
from health_cpath.models.transforms import LoadTilesBatchd | ||||||
from health_cpath.datasets.base_dataset import SlidesDataset, TilesDataset | ||||||
from health_cpath.utils.naming import ModelKey | ||||||
|
||||||
from monai.transforms.compose import Compose | ||||||
from monai.transforms.io.dictionary import LoadImaged | ||||||
from monai.apps.pathology.transforms import TileOnGridd | ||||||
from monai.transforms import RandGridPatchd, GridPatchd | ||||||
from monai.data.image_reader import WSIReader | ||||||
|
||||||
_SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset) | ||||||
|
@@ -245,72 +245,85 @@ def __init__( | |||||
self, | ||||||
level: Optional[int] = 1, | ||||||
tile_size: Optional[int] = 224, | ||||||
step: Optional[int] = None, | ||||||
random_offset: Optional[bool] = True, | ||||||
pad_full: Optional[bool] = False, | ||||||
background_val: Optional[int] = 255, | ||||||
filter_mode: Optional[str] = "min", | ||||||
filter_mode: Optional[str] = "max", | ||||||
overlap: Optional[float] = 0, | ||||||
intensity_threshold: Optional[float] = 0, | ||||||
pad_mode: Optional[str] = "constant", | ||||||
**kwargs: Any, | ||||||
) -> None: | ||||||
""" | ||||||
:param level: the whole slide image level at which the image is extracted, defaults to 1 | ||||||
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend | ||||||
:param tile_size: size of the square tile, defaults to 224 | ||||||
this param is passed to TileOnGridd monai transform for tiling on the fly. | ||||||
:param step: step size to create overlapping tiles, defaults to None (same as tile_size) | ||||||
Use a step < tile_size to create overlapping tiles, analogousely a step > tile_size will skip some chunks in | ||||||
the wsi. This param is passed to TileOnGridd monai transform for tiling on the fly. | ||||||
:param random_offset: randomize position of the grid, instead of starting from the top-left corner, | ||||||
defaults to True. This param is passed to TileOnGridd monai transform for tiling on the fly. | ||||||
:param pad_full: pad image to the size evenly divisible by tile_size, defaults to False | ||||||
This param is passed to TileOnGridd monai transform for tiling on the fly. | ||||||
:param background_val: the background constant to ignore background tiles (e.g. 255 for white background), | ||||||
defaults to 255. This param is passed to TileOnGridd monai transform for tiling on the fly. | ||||||
:param filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is greater than | ||||||
tile_count, then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for | ||||||
random) subset, defaults to "min" (which assumes background is high value). This param is passed to TileOnGridd | ||||||
monai transform for tiling on the fly. | ||||||
:param filter_mode: when `num_patches` is provided, it determines if keep patches with highest values | ||||||
(`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None. | ||||||
:param overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the order of values? (width, height) or the other way around? |
||||||
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. | ||||||
:param intensity_threshold: a value to keep only the patches whose sum of intensities are less than the | ||||||
threshold. Defaults to no filtering. | ||||||
:pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
super().__init__(**kwargs) | ||||||
self.level = level | ||||||
self.tile_size = tile_size | ||||||
self.step = step | ||||||
self.random_offset = random_offset | ||||||
self.pad_full = pad_full | ||||||
self.background_val = background_val | ||||||
self.filter_mode = filter_mode | ||||||
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and | ||||||
# Tiling transform expects None to select all foreground tile so we hardcode max_bag_size and | ||||||
# max_bag_size_inf to None if set to 0 | ||||||
self.max_bag_size = None if self.max_bag_size == 0 else self.max_bag_size # type: ignore | ||||||
self.max_bag_size_inf = None if self.max_bag_size_inf == 0 else self.max_bag_size_inf # type: ignore | ||||||
self.overlap = overlap | ||||||
self.intensity_threshold = intensity_threshold | ||||||
self.pad_mode = pad_mode | ||||||
|
||||||
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset: | ||||||
base_transform = Compose( | ||||||
[ | ||||||
LoadImaged( | ||||||
keys=slides_dataset.IMAGE_COLUMN, | ||||||
reader=WSIReader, | ||||||
backend="cuCIM", | ||||||
dtype=np.uint8, | ||||||
level=self.level, | ||||||
image_only=True, | ||||||
), | ||||||
TileOnGridd( | ||||||
keys=slides_dataset.IMAGE_COLUMN, | ||||||
tile_count=self.max_bag_size if stage == ModelKey.TRAIN else self.max_bag_size_inf, | ||||||
tile_size=self.tile_size, | ||||||
step=self.step, | ||||||
random_offset=self.random_offset if stage == ModelKey.TRAIN else False, | ||||||
pad_full=self.pad_full, | ||||||
background_val=self.background_val, | ||||||
filter_mode=self.filter_mode, | ||||||
return_list_of_dicts=True, | ||||||
), | ||||||
] | ||||||
load_image_transform = LoadImaged( | ||||||
keys=slides_dataset.IMAGE_COLUMN, | ||||||
reader=WSIReader, # type: ignore | ||||||
backend="cuCIM", | ||||||
dtype=np.uint8, | ||||||
level=self.level, | ||||||
image_only=True, | ||||||
) | ||||||
if self.transforms_dict and self.transforms_dict[stage]: | ||||||
max_offset = None if (self.random_offset and stage == ModelKey.TRAIN) else 0 | ||||||
|
||||||
if stage != ModelKey.TRAIN: | ||||||
grid_transform = RandGridPatchd( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something I don't get here: We are using random tiles when we are NOT training? |
||||||
keys=[slides_dataset.IMAGE_COLUMN], | ||||||
patch_size=[self.tile_size, self.tile_size], # type: ignore | ||||||
num_patches=self.max_bag_size, | ||||||
sort_fn=self.filter_mode, | ||||||
pad_mode=self.pad_mode, # type: ignore | ||||||
constant_values=self.background_val, | ||||||
overlap=self.overlap, # type: ignore | ||||||
threshold=self.intensity_threshold, | ||||||
max_offset=max_offset, | ||||||
) | ||||||
else: | ||||||
grid_transform = GridPatchd( | ||||||
keys=[slides_dataset.IMAGE_COLUMN], | ||||||
patch_size=[self.tile_size, self.tile_size], # type: ignore | ||||||
num_patches=self.max_bag_size_inf, | ||||||
sort_fn=self.filter_mode, | ||||||
pad_mode=self.pad_mode, # type: ignore | ||||||
constant_values=self.background_val, | ||||||
overlap=self.overlap, # type: ignore | ||||||
threshold=self.intensity_threshold, | ||||||
offset=max_offset, | ||||||
) | ||||||
|
||||||
base_transform = Compose([load_image_transform, grid_transform]) | ||||||
|
||||||
transforms = Compose([base_transform, self.transforms_dict[stage]]).flatten() | ||||||
if self.transforms_dict and self.transforms_dict[stage]: | ||||||
transforms = Compose([base_transform, self.transforms_dict[stage]]).flatten() # type: ignore | ||||||
else: | ||||||
transforms = base_transform | ||||||
# The tiling transform is randomized. Make them deterministic. This call needs to be | ||||||
|
@@ -325,7 +338,7 @@ def _get_dataloader(self, dataset: SlidesDataset, stage: ModelKey, shuffle: bool | |||||
return DataLoader( | ||||||
transformed_slides_dataset, | ||||||
batch_size=self.batch_size, | ||||||
collate_fn=image_collate, | ||||||
collate_fn=array_collate, | ||||||
shuffle=shuffle, | ||||||
generator=generator, | ||||||
**dataloader_kwargs, | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. | ||
# ------------------------------------------------------------------------------------------ | ||
import torch | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union | ||
from pytorch_lightning.utilities.warnings import rank_zero_warn | ||
from pathlib import Path | ||
|
||
|
@@ -353,20 +353,82 @@ def get_bag_label(labels: Tensor) -> Tensor: | |
# SlidesDataModule attributes a single label to a bag of tiles already no need to do majority voting | ||
return labels | ||
|
||
@staticmethod | ||
def get_empty_lists(shape: int, n: int) -> List: | ||
ll = [] | ||
for _ in range(n): | ||
ll.append([None] * shape) | ||
return ll | ||
|
||
@staticmethod | ||
def get_patch_coordinate(slide_offset: List, patch_location: List[int], patch_size: List[int] | ||
) -> Tuple[int, int, int, int]: | ||
""" computing absolute patch coordinate """ | ||
# PATCH_LOCATION is expected to have shape [y, x] | ||
top = slide_offset[0] + patch_location[0] | ||
bottom = slide_offset[0] + patch_location[0] + patch_size[0] | ||
left = slide_offset[1] + patch_location[1] | ||
right = slide_offset[1] + patch_location[1] + patch_size[1] | ||
return top, bottom, left, right | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tuples of 4 integers are really error prone. Can we use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (maybe changing it to use top, bottom, left, right in the same washup?) |
||
|
||
@staticmethod | ||
def expand_slide_constant_metadata(id: str, path: str, n_patches: int, top: List[int], | ||
bottom: List[int], left: List[int], right: List[int]) -> Tuple[List, List, List]: | ||
"""Duplicate metadata that is patch invariant to match the shape of other arrays""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you expand the documentation a bit here? also "match the shape of other arrays" is not completely correct, it is matching the number given in |
||
slide_id = [id] * n_patches | ||
image_paths = [path] * n_patches | ||
tile_id = [f"{id}_left_{left[i]}_top_{top[i]}_right_{right[i]}_bottom_{bottom[i]}" for i in range(n_patches)] | ||
return slide_id, image_paths, tile_id | ||
|
||
def get_slide_patch_coordinates(self, slide_offset: List, patches_location: List, patch_size: List | ||
) -> Tuple[List, List, List, List]: | ||
""" computing absolute coordinates for all patches in a slide""" | ||
top, bottom, left, right = self.get_empty_lists(len(patches_location), 4) | ||
for i, location in enumerate(patches_location): | ||
top[i], bottom[i], left[i], right[i] = self.get_patch_coordinate(slide_offset, location, patch_size) | ||
return top, bottom, left, right | ||
|
||
def compute_slide_metadata(self, batch: Dict, index: int, metadata_dict: Dict) -> Dict: | ||
"""compute patch-dependent and patch-invariante metadata for a single slide """ | ||
offset = batch[SlideKey.OFFSET.value][index] | ||
patches_location = batch[SlideKey.TILE_LOCATION.value][index] | ||
patch_size = batch[SlideKey.TILE_SIZE.value][index] | ||
n_patches = len(patches_location) | ||
id = batch[SlideKey.SLIDE_ID][index] | ||
path = batch[SlideKey.IMAGE_PATH][index] | ||
|
||
top, bottom, left, right = self.get_slide_patch_coordinates(offset, patches_location, patch_size) | ||
slide_id, image_paths, tile_id = self.expand_slide_constant_metadata( | ||
id, path, n_patches, top, bottom, left, right | ||
) | ||
|
||
metadata_dict[ResultsKey.TILE_TOP] = top | ||
metadata_dict[ResultsKey.TILE_BOTTOM] = bottom | ||
metadata_dict[ResultsKey.TILE_LEFT] = left | ||
metadata_dict[ResultsKey.TILE_RIGHT] = right | ||
metadata_dict[ResultsKey.SLIDE_ID] = slide_id | ||
metadata_dict[ResultsKey.TILE_ID] = tile_id | ||
metadata_dict[ResultsKey.IMAGE_PATH] = image_paths | ||
return metadata_dict | ||
|
||
def update_results_with_data_specific_info(self, batch: Dict, results: Dict) -> None: | ||
# WARNING: This is a dummy input until we figure out tiles coordinates retrieval in the next iteration. | ||
bag_sizes = [tiles.shape[0] for tiles in batch[SlideKey.IMAGE]] | ||
results.update( | ||
{ | ||
ResultsKey.SLIDE_ID: [ | ||
[slide_id] * bag_sizes[i] for i, slide_id in enumerate(batch[SlideKey.SLIDE_ID]) | ||
], | ||
ResultsKey.TILE_ID: [ | ||
[f"{slide_id}_{tile_id}" for tile_id in range(bag_sizes[i])] | ||
for i, slide_id in enumerate(batch[SlideKey.SLIDE_ID]) | ||
], | ||
ResultsKey.IMAGE_PATH: [ | ||
[img_path] * bag_sizes[i] for i, img_path in enumerate(batch[SlideKey.IMAGE_PATH]) | ||
], | ||
if all(key.value in batch.keys() for key in [SlideKey.OFFSET, SlideKey.TILE_LOCATION, SlideKey.TILE_SIZE]): | ||
n_slides = len(batch[SlideKey.SLIDE_ID]) | ||
metadata_dict: Dict[str, List[Union[int, str]]] = { | ||
ResultsKey.TILE_TOP: [], | ||
ResultsKey.TILE_BOTTOM: [], | ||
ResultsKey.TILE_LEFT: [], | ||
ResultsKey.TILE_RIGHT: [], | ||
ResultsKey.SLIDE_ID: [], | ||
ResultsKey.TILE_ID: [], | ||
ResultsKey.IMAGE_PATH: [], | ||
} | ||
) | ||
results.update(metadata_dict) | ||
# each slide can have a different number of patches | ||
for i in range(n_slides): | ||
updated_metadata_dict = self.compute_slide_metadata(batch, i, metadata_dict) | ||
for key in metadata_dict.keys(): | ||
results[key].append(updated_metadata_dict[key]) | ||
else: | ||
rank_zero_warn(message="Offset, patch location or patch size are not found in the batch" | ||
"make sure to use RandGridPatch.") |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -72,6 +72,8 @@ def normalize_dict_for_df(dict_old: Dict[ResultsKey, Any]) -> Dict[str, Any]: | |||||
value = value.squeeze(0).cpu().numpy() | ||||||
if value.ndim == 0: | ||||||
value = np.full(bag_size, fill_value=value) | ||||||
if isinstance(value, List) and isinstance(value[0], torch.Tensor): | ||||||
value = [value[i].item() for i in range(len(value))] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
dict_new[key] = value | ||||||
elif key == ResultsKey.CLASS_PROBS: | ||||||
if isinstance(value, torch.Tensor): | ||||||
|
@@ -134,11 +136,17 @@ def save_outputs_csv(results: ResultsType, outputs_dir: Path) -> None: | |||||
|
||||||
# Collect the list of dictionaries in a list of pandas dataframe and save | ||||||
df_list = [] | ||||||
skipped_slides = 0 | ||||||
for slide_dict in list_slide_dicts: | ||||||
slide_dict = normalize_dict_for_df(slide_dict) # type: ignore | ||||||
df_list.append(pd.DataFrame.from_dict(slide_dict)) | ||||||
try: | ||||||
df_list.append(pd.DataFrame.from_dict(slide_dict)) | ||||||
except ValueError: | ||||||
skipped_slides += 1 | ||||||
logging.warning(f"something wrong in the dimension of slide {slide_dict[ResultsKey.SLIDE_ID][0]}") | ||||||
df = pd.concat(df_list, ignore_index=True) | ||||||
df.to_csv(csv_filename, mode='w+', header=True) | ||||||
logging.warning(f"{skipped_slides} slides have not been included in the ouputs because of issues with the outputs") | ||||||
|
||||||
|
||||||
def save_features(results: ResultsType, outputs_dir: Path) -> None: | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes to
primary_deps.yml
should be made viarequirements_run.txt