diff --git a/README.md b/README.md index 8b52eb8..fd22374 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,13 @@ This project was developed by the [Mahmood Lab](https://faisal.ai/) at Harvard M - **Tissue Segmentation**: Extract tissue from background (supports H&E, IHC, penmark and artifact removal, etc.). - **Patch Extraction**: Extract tissue patches of any size and magnification. -- **Patch Feature Extraction**: Extract patch embeddings using one of 20 foundation models, including [UNI](https://www.nature.com/articles/s41591-024-02857-3), [Virchow](https://www.nature.com/articles/s41591-024-03141-0), [H-Optimus-0](https://github.com/bioptimus/releases/tree/main/models/h-optimus/v0) and more... +- **Patch Feature Extraction**: Extract patch embeddings using one of 21 foundation models, including [UNI](https://www.nature.com/articles/s41591-024-02857-3), [Virchow](https://www.nature.com/articles/s41591-024-03141-0), [H-Optimus-0](https://github.com/bioptimus/releases/tree/main/models/h-optimus/v0) and more... - **Slide Feature Extraction**: Extract slide embeddings using one of 6 slide foundation models, including [Threads](https://arxiv.org/abs/2501.16652) (coming soon!), [Titan](https://arxiv.org/abs/2411.19666), and [GigaPath](https://www.nature.com/articles/s41586-024-07441-w). ### Updates: - 02.25: New image converter from `czi`, `png`, etc to `tiff`. - 02.25: Support for [GrandQC](https://www.nature.com/articles/s41467-024-54769-y) tissue vs. background segmentation. -- 02.25: Support for [Madeleine](https://github.com/mahmoodlab/MADELEINE/tree/main), [Hibou](https://github.com/HistAI/hibou), [Lunit](https://huggingface.co/1aurent/vit_small_patch8_224.lunit_dino) and [Kaiko](https://huggingface.co/histai/hibou-L) models. +- 02.25: Support for [Madeleine](https://github.com/mahmoodlab/MADELEINE/tree/main), [Hibou](https://github.com/HistAI/hibou), [Lunit](https://huggingface.co/1aurent/vit_small_patch8_224.lunit_dino), [Kaiko](https://huggingface.co/histai/hibou-L), and [H-Optimus1](https://huggingface.co/bioptimus/H-optimus-1) models. ### 🔨 1. **Installation**: - Create an environment: `conda create -n "trident" python=3.10`, and activate it `conda activate trident`. @@ -105,6 +105,7 @@ Trident supports 20 patch encoders, loaded via a patch-level [`encoder_factory`] - **Phikon-v2**: [owkin/phikon-v2](https://huggingface.co/owkin/phikon-v2/) (`--patch_encoder phikon_v2`) - **Prov-Gigapath**: [prov-gigapath](https://huggingface.co/prov-gigapath/prov-gigapath) (`--patch_encoder gigapath`) - **H-Optimus-0**: [bioptimus/H-optimus-0](https://huggingface.co/bioptimus/H-optimus-0) (`--patch_encoder hoptimus0`) +- **H-Optimus-1**: [bioptimus/H-optimus-1](https://huggingface.co/bioptimus/H-optimus-1) (`--patch_encoder hoptimus1`) - **MUSK**: [xiangjx/musk](https://huggingface.co/xiangjx/musk) (`--patch_encoder musk`) - **Kaiko**: Hosted on TorchHub (`--patch_encoder kaiko-vits8, kaiko-vits16, kaiko-vitb8, kaiko-vitb16, kaiko-vitl14`) - **Lunit**: [1aurent/vit_small_patch8_224.lunit_dino](https://huggingface.co/1aurent/vit_small_patch8_224.lunit_dino) (`--patch_encoder lunit-vits8`) diff --git a/trident/wsi_objects/WSI.py b/trident/wsi_objects/WSI.py index 79186a2..ca42396 100644 --- a/trident/wsi_objects/WSI.py +++ b/trident/wsi_objects/WSI.py @@ -5,7 +5,7 @@ import os import warnings import torch -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Union from torch.utils.data import DataLoader from trident.wsi_objects.WSIPatcher import * @@ -156,6 +156,7 @@ def get_dimensions(self) -> Tuple[int, int]: >>> wsi.get_dimensions() (100000, 80000) """ + self._lazy_initialize() # ensure WSI has been init return self.img.dimensions def read_region( @@ -187,6 +188,7 @@ def read_region( >>> print(region.shape) (512, 512, 3) """ + self._lazy_initialize() # ensure WSI has been init return np.array(self.read_region_pil(location, level, size)) def read_region_pil( @@ -217,6 +219,7 @@ def read_region_pil( >>> region = wsi.read_region_pil((0, 0), level=0, size=(512, 512)) >>> region.show() """ + self._lazy_initialize() # ensure WSI has been init return self.img.read_region(location, level, size).convert('RGB') def get_thumbnail(self, size: tuple[int, int]) -> Image.Image: @@ -238,6 +241,7 @@ def get_thumbnail(self, size: tuple[int, int]) -> Image.Image: >>> thumbnail = wsi.get_thumbnail((256, 256)) >>> thumbnail.show() """ + self._lazy_initialize() # ensure WSI has been init return self.img.get_thumbnail(size).convert('RGB') def create_patcher( @@ -302,6 +306,7 @@ def _fetch_mpp(self, custom_mpp_keys: List[str] | None = None) -> float | None: If the MPP value is unavailable, this method attempts to calculate it from the slide's resolution metadata. """ + self._lazy_initialize() # ensure WSI has been init mpp_x = None mpp_keys = [ openslide.PROPERTY_NAME_MPP_X, @@ -368,6 +373,7 @@ def _fetch_magnification(self, custom_mpp_keys: List[str] | None = None) -> int >>> print(mag) 40 """ + self._lazy_initialize() # ensure WSI has been init try: if self.mpp is None: mpp_x = self._fetch_mpp(custom_mpp_keys) @@ -412,7 +418,7 @@ def segment_tissue( holes_are_tissue: bool = True, job_dir: str | None = None, batch_size: int = 16, - ) -> str: + ) -> Union[str, gpd.GeoDataFrame]: """ The `segment_tissue` function of the class `OpenSlideWSI` segments tissue regions in the WSI using a specified segmentation model. It processes the WSI at a target magnification level, optionally @@ -433,8 +439,9 @@ def segment_tissue( Returns: -------- - str: - The absolute path to where the segmentation as GeoJSON is saved. + Union[str, geopandas.GeoDataFrame]: + If `job_dir` is provided, returns the absolute path to where the segmentation is saved as a GeoJSON. + If `job_dir` is None, returns a geodataframe where each entry is a contour tissue. Example: -------- @@ -442,7 +449,7 @@ def segment_tissue( >>> # Results saved in "output_dir" """ - self._lazy_initialize() + self._lazy_initialize() # ensure WSI has been init max_dimension = 1000 if self.width > self.height: thumbnail_width = max_dimension @@ -496,31 +503,35 @@ def segment_tissue( for hole in holes: cv2.drawContours(predicted_mask, [hole], 0, 255, -1) - # Save thumbnail image - thumbnail_saveto = os.path.join(job_dir, 'thumbnails', f'{self.name}.jpg') - os.makedirs(os.path.dirname(thumbnail_saveto), exist_ok=True) - thumbnail.save(thumbnail_saveto) - - # Save geopandas contours - gdf_saveto = os.path.join(job_dir, 'contours_geojson', f'{self.name}.geojson') - os.makedirs(os.path.dirname(gdf_saveto), exist_ok=True) + # Create geopandas contours gdf_contours = mask_to_gdf( mask=predicted_mask, max_nb_holes=0 if holes_are_tissue else 5, min_contour_area=1000, pixel_size=self.mpp, contour_scale=1/mpp_reduction_factor - ) - gdf_contours.set_crs("EPSG:3857", inplace=True) # used to silent warning // Web Mercator + ).set_crs("EPSG:3857", inplace=True) # used to silent warning // Web Mercator + + # return contours if no save path is provided. + if job_dir is None: + return gdf_contours + + gdf_saveto = os.path.join(job_dir, 'contours_geojson', f'{self.name}.geojson') + os.makedirs(os.path.dirname(gdf_saveto), exist_ok=True) gdf_contours.to_file(gdf_saveto, driver="GeoJSON") self.gdf_contours = gdf_contours self.tissue_seg_path = gdf_saveto - # Draw the contours on the thumbnail image + # Save thumbnail image with contours drawn contours_saveto = os.path.join(job_dir, 'contours', f'{self.name}.jpg') annotated = np.array(thumbnail) overlay_gdf_on_thumbnail(gdf_contours, annotated, contours_saveto, thumbnail_width / self.width) + # Save thumbnail image + thumbnail_saveto = os.path.join(job_dir, 'thumbnails', f'{self.name}.jpg') + os.makedirs(os.path.dirname(thumbnail_saveto), exist_ok=True) + thumbnail.save(thumbnail_saveto) + return gdf_saveto def get_best_level_and_custom_downsample( @@ -556,6 +567,7 @@ def get_best_level_and_custom_downsample( >>> print(level, custom_downsample) 2, 1.1 """ + self._lazy_initialize() # ensure WSI has been init level_downsamples = self.level_downsamples # First, check for an exact match within tolerance @@ -684,6 +696,7 @@ def visualize_coords(self, coords_path: str, save_patch_viz: str) -> str: >>> print(viz_path) output_viz/sample_name.png """ + self._lazy_initialize() # ensure WSI has been init try: coords_attrs, coords = read_coords(coords_path) # Coords are ALWAYS wrt. level 0 of the slide. @@ -762,11 +775,11 @@ def extract_patch_features( self, patch_encoder: torch.nn.Module, coords_path: str, - save_features: str, + save_features: str = None, device: str = 'cuda:0', saveas: str = 'h5', batch_limit: int = 512 - ) -> str: + ) -> Union[str, np.array]: """ The `extract_features` function of the class `OpenSlideWSI` extracts feature embeddings from the WSI using a specified patch encoder. It processes the patches as specified @@ -778,8 +791,8 @@ def extract_patch_features( The model used for feature extraction. coords_path : str Path to the file containing patch coordinates. - save_features : str - Directory path to save the extracted features. + save_features : str, optional + Directory path to save the extracted features. Defaults to None. device : str, optional Device to run feature extraction on (e.g., 'cuda:0'). Defaults to 'cuda:0'. saveas : str, optional @@ -789,8 +802,9 @@ def extract_patch_features( Returns: -------- - str: - The absolute file path to the saved feature file in the specified format. + Union[str, np.array]: + If `save_features` provides a save path, returns the absolute file path to the saved patch feature file in the specified format. + If `save_features` is None, returns a numpy array of dimensions [#patches x patch_emb_dim] with the patch embeddings. Example: -------- @@ -837,6 +851,10 @@ def extract_patch_features( # Concatenate features features = np.concatenate(features, axis=0) + # Return features as numpy if no where to save them + if save_features is None: + return features + # Save the features to disk os.makedirs(save_features, exist_ok=True) if saveas == 'h5': @@ -862,9 +880,9 @@ def extract_slide_features( self, patch_features_path: str, slide_encoder: torch.nn.Module, - save_features: str, + save_features: str = None, device: str = 'cuda', - ) -> str: + ) -> Union[str, np.array]: """ Extract slide-level features by encoding patch-level features using a pretrained slide encoder. @@ -875,11 +893,14 @@ def extract_slide_features( Args: patch_features_path (str): Path to the HDF5 file containing patch-level features and coordinates. slide_encoder (torch.nn.Module): Pretrained slide encoder model for generating slide-level features. - save_features (str): Directory where the extracted slide features will be saved. + save_features (str, optional): Directory where the extracted slide features will be saved. Defaults to None. device (str, optional): Device to run computations on (e.g., 'cuda', 'cpu'). Defaults to 'cuda'. Returns: - str: The absolute path to the slide-level features. + -------- + Union[str, np.array]: + If `save_features` provides a save path, returns the absolute file path to the saved slide feature file as h5. + If `save_features` is None, returns a numpy array of dimensions [slide_emb_dim] with the slide embeddings. Workflow: 1. Load the pretrained slide encoder model and set it to evaluation mode. @@ -938,6 +959,10 @@ def extract_slide_features( features = slide_encoder(batch, device) features = features.float().cpu().numpy().squeeze() + # Return slide embeddings if no save_features directory is specified: + if save_features is None: + return features + # Save slide-level features if save path is provided os.makedirs(save_features, exist_ok=True) save_path = os.path.join(save_features, f'{self.name}.h5') diff --git a/trident/wsi_objects/WSIPatcher.py b/trident/wsi_objects/WSIPatcher.py index 02e7224..45fbf31 100644 --- a/trident/wsi_objects/WSIPatcher.py +++ b/trident/wsi_objects/WSIPatcher.py @@ -31,7 +31,7 @@ def __init__( patch_size (int): patch width/height in pixel on the slide after rescaling src_pixel_size (float, optional): pixel size in um/px of the slide before rescaling. Defaults to None. dst_pixel_size (float, optional): pixel size in um/px of the slide after rescaling. Defaults to None. - src_mag (int, optional): level0 magnification of the slide before rescaling. Defaults to None. + src_mag (int, optional): level0 magnification of the slide before rescaling. Defaults to None. dst_mag (int, optional): target magnification of the slide after rescaling. Defaults to None. overlap (int, optional): Overlap between patches in pixels. Defaults to 0. mask (gpd.GeoDataFrame, optional): geopandas dataframe of Polygons. Defaults to None.