Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ This project was developed by the [Mahmood Lab](https://faisal.ai/) at Harvard M

- **Tissue Segmentation**: Extract tissue from background (supports H&E, IHC, penmark and artifact removal, etc.).
- **Patch Extraction**: Extract tissue patches of any size and magnification.
- **Patch Feature Extraction**: Extract patch embeddings using one of 20 foundation models, including [UNI](https://www.nature.com/articles/s41591-024-02857-3), [Virchow](https://www.nature.com/articles/s41591-024-03141-0), [H-Optimus-0](https://github.com/bioptimus/releases/tree/main/models/h-optimus/v0) and more...
- **Patch Feature Extraction**: Extract patch embeddings using one of 21 foundation models, including [UNI](https://www.nature.com/articles/s41591-024-02857-3), [Virchow](https://www.nature.com/articles/s41591-024-03141-0), [H-Optimus-0](https://github.com/bioptimus/releases/tree/main/models/h-optimus/v0) and more...
- **Slide Feature Extraction**: Extract slide embeddings using one of 6 slide foundation models, including [Threads](https://arxiv.org/abs/2501.16652) (coming soon!), [Titan](https://arxiv.org/abs/2411.19666), and [GigaPath](https://www.nature.com/articles/s41586-024-07441-w).

### Updates:
- 02.25: New image converter from `czi`, `png`, etc to `tiff`.
- 02.25: Support for [GrandQC](https://www.nature.com/articles/s41467-024-54769-y) tissue vs. background segmentation.
- 02.25: Support for [Madeleine](https://github.com/mahmoodlab/MADELEINE/tree/main), [Hibou](https://github.com/HistAI/hibou), [Lunit](https://huggingface.co/1aurent/vit_small_patch8_224.lunit_dino) and [Kaiko](https://huggingface.co/histai/hibou-L) models.
- 02.25: Support for [Madeleine](https://github.com/mahmoodlab/MADELEINE/tree/main), [Hibou](https://github.com/HistAI/hibou), [Lunit](https://huggingface.co/1aurent/vit_small_patch8_224.lunit_dino), [Kaiko](https://huggingface.co/histai/hibou-L), and [H-Optimus1](https://huggingface.co/bioptimus/H-optimus-1) models.

### 🔨 1. **Installation**:
- Create an environment: `conda create -n "trident" python=3.10`, and activate it `conda activate trident`.
Expand Down Expand Up @@ -105,6 +105,7 @@ Trident supports 20 patch encoders, loaded via a patch-level [`encoder_factory`]
- **Phikon-v2**: [owkin/phikon-v2](https://huggingface.co/owkin/phikon-v2/) (`--patch_encoder phikon_v2`)
- **Prov-Gigapath**: [prov-gigapath](https://huggingface.co/prov-gigapath/prov-gigapath) (`--patch_encoder gigapath`)
- **H-Optimus-0**: [bioptimus/H-optimus-0](https://huggingface.co/bioptimus/H-optimus-0) (`--patch_encoder hoptimus0`)
- **H-Optimus-1**: [bioptimus/H-optimus-1](https://huggingface.co/bioptimus/H-optimus-1) (`--patch_encoder hoptimus1`)
- **MUSK**: [xiangjx/musk](https://huggingface.co/xiangjx/musk) (`--patch_encoder musk`)
- **Kaiko**: Hosted on TorchHub (`--patch_encoder kaiko-vits8, kaiko-vits16, kaiko-vitb8, kaiko-vitb16, kaiko-vitl14`)
- **Lunit**: [1aurent/vit_small_patch8_224.lunit_dino](https://huggingface.co/1aurent/vit_small_patch8_224.lunit_dino) (`--patch_encoder lunit-vits8`)
Expand Down
77 changes: 51 additions & 26 deletions trident/wsi_objects/WSI.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import warnings
import torch
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Union
from torch.utils.data import DataLoader

from trident.wsi_objects.WSIPatcher import *
Expand Down Expand Up @@ -156,6 +156,7 @@ def get_dimensions(self) -> Tuple[int, int]:
>>> wsi.get_dimensions()
(100000, 80000)
"""
self._lazy_initialize() # ensure WSI has been init
return self.img.dimensions

def read_region(
Expand Down Expand Up @@ -187,6 +188,7 @@ def read_region(
>>> print(region.shape)
(512, 512, 3)
"""
self._lazy_initialize() # ensure WSI has been init
return np.array(self.read_region_pil(location, level, size))

def read_region_pil(
Expand Down Expand Up @@ -217,6 +219,7 @@ def read_region_pil(
>>> region = wsi.read_region_pil((0, 0), level=0, size=(512, 512))
>>> region.show()
"""
self._lazy_initialize() # ensure WSI has been init
return self.img.read_region(location, level, size).convert('RGB')

def get_thumbnail(self, size: tuple[int, int]) -> Image.Image:
Expand All @@ -238,6 +241,7 @@ def get_thumbnail(self, size: tuple[int, int]) -> Image.Image:
>>> thumbnail = wsi.get_thumbnail((256, 256))
>>> thumbnail.show()
"""
self._lazy_initialize() # ensure WSI has been init
return self.img.get_thumbnail(size).convert('RGB')

def create_patcher(
Expand Down Expand Up @@ -302,6 +306,7 @@ def _fetch_mpp(self, custom_mpp_keys: List[str] | None = None) -> float | None:
If the MPP value is unavailable, this method attempts to calculate it
from the slide's resolution metadata.
"""
self._lazy_initialize() # ensure WSI has been init
mpp_x = None
mpp_keys = [
openslide.PROPERTY_NAME_MPP_X,
Expand Down Expand Up @@ -368,6 +373,7 @@ def _fetch_magnification(self, custom_mpp_keys: List[str] | None = None) -> int
>>> print(mag)
40
"""
self._lazy_initialize() # ensure WSI has been init
try:
if self.mpp is None:
mpp_x = self._fetch_mpp(custom_mpp_keys)
Expand Down Expand Up @@ -412,7 +418,7 @@ def segment_tissue(
holes_are_tissue: bool = True,
job_dir: str | None = None,
batch_size: int = 16,
) -> str:
) -> Union[str, gpd.GeoDataFrame]:
"""
The `segment_tissue` function of the class `OpenSlideWSI` segments tissue regions in the WSI using
a specified segmentation model. It processes the WSI at a target magnification level, optionally
Expand All @@ -433,16 +439,17 @@ def segment_tissue(

Returns:
--------
str:
The absolute path to where the segmentation as GeoJSON is saved.
Union[str, geopandas.GeoDataFrame]:
If `job_dir` is provided, returns the absolute path to where the segmentation is saved as a GeoJSON.
If `job_dir` is None, returns a geodataframe where each entry is a contour tissue.

Example:
--------
>>> wsi.segment_tissue(segmentation_model, target_mag=10, job_dir="output_dir")
>>> # Results saved in "output_dir"
"""

self._lazy_initialize()
self._lazy_initialize() # ensure WSI has been init
max_dimension = 1000
if self.width > self.height:
thumbnail_width = max_dimension
Expand Down Expand Up @@ -496,31 +503,35 @@ def segment_tissue(
for hole in holes:
cv2.drawContours(predicted_mask, [hole], 0, 255, -1)

# Save thumbnail image
thumbnail_saveto = os.path.join(job_dir, 'thumbnails', f'{self.name}.jpg')
os.makedirs(os.path.dirname(thumbnail_saveto), exist_ok=True)
thumbnail.save(thumbnail_saveto)

# Save geopandas contours
gdf_saveto = os.path.join(job_dir, 'contours_geojson', f'{self.name}.geojson')
os.makedirs(os.path.dirname(gdf_saveto), exist_ok=True)
# Create geopandas contours
gdf_contours = mask_to_gdf(
mask=predicted_mask,
max_nb_holes=0 if holes_are_tissue else 5,
min_contour_area=1000,
pixel_size=self.mpp,
contour_scale=1/mpp_reduction_factor
)
gdf_contours.set_crs("EPSG:3857", inplace=True) # used to silent warning // Web Mercator
).set_crs("EPSG:3857", inplace=True) # used to silent warning // Web Mercator

# return contours if no save path is provided.
if job_dir is None:
return gdf_contours

gdf_saveto = os.path.join(job_dir, 'contours_geojson', f'{self.name}.geojson')
os.makedirs(os.path.dirname(gdf_saveto), exist_ok=True)
gdf_contours.to_file(gdf_saveto, driver="GeoJSON")
self.gdf_contours = gdf_contours
self.tissue_seg_path = gdf_saveto

# Draw the contours on the thumbnail image
# Save thumbnail image with contours drawn
contours_saveto = os.path.join(job_dir, 'contours', f'{self.name}.jpg')
annotated = np.array(thumbnail)
overlay_gdf_on_thumbnail(gdf_contours, annotated, contours_saveto, thumbnail_width / self.width)

# Save thumbnail image
thumbnail_saveto = os.path.join(job_dir, 'thumbnails', f'{self.name}.jpg')
os.makedirs(os.path.dirname(thumbnail_saveto), exist_ok=True)
thumbnail.save(thumbnail_saveto)

return gdf_saveto

def get_best_level_and_custom_downsample(
Expand Down Expand Up @@ -556,6 +567,7 @@ def get_best_level_and_custom_downsample(
>>> print(level, custom_downsample)
2, 1.1
"""
self._lazy_initialize() # ensure WSI has been init
level_downsamples = self.level_downsamples

# First, check for an exact match within tolerance
Expand Down Expand Up @@ -684,6 +696,7 @@ def visualize_coords(self, coords_path: str, save_patch_viz: str) -> str:
>>> print(viz_path)
output_viz/sample_name.png
"""
self._lazy_initialize() # ensure WSI has been init

try:
coords_attrs, coords = read_coords(coords_path) # Coords are ALWAYS wrt. level 0 of the slide.
Expand Down Expand Up @@ -762,11 +775,11 @@ def extract_patch_features(
self,
patch_encoder: torch.nn.Module,
coords_path: str,
save_features: str,
save_features: str = None,
device: str = 'cuda:0',
saveas: str = 'h5',
batch_limit: int = 512
) -> str:
) -> Union[str, np.array]:
"""
The `extract_features` function of the class `OpenSlideWSI` extracts feature embeddings
from the WSI using a specified patch encoder. It processes the patches as specified
Expand All @@ -778,8 +791,8 @@ def extract_patch_features(
The model used for feature extraction.
coords_path : str
Path to the file containing patch coordinates.
save_features : str
Directory path to save the extracted features.
save_features : str, optional
Directory path to save the extracted features. Defaults to None.
device : str, optional
Device to run feature extraction on (e.g., 'cuda:0'). Defaults to 'cuda:0'.
saveas : str, optional
Expand All @@ -789,8 +802,9 @@ def extract_patch_features(

Returns:
--------
str:
The absolute file path to the saved feature file in the specified format.
Union[str, np.array]:
If `save_features` provides a save path, returns the absolute file path to the saved patch feature file in the specified format.
If `save_features` is None, returns a numpy array of dimensions [#patches x patch_emb_dim] with the patch embeddings.

Example:
--------
Expand Down Expand Up @@ -837,6 +851,10 @@ def extract_patch_features(
# Concatenate features
features = np.concatenate(features, axis=0)

# Return features as numpy if no where to save them
if save_features is None:
return features

# Save the features to disk
os.makedirs(save_features, exist_ok=True)
if saveas == 'h5':
Expand All @@ -862,9 +880,9 @@ def extract_slide_features(
self,
patch_features_path: str,
slide_encoder: torch.nn.Module,
save_features: str,
save_features: str = None,
device: str = 'cuda',
) -> str:
) -> Union[str, np.array]:
"""
Extract slide-level features by encoding patch-level features using a pretrained slide encoder.

Expand All @@ -875,11 +893,14 @@ def extract_slide_features(
Args:
patch_features_path (str): Path to the HDF5 file containing patch-level features and coordinates.
slide_encoder (torch.nn.Module): Pretrained slide encoder model for generating slide-level features.
save_features (str): Directory where the extracted slide features will be saved.
save_features (str, optional): Directory where the extracted slide features will be saved. Defaults to None.
device (str, optional): Device to run computations on (e.g., 'cuda', 'cpu'). Defaults to 'cuda'.

Returns:
str: The absolute path to the slide-level features.
--------
Union[str, np.array]:
If `save_features` provides a save path, returns the absolute file path to the saved slide feature file as h5.
If `save_features` is None, returns a numpy array of dimensions [slide_emb_dim] with the slide embeddings.

Workflow:
1. Load the pretrained slide encoder model and set it to evaluation mode.
Expand Down Expand Up @@ -938,6 +959,10 @@ def extract_slide_features(
features = slide_encoder(batch, device)
features = features.float().cpu().numpy().squeeze()

# Return slide embeddings if no save_features directory is specified:
if save_features is None:
return features

# Save slide-level features if save path is provided
os.makedirs(save_features, exist_ok=True)
save_path = os.path.join(save_features, f'{self.name}.h5')
Expand Down
2 changes: 1 addition & 1 deletion trident/wsi_objects/WSIPatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
patch_size (int): patch width/height in pixel on the slide after rescaling
src_pixel_size (float, optional): pixel size in um/px of the slide before rescaling. Defaults to None.
dst_pixel_size (float, optional): pixel size in um/px of the slide after rescaling. Defaults to None.
src_mag (int, optional): level0 magnification of the slide before rescaling. Defaults to None.
src_mag (int, optional): level0 magnification of the slide before rescaling. Defaults to None.
dst_mag (int, optional): target magnification of the slide after rescaling. Defaults to None.
overlap (int, optional): Overlap between patches in pixels. Defaults to 0.
mask (gpd.GeoDataFrame, optional): geopandas dataframe of Polygons. Defaults to None.
Expand Down