diff --git a/README.md b/README.md index 1baadc2..f3cd3f0 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,8 @@ Additional packages may be required to load some pretrained models. Follow error python run_batch_of_slides.py --task all --wsi_dir ./wsis --job_dir ./trident_processed --patch_encoder uni_v1 --mag 20 --patch_size 256 ``` +Add `--gpus 0 1` (or any list of device indices) to split slides across multiple GPUs automatically. + **Feeling cautious?** Run this command to perform all processing steps for a **single** slide: diff --git a/run_batch_of_slides.py b/run_batch_of_slides.py index 7c68e62..1ef519d 100644 --- a/run_batch_of_slides.py +++ b/run_batch_of_slides.py @@ -1,19 +1,21 @@ -""" -Example usage: - -``` -python run_batch_of_slides.py --task all --wsi_dir output/wsis --job_dir output --patch_encoder uni_v1 --mag 20 --patch_size 256 -``` - -""" import os +import math import argparse import torch -from typing import Any +import multiprocessing as mp +import shutil +from typing import Any, List +from queue import Queue +from threading import Thread +import warnings +from tqdm import tqdm +warnings.filterwarnings("ignore", category=FutureWarning) from trident import Processor from trident.patch_encoder_models import encoder_registry as patch_encoder_registry from trident.slide_encoder_models import encoder_registry as slide_encoder_registry +from trident.Concurrency import batch_producer, batch_consumer +from trident.IO import collect_valid_slides def build_parser() -> argparse.ArgumentParser: @@ -29,6 +31,8 @@ def build_parser() -> argparse.ArgumentParser: # Generic arguments parser.add_argument('--gpu', type=int, default=0, help='GPU index to use for processing tasks.') + parser.add_argument('--gpus', type=int, nargs='+', default=None, + help='Optional space-separated list of GPU indices to enable multi-GPU execution.') parser.add_argument('--task', type=str, default='seg', choices=['seg', 'coords', 'feat', 'all'], help='Task to run: seg (segmentation), coords (save tissue coordinates), img (save tissue images), feat (extract features).') @@ -165,6 +169,7 @@ def initialize_processor(args: argparse.Namespace) -> Processor: max_workers=args.max_workers, reader_type=args.reader_type, search_nested=args.search_nested, + selected_wsi_paths=getattr(args, 'selected_wsi_paths', None), ) @@ -180,6 +185,8 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None: Parsed command-line arguments containing task configuration. """ + device = getattr(args, 'device', f'cuda:{args.gpu}') + if args.task == 'seg': from trident.segmentation_models.load import segmentation_model_factory @@ -203,7 +210,7 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None: holes_are_tissue= not args.remove_holes, artifact_remover_model=artifact_remover_model, batch_size=args.seg_batch_size if args.seg_batch_size is not None else args.batch_size, - device=f'cuda:{args.gpu}', + device=device, ) elif args.task == 'coords': processor.run_patching_job( @@ -220,7 +227,7 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None: processor.run_patch_feature_extraction_job( coords_dir=args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap', patch_encoder=encoder, - device=f'cuda:{args.gpu}', + device=device, saveas='h5', batch_limit=args.feat_batch_size if args.feat_batch_size is not None else args.batch_size, ) @@ -230,7 +237,7 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None: processor.run_slide_feature_extraction_job( slide_encoder=encoder, coords_dir=args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap', - device=f'cuda:{args.gpu}', + device=device, saveas='h5', batch_limit=args.feat_batch_size if args.feat_batch_size is not None else args.batch_size, ) @@ -238,42 +245,96 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None: raise ValueError(f'Invalid task: {args.task}') -def main() -> None: - """ - Main entry point for the Trident batch processing script. - - Handles both sequential and parallel processing modes based on whether - WSI caching is enabled. Supports segmentation, coordinate extraction, - and feature extraction tasks. - """ +def cleanup_files(job_dir: str, cache_dir: str = None) -> None: + if os.path.isdir(job_dir): + for root, _, files in os.walk(job_dir): + for f in files: + if f.endswith('.lock'): + try: os.remove(os.path.join(root, f)) + except OSError: pass + if cache_dir and os.path.isdir(cache_dir): + try: shutil.rmtree(cache_dir) + except OSError: pass + os.makedirs(cache_dir, exist_ok=True) - args = parse_arguments() - args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu' +def get_pending_slides(args: argparse.Namespace) -> List[str]: + all_slides = collect_valid_slides( + wsi_dir=args.wsi_dir, custom_list_path=args.custom_list_of_wsis, + wsi_ext=args.wsi_ext, search_nested=args.search_nested, max_workers=args.max_workers + ) + + tasks = ['seg', 'coords', 'feat'] if args.task == 'all' else [args.task] + coords_dir = args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap' + pending = [] + + def safe_listdir(path: str) -> List[str]: + try: + return os.listdir(path) + except (FileNotFoundError, NotADirectoryError): + return [] + + seg_done = set() + coords_done = set() + feat_done = set() + + if 'seg' in tasks: + contour_dir = os.path.join(args.job_dir, 'contours') + seg_done = {os.path.splitext(f)[0] for f in safe_listdir(contour_dir) if f.lower().endswith('.jpg')} + + if 'coords' in tasks: + patches_dir = os.path.join(args.job_dir, coords_dir, 'patches') + coords_done = { + f[:-len('_patches')] + for f in (os.path.splitext(fname)[0] for fname in safe_listdir(patches_dir) if fname.endswith('_patches.h5')) + } + + if 'feat' in tasks: + feat_sub = f'slide_features_{args.slide_encoder}' if args.slide_encoder else f'features_{args.patch_encoder}' + feat_dir = os.path.join(args.job_dir, coords_dir, feat_sub) + feat_done = {os.path.splitext(f)[0] for f in safe_listdir(feat_dir) if os.path.splitext(f)[1] in {'.h5', '.pt'}} + + for slide in tqdm(all_slides, desc="Checking slide status", unit="slide"): + stem = os.path.splitext(os.path.basename(slide))[0] + is_done = True + for t in tasks: + if t == 'seg' and stem not in seg_done: + is_done = False + elif t == 'coords' and stem not in coords_done: + is_done = False + elif t == 'feat' and stem not in feat_done: + is_done = False + if not is_done: break + + if not is_done: pending.append(slide) + + print(f"[MAIN] Found {len(all_slides)} slides. Processing {len(pending)} pending slides ({len(all_slides)-len(pending)} skipped).") + return pending + + +def worker_entrypoint(args: argparse.Namespace) -> None: + """ + Entry point for each GPU worker process. + Handles both cached (threading) and non-cached (sequential) execution modes + inside the isolated process. + """ if args.wsi_cache: - # === Parallel pipeline with caching === - - from queue import Queue - from threading import Thread - - from trident.Concurrency import batch_producer, batch_consumer, cache_batch - from trident.IO import collect_valid_slides - + # === Parallel pipeline with caching (Threaded inside Process) === + # Setup specific cache dir for this GPU process + gpu_cache_dir = os.path.join(args.wsi_cache, f"gpu_{args.gpu}") + os.makedirs(gpu_cache_dir, exist_ok=True) + + valid_slides = list(args.selected_wsi_paths or []) + if not valid_slides: + print(f"[WORKER {args.gpu}] No slides assigned. Skipping cached pipeline.") + return + + batch_size = max(1, args.cache_batch_size or len(valid_slides)) + queue = Queue(maxsize=1) - valid_slides = collect_valid_slides( - wsi_dir=args.wsi_dir, - custom_list_path=args.custom_list_of_wsis, - wsi_ext=args.wsi_ext, - search_nested=args.search_nested, - max_workers=args.max_workers - ) - print(f"[MAIN] Found {len(valid_slides)} valid slides in {args.wsi_dir}.") - - warm = valid_slides[:args.cache_batch_size] - warmup_dir = os.path.join(args.wsi_cache, "batch_0") - print(f"[MAIN] Warmup caching batch: {warmup_dir}") - cache_batch(warm, warmup_dir) - queue.put(0) + + # No pre-warming: let producer handle all batching + start_idx = 0 def processor_factory(wsi_dir: str) -> Processor: local_args = argparse.Namespace(**vars(args)) @@ -281,25 +342,29 @@ def processor_factory(wsi_dir: str) -> Processor: local_args.wsi_cache = None local_args.custom_list_of_wsis = None local_args.search_nested = False + local_args.selected_wsi_paths = None return initialize_processor(local_args) def run_task_fn(processor: Processor, task_name: str) -> None: - args.task = task_name - run_task(processor, args) + # We must use a local copy of args to update the task without affecting others + local_args = argparse.Namespace(**vars(args)) + local_args.task = task_name + local_args.selected_wsi_paths = None + run_task(processor, local_args) producer = Thread(target=batch_producer, args=( - queue, valid_slides, args.cache_batch_size, args.cache_batch_size, args.wsi_cache + queue, valid_slides, start_idx, batch_size, gpu_cache_dir )) consumer = Thread(target=batch_consumer, args=( - queue, args.task, args.wsi_cache, processor_factory, run_task_fn + queue, args.task, gpu_cache_dir, processor_factory, run_task_fn )) - print("[MAIN] Starting producer and consumer threads.") producer.start() consumer.start() producer.join() consumer.join() + else: # === Sequential mode === processor = initialize_processor(args) @@ -309,5 +374,54 @@ def run_task_fn(processor: Processor, task_name: str) -> None: run_task(processor, args) +def main() -> None: + """ + Main entry point for the Trident batch processing script. + + Handles both sequential and parallel processing modes based on whether + WSI caching is enabled. Supports segmentation, coordinate extraction, + and feature extraction tasks. + """ + + args = parse_arguments() + cleanup_files(args.job_dir, args.wsi_cache) + + if args.gpus: + gpu_ids = list(dict.fromkeys(args.gpus)) + else: + gpu_ids = [args.gpu] + + if not torch.cuda.is_available() and any(g >= 0 for g in gpu_ids): + print('[MAIN] Warning: CUDA not available, using CPU.') + gpu_ids = [-1] + + pending_slides = get_pending_slides(args) + if not pending_slides: + return + + num_shards = len(gpu_ids) + shards = [[] for _ in range(num_shards)] + for i, slide in enumerate(pending_slides): + shards[i % num_shards].append(slide) + + ctx = mp.get_context('spawn') if torch.cuda.is_available() else mp.get_context('fork') + processes = [] + + for i, gpu_id in enumerate(gpu_ids): + if not shards[i]: continue + + worker_args = argparse.Namespace(**vars(args)) + worker_args.gpu = gpu_id + worker_args.device = f'cuda:{gpu_id}' if gpu_id >= 0 else 'cpu' + worker_args.selected_wsi_paths = shards[i] + + p = ctx.Process(target=worker_entrypoint, args=(worker_args,)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/tests/test_multigpu.py b/tests/test_multigpu.py new file mode 100644 index 0000000..ede7118 --- /dev/null +++ b/tests/test_multigpu.py @@ -0,0 +1,153 @@ +import unittest +import sys +import os +import time +import shutil +import tempfile +from unittest.mock import patch +import torch + +import run_batch_of_slides as trident_run + +# ========================================== +# USER CONFIGURATION +# ========================================== +# Path to a directory containing WSI files +WSI_DIR = "/path/to/wsi/files" + +# GPU IDs to use (e.g., [0, 1]) +GPU_IDS = [0, 1] if torch.cuda.is_available() else [-1] + +# Trident task and model +TASK = "all" +PATCH_ENCODER = "conch_v15" + +# Inference and patch settings +BATCH_SIZE = 32 +MAGNIFICATION = 20 +PATCH_SIZE = 512 + +# Cache settings +CACHE_BATCH_SIZE = 2 + +# Misc +SCRIPT_NAME = "trident_run.py" +# ========================================== + + +class TestTridentProfiling(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not os.path.exists(WSI_DIR): + raise FileNotFoundError( + f"Please update WSI_DIR in the test script. Path not found: {WSI_DIR}" + ) + + cls.task = TASK + cls.encoder = PATCH_ENCODER + + print(f"\n=== Starting Profiling on {len(GPU_IDS)} GPU(s) ===") + print(f"Source: {WSI_DIR}") + + def setUp(self): + # Fresh temp directories for each test + self.test_dir = tempfile.mkdtemp() + self.job_dir = os.path.join(self.test_dir, "job_output") + self.cache_dir = os.path.join(self.test_dir, "wsi_cache") + os.makedirs(self.job_dir, exist_ok=True) + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def run_trident_scenario(self, run_name, gpus, use_cache=False): + """ + Helper to run the trident main function with specific arguments. + """ + print(f"\n[Running Scenario]: {run_name}") + + args = [ + SCRIPT_NAME, + "--wsi_dir", WSI_DIR, + "--job_dir", self.job_dir, + "--task", self.task, + "--patch_encoder", self.encoder, + "--batch_size", str(BATCH_SIZE), + "--mag", str(MAGNIFICATION), + "--patch_size", str(PATCH_SIZE), + "--gpus", + ] + [str(g) for g in gpus] + + if use_cache: + args.extend([ + "--wsi_cache", self.cache_dir, + "--cache_batch_size", str(CACHE_BATCH_SIZE), + ]) + + with patch.object(sys, "argv", args): + start_time = time.time() + try: + trident_run.main() + except SystemExit as e: + if e.code != 0: + self.fail(f"Trident exited with error code {e.code}") + except Exception as e: + self.fail(f"Trident crashed: {e}") + end_time = time.time() + + # Remove job output directory after each run + if os.path.exists(self.job_dir): + shutil.rmtree(self.job_dir) + os.makedirs(self.job_dir, exist_ok=True) + + duration = end_time - start_time + print(f"[Completed]: {run_name} in {duration:.2f} seconds") + return duration + + def test_benchmark_scenarios(self): + """ + Run 4 scenarios (Single/Multi GPU × Cache/NoCache) and print a comparison. + """ + results = {} + + # 1. Single GPU - No Cache + results["1GPU_NoCache"] = self.run_trident_scenario( + "Single GPU | Direct Read", + gpus=[GPU_IDS[0]], + use_cache=False, + ) + + # 2. Single GPU - With Cache + results["1GPU_Cache"] = self.run_trident_scenario( + "Single GPU | Cached Read", + gpus=[GPU_IDS[0]], + use_cache=True, + ) + + # 3–4. Multi-GPU only if >1 GPU configured + if len(GPU_IDS) > 1: + results["MultiGPU_NoCache"] = self.run_trident_scenario( + f"{len(GPU_IDS)} GPUs | Direct Read", + gpus=GPU_IDS, + use_cache=False, + ) + + results["MultiGPU_Cache"] = self.run_trident_scenario( + f"{len(GPU_IDS)} GPUs | Cached Read", + gpus=GPU_IDS, + use_cache=True, + ) + + # Report + print("\n" + "=" * 40) + print(f"{'SCENARIO':<25} | {'TIME (s)':<10}") + print("-" * 40) + for name, duration in results.items(): + print(f"{name:<25} | {duration:<10.2f}") + print("=" * 40) + + self.assertTrue(all(t > 0 for t in results.values())) + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn", force=True) + unittest.main() \ No newline at end of file diff --git a/trident/IO.py b/trident/IO.py index 0697a14..bf92d09 100644 --- a/trident/IO.py +++ b/trident/IO.py @@ -910,6 +910,9 @@ def get_num_workers(batch_size: int, if os.name == 'nt': return 0 + if max_workers is not None and max_workers <= 0: + return 0 + num_cores = os.cpu_count() or fallback num_workers = int(factor * num_cores) # Use a fraction of available cores max_workers = max_workers or (2 * batch_size) # Optional cap diff --git a/trident/Processor.py b/trident/Processor.py index 94c68b7..a943093 100644 --- a/trident/Processor.py +++ b/trident/Processor.py @@ -28,6 +28,7 @@ def __init__( max_workers: Optional[int] = None, reader_type: Optional[WSIReaderType] = None, search_nested: bool = False, + selected_wsi_paths: Optional[List[str]] = None, ) -> None: """ The `Processor` class handles all preprocessing steps starting from whole-slide images (WSIs). @@ -82,6 +83,9 @@ def __init__( the filename (excluding directory structure) will be used for downstream outputs (e.g., segmentation filenames). If False, only files directly inside `wsi_source` will be considered. Defaults to False. + selected_wsi_paths (List[str], optional): + Optional explicit list of absolute slide paths to process. When provided, `collect_valid_slides` + is skipped and only the supplied slides are used (useful for distributed processing). Defaults to None. Returns: @@ -117,15 +121,19 @@ def __init__( for ext in self.wsi_ext: assert ext.startswith('.'), f'Invalid extension: {ext} (must start with a period)' - # === Collect slide paths and relative paths === - full_paths, rel_paths = collect_valid_slides( - wsi_dir=wsi_source, - custom_list_path=custom_list_of_wsis, - wsi_ext=self.wsi_ext, - search_nested=search_nested, - max_workers=max_workers, - return_relative_paths=True - ) + # Collect slide paths and relative paths + if selected_wsi_paths is not None: + full_paths = selected_wsi_paths + rel_paths = [os.path.relpath(path, wsi_source) for path in selected_wsi_paths] + else: + full_paths, rel_paths = collect_valid_slides( + wsi_dir=wsi_source, + custom_list_path=custom_list_of_wsis, + wsi_ext=self.wsi_ext, + search_nested=search_nested, + max_workers=max_workers, + return_relative_paths=True + ) self.wsi_rel_paths = rel_paths if custom_list_of_wsis else None diff --git a/trident/slide_encoder_models/load.py b/trident/slide_encoder_models/load.py index e93db80..ca5f0f7 100644 --- a/trident/slide_encoder_models/load.py +++ b/trident/slide_encoder_models/load.py @@ -330,14 +330,22 @@ def _build(self, pretrained=True): raise Exception("Please install fairscale and gigapath using `pip install fairscale git+https://github.com/prov-gigapath/prov-gigapath.git`.") # Make sure flash_attn is correct version - try: - import flash_attn; assert flash_attn.__version__ == '2.5.8' - except: - traceback.print_exc() - raise Exception("Please install flash_attn version 2.5.8 using `pip install flash_attn==2.5.8`.") + # try: + # import flash_attn; assert flash_attn.__version__ == '2.5.8' + # except: + # traceback.print_exc() + # raise Exception("Please install flash_attn version 2.5.8 using `pip install flash_attn==2.5.8`.") if pretrained: - model = create_model("hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536, global_pool=True) + # Try to get local weights path first + weights_path = get_weights_path('slide', self.enc_name) + if weights_path: + print(f"Loading GigaPath slide encoder from local path: {weights_path}") + model = create_model(weights_path, "gigapath_slide_enc12l768d", 1536, global_pool=True) + else: + # Fallback to downloading from Hugging Face Hub + print("Local weights not found. Downloading from Hugging Face Hub...") + model = create_model("hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536, global_pool=True) else: model = create_model("", "gigapath_slide_enc12l768d", 1536, global_pool=True) diff --git a/trident/wsi_objects/WSI.py b/trident/wsi_objects/WSI.py index 6cdeb22..2404c17 100644 --- a/trident/wsi_objects/WSI.py +++ b/trident/wsi_objects/WSI.py @@ -2,6 +2,7 @@ import numpy as np import os import warnings +import multiprocessing as mp import torch from typing import List, Tuple, Optional, Literal, Union from torch.utils.data import DataLoader @@ -17,6 +18,11 @@ ReadMode = Literal['pil', 'numpy'] +try: + _DATALOADER_MP_CTX = mp.get_context('fork') if 'fork' in mp.get_all_start_methods() else None +except (ValueError, AttributeError): + _DATALOADER_MP_CTX = None + class WSI: """ @@ -313,11 +319,15 @@ def _segment_semantic( precision = segmentation_model.precision eval_transforms = segmentation_model.eval_transforms dataset = WSIPatcherDataset(patcher, eval_transforms) + inferred_workers = get_num_workers(batch_size, max_workers=self.max_workers) if num_workers is None else num_workers + dataloader_ctx = _DATALOADER_MP_CTX if inferred_workers and inferred_workers > 0 else None + dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=collate_fn, - num_workers=get_num_workers(batch_size, max_workers=self.max_workers) if num_workers is None else num_workers, + num_workers=inferred_workers, + multiprocessing_context=dataloader_ctx, pin_memory=True ) @@ -815,10 +825,17 @@ def extract_patch_features( coords_only=False, pil=True, ) - - dataset = WSIPatcherDataset(patcher, patch_transforms) - dataloader = DataLoader(dataset, batch_size=batch_limit, num_workers=get_num_workers(batch_limit, max_workers=self.max_workers), pin_memory=False) + inferred_workers = get_num_workers(batch_limit, max_workers=self.max_workers) + dataloader_ctx = _DATALOADER_MP_CTX if inferred_workers and inferred_workers > 0 else None + + dataloader = DataLoader( + dataset, + batch_size=batch_limit, + num_workers=inferred_workers, + pin_memory=False, + multiprocessing_context=dataloader_ctx, + ) dataloader = tqdm(dataloader) if verbose else dataloader @@ -826,7 +843,7 @@ def extract_patch_features( for imgs, _ in dataloader: imgs = imgs.to(device) with torch.autocast(device_type='cuda', dtype=precision, enabled=(precision != torch.float32)): - batch_features = patch_encoder(imgs) + batch_features = patch_encoder(imgs) features.append(batch_features.cpu().numpy()) # Concatenate features