From 98de6dbbbe3b9a0fe1fc68c094c1a4f45efca47d Mon Sep 17 00:00:00 2001 From: winglet0996 Date: Mon, 24 Nov 2025 03:19:43 +0800 Subject: [PATCH 1/7] multigpu support --- README.md | 2 + run_batch_of_slides.py | 173 +++++++++++++++++++++++++++---------- tests/test_multigpu.py | 153 ++++++++++++++++++++++++++++++++ trident/IO.py | 3 + trident/Processor.py | 26 ++++-- trident/wsi_objects/WSI.py | 27 ++++-- 6 files changed, 324 insertions(+), 60 deletions(-) create mode 100644 tests/test_multigpu.py diff --git a/README.md b/README.md index 1baadc2..f3cd3f0 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,8 @@ Additional packages may be required to load some pretrained models. Follow error python run_batch_of_slides.py --task all --wsi_dir ./wsis --job_dir ./trident_processed --patch_encoder uni_v1 --mag 20 --patch_size 256 ``` +Add `--gpus 0 1` (or any list of device indices) to split slides across multiple GPUs automatically. + **Feeling cautious?** Run this command to perform all processing steps for a **single** slide: diff --git a/run_batch_of_slides.py b/run_batch_of_slides.py index 7c68e62..1976c5e 100644 --- a/run_batch_of_slides.py +++ b/run_batch_of_slides.py @@ -1,19 +1,17 @@ -""" -Example usage: - -``` -python run_batch_of_slides.py --task all --wsi_dir output/wsis --job_dir output --patch_encoder uni_v1 --mag 20 --patch_size 256 -``` - -""" import os import argparse import torch -from typing import Any +import multiprocessing as mp +import shutil +from typing import Any, List +from queue import Queue +from threading import Thread from trident import Processor from trident.patch_encoder_models import encoder_registry as patch_encoder_registry from trident.slide_encoder_models import encoder_registry as slide_encoder_registry +from trident.Concurrency import batch_producer, batch_consumer, cache_batch +from trident.IO import collect_valid_slides def build_parser() -> argparse.ArgumentParser: @@ -29,6 +27,8 @@ def build_parser() -> argparse.ArgumentParser: # Generic arguments parser.add_argument('--gpu', type=int, default=0, help='GPU index to use for processing tasks.') + parser.add_argument('--gpus', type=int, nargs='+', default=None, + help='Optional space-separated list of GPU indices to enable multi-GPU execution.') parser.add_argument('--task', type=str, default='seg', choices=['seg', 'coords', 'feat', 'all'], help='Task to run: seg (segmentation), coords (save tissue coordinates), img (save tissue images), feat (extract features).') @@ -165,6 +165,7 @@ def initialize_processor(args: argparse.Namespace) -> Processor: max_workers=args.max_workers, reader_type=args.reader_type, search_nested=args.search_nested, + selected_wsi_paths=getattr(args, 'selected_wsi_paths', None), ) @@ -180,6 +181,8 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None: Parsed command-line arguments containing task configuration. """ + device = getattr(args, 'device', f'cuda:{args.gpu}') + if args.task == 'seg': from trident.segmentation_models.load import segmentation_model_factory @@ -203,7 +206,7 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None: holes_are_tissue= not args.remove_holes, artifact_remover_model=artifact_remover_model, batch_size=args.seg_batch_size if args.seg_batch_size is not None else args.batch_size, - device=f'cuda:{args.gpu}', + device=device, ) elif args.task == 'coords': processor.run_patching_job( @@ -220,7 +223,7 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None: processor.run_patch_feature_extraction_job( coords_dir=args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap', patch_encoder=encoder, - device=f'cuda:{args.gpu}', + device=device, saveas='h5', batch_limit=args.feat_batch_size if args.feat_batch_size is not None else args.batch_size, ) @@ -230,7 +233,7 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None: processor.run_slide_feature_extraction_job( slide_encoder=encoder, coords_dir=args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap', - device=f'cuda:{args.gpu}', + device=device, saveas='h5', batch_limit=args.feat_batch_size if args.feat_batch_size is not None else args.batch_size, ) @@ -238,40 +241,67 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None: raise ValueError(f'Invalid task: {args.task}') -def main() -> None: - """ - Main entry point for the Trident batch processing script. - - Handles both sequential and parallel processing modes based on whether - WSI caching is enabled. Supports segmentation, coordinate extraction, - and feature extraction tasks. - """ +def cleanup_files(job_dir: str, cache_dir: str = None) -> None: + if os.path.isdir(job_dir): + for root, _, files in os.walk(job_dir): + for f in files: + if f.endswith('.lock'): + try: os.remove(os.path.join(root, f)) + except OSError: pass + if cache_dir and os.path.isdir(cache_dir): + try: shutil.rmtree(cache_dir) + except OSError: pass + os.makedirs(cache_dir, exist_ok=True) - args = parse_arguments() - args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu' +def get_pending_slides(args: argparse.Namespace) -> List[str]: + all_slides = collect_valid_slides( + wsi_dir=args.wsi_dir, custom_list_path=args.custom_list_of_wsis, + wsi_ext=args.wsi_ext, search_nested=args.search_nested, max_workers=args.max_workers + ) + + tasks = ['seg', 'coords', 'feat'] if args.task == 'all' else [args.task] + coords_dir = args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap' + pending = [] + + for slide in all_slides: + stem = os.path.splitext(os.path.basename(slide))[0] + is_done = True + for t in tasks: + if t == 'seg' and not os.path.exists(os.path.join(args.job_dir, 'contours', f'{stem}.jpg')): + is_done = False + elif t == 'coords' and not os.path.exists(os.path.join(args.job_dir, coords_dir, 'patches', f'{stem}_patches.h5')): + is_done = False + elif t == 'feat': + feat_sub = f'slide_features_{args.slide_encoder}' if args.slide_encoder else f'features_{args.patch_encoder}' + feat_dir = os.path.join(args.job_dir, coords_dir, feat_sub) + if not (os.path.isdir(feat_dir) and any(os.path.exists(os.path.join(feat_dir, f'{stem}.{ext}')) for ext in ['h5', 'pt'])): + is_done = False + if not is_done: break + + if not is_done: pending.append(slide) + + print(f"[MAIN] Found {len(all_slides)} slides. Processing {len(pending)} pending slides ({len(all_slides)-len(pending)} skipped).") + return pending + + +def worker_entrypoint(args: argparse.Namespace) -> None: + """ + Entry point for each GPU worker process. + Handles both cached (threading) and non-cached (sequential) execution modes + inside the isolated process. + """ if args.wsi_cache: - # === Parallel pipeline with caching === - - from queue import Queue - from threading import Thread - - from trident.Concurrency import batch_producer, batch_consumer, cache_batch - from trident.IO import collect_valid_slides - + # === Parallel pipeline with caching (Threaded inside Process) === + # Setup specific cache dir for this GPU process + gpu_cache_dir = os.path.join(args.wsi_cache, f"gpu_{args.gpu}") + os.makedirs(gpu_cache_dir, exist_ok=True) + queue = Queue(maxsize=1) - valid_slides = collect_valid_slides( - wsi_dir=args.wsi_dir, - custom_list_path=args.custom_list_of_wsis, - wsi_ext=args.wsi_ext, - search_nested=args.search_nested, - max_workers=args.max_workers - ) - print(f"[MAIN] Found {len(valid_slides)} valid slides in {args.wsi_dir}.") + valid_slides = args.selected_wsi_paths warm = valid_slides[:args.cache_batch_size] - warmup_dir = os.path.join(args.wsi_cache, "batch_0") - print(f"[MAIN] Warmup caching batch: {warmup_dir}") + warmup_dir = os.path.join(gpu_cache_dir, "batch_0") cache_batch(warm, warmup_dir) queue.put(0) @@ -284,22 +314,24 @@ def processor_factory(wsi_dir: str) -> Processor: return initialize_processor(local_args) def run_task_fn(processor: Processor, task_name: str) -> None: - args.task = task_name - run_task(processor, args) + # We must use a local copy of args to update the task without affecting others + local_args = argparse.Namespace(**vars(args)) + local_args.task = task_name + run_task(processor, local_args) producer = Thread(target=batch_producer, args=( - queue, valid_slides, args.cache_batch_size, args.cache_batch_size, args.wsi_cache + queue, valid_slides, args.cache_batch_size, args.cache_batch_size, gpu_cache_dir )) consumer = Thread(target=batch_consumer, args=( - queue, args.task, args.wsi_cache, processor_factory, run_task_fn + queue, args.task, gpu_cache_dir, processor_factory, run_task_fn )) - print("[MAIN] Starting producer and consumer threads.") producer.start() consumer.start() producer.join() consumer.join() + else: # === Sequential mode === processor = initialize_processor(args) @@ -309,5 +341,54 @@ def run_task_fn(processor: Processor, task_name: str) -> None: run_task(processor, args) +def main() -> None: + """ + Main entry point for the Trident batch processing script. + + Handles both sequential and parallel processing modes based on whether + WSI caching is enabled. Supports segmentation, coordinate extraction, + and feature extraction tasks. + """ + + args = parse_arguments() + cleanup_files(args.job_dir, args.wsi_cache) + + if args.gpus: + gpu_ids = list(dict.fromkeys(args.gpus)) + else: + gpu_ids = [args.gpu] + + if not torch.cuda.is_available() and any(g >= 0 for g in gpu_ids): + print('[MAIN] Warning: CUDA not available, using CPU.') + gpu_ids = [-1] + + pending_slides = get_pending_slides(args) + if not pending_slides: + return + + num_shards = len(gpu_ids) + shards = [[] for _ in range(num_shards)] + for i, slide in enumerate(pending_slides): + shards[i % num_shards].append(slide) + + ctx = mp.get_context('spawn') if torch.cuda.is_available() else mp.get_context('fork') + processes = [] + + for i, gpu_id in enumerate(gpu_ids): + if not shards[i]: continue + + worker_args = argparse.Namespace(**vars(args)) + worker_args.gpu = gpu_id + worker_args.device = f'cuda:{gpu_id}' if gpu_id >= 0 else 'cpu' + worker_args.selected_wsi_paths = shards[i] + + p = ctx.Process(target=worker_entrypoint, args=(worker_args,)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/tests/test_multigpu.py b/tests/test_multigpu.py new file mode 100644 index 0000000..ede7118 --- /dev/null +++ b/tests/test_multigpu.py @@ -0,0 +1,153 @@ +import unittest +import sys +import os +import time +import shutil +import tempfile +from unittest.mock import patch +import torch + +import run_batch_of_slides as trident_run + +# ========================================== +# USER CONFIGURATION +# ========================================== +# Path to a directory containing WSI files +WSI_DIR = "/path/to/wsi/files" + +# GPU IDs to use (e.g., [0, 1]) +GPU_IDS = [0, 1] if torch.cuda.is_available() else [-1] + +# Trident task and model +TASK = "all" +PATCH_ENCODER = "conch_v15" + +# Inference and patch settings +BATCH_SIZE = 32 +MAGNIFICATION = 20 +PATCH_SIZE = 512 + +# Cache settings +CACHE_BATCH_SIZE = 2 + +# Misc +SCRIPT_NAME = "trident_run.py" +# ========================================== + + +class TestTridentProfiling(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not os.path.exists(WSI_DIR): + raise FileNotFoundError( + f"Please update WSI_DIR in the test script. Path not found: {WSI_DIR}" + ) + + cls.task = TASK + cls.encoder = PATCH_ENCODER + + print(f"\n=== Starting Profiling on {len(GPU_IDS)} GPU(s) ===") + print(f"Source: {WSI_DIR}") + + def setUp(self): + # Fresh temp directories for each test + self.test_dir = tempfile.mkdtemp() + self.job_dir = os.path.join(self.test_dir, "job_output") + self.cache_dir = os.path.join(self.test_dir, "wsi_cache") + os.makedirs(self.job_dir, exist_ok=True) + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def run_trident_scenario(self, run_name, gpus, use_cache=False): + """ + Helper to run the trident main function with specific arguments. + """ + print(f"\n[Running Scenario]: {run_name}") + + args = [ + SCRIPT_NAME, + "--wsi_dir", WSI_DIR, + "--job_dir", self.job_dir, + "--task", self.task, + "--patch_encoder", self.encoder, + "--batch_size", str(BATCH_SIZE), + "--mag", str(MAGNIFICATION), + "--patch_size", str(PATCH_SIZE), + "--gpus", + ] + [str(g) for g in gpus] + + if use_cache: + args.extend([ + "--wsi_cache", self.cache_dir, + "--cache_batch_size", str(CACHE_BATCH_SIZE), + ]) + + with patch.object(sys, "argv", args): + start_time = time.time() + try: + trident_run.main() + except SystemExit as e: + if e.code != 0: + self.fail(f"Trident exited with error code {e.code}") + except Exception as e: + self.fail(f"Trident crashed: {e}") + end_time = time.time() + + # Remove job output directory after each run + if os.path.exists(self.job_dir): + shutil.rmtree(self.job_dir) + os.makedirs(self.job_dir, exist_ok=True) + + duration = end_time - start_time + print(f"[Completed]: {run_name} in {duration:.2f} seconds") + return duration + + def test_benchmark_scenarios(self): + """ + Run 4 scenarios (Single/Multi GPU × Cache/NoCache) and print a comparison. + """ + results = {} + + # 1. Single GPU - No Cache + results["1GPU_NoCache"] = self.run_trident_scenario( + "Single GPU | Direct Read", + gpus=[GPU_IDS[0]], + use_cache=False, + ) + + # 2. Single GPU - With Cache + results["1GPU_Cache"] = self.run_trident_scenario( + "Single GPU | Cached Read", + gpus=[GPU_IDS[0]], + use_cache=True, + ) + + # 3–4. Multi-GPU only if >1 GPU configured + if len(GPU_IDS) > 1: + results["MultiGPU_NoCache"] = self.run_trident_scenario( + f"{len(GPU_IDS)} GPUs | Direct Read", + gpus=GPU_IDS, + use_cache=False, + ) + + results["MultiGPU_Cache"] = self.run_trident_scenario( + f"{len(GPU_IDS)} GPUs | Cached Read", + gpus=GPU_IDS, + use_cache=True, + ) + + # Report + print("\n" + "=" * 40) + print(f"{'SCENARIO':<25} | {'TIME (s)':<10}") + print("-" * 40) + for name, duration in results.items(): + print(f"{name:<25} | {duration:<10.2f}") + print("=" * 40) + + self.assertTrue(all(t > 0 for t in results.values())) + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn", force=True) + unittest.main() \ No newline at end of file diff --git a/trident/IO.py b/trident/IO.py index 0697a14..bf92d09 100644 --- a/trident/IO.py +++ b/trident/IO.py @@ -910,6 +910,9 @@ def get_num_workers(batch_size: int, if os.name == 'nt': return 0 + if max_workers is not None and max_workers <= 0: + return 0 + num_cores = os.cpu_count() or fallback num_workers = int(factor * num_cores) # Use a fraction of available cores max_workers = max_workers or (2 * batch_size) # Optional cap diff --git a/trident/Processor.py b/trident/Processor.py index 94c68b7..a943093 100644 --- a/trident/Processor.py +++ b/trident/Processor.py @@ -28,6 +28,7 @@ def __init__( max_workers: Optional[int] = None, reader_type: Optional[WSIReaderType] = None, search_nested: bool = False, + selected_wsi_paths: Optional[List[str]] = None, ) -> None: """ The `Processor` class handles all preprocessing steps starting from whole-slide images (WSIs). @@ -82,6 +83,9 @@ def __init__( the filename (excluding directory structure) will be used for downstream outputs (e.g., segmentation filenames). If False, only files directly inside `wsi_source` will be considered. Defaults to False. + selected_wsi_paths (List[str], optional): + Optional explicit list of absolute slide paths to process. When provided, `collect_valid_slides` + is skipped and only the supplied slides are used (useful for distributed processing). Defaults to None. Returns: @@ -117,15 +121,19 @@ def __init__( for ext in self.wsi_ext: assert ext.startswith('.'), f'Invalid extension: {ext} (must start with a period)' - # === Collect slide paths and relative paths === - full_paths, rel_paths = collect_valid_slides( - wsi_dir=wsi_source, - custom_list_path=custom_list_of_wsis, - wsi_ext=self.wsi_ext, - search_nested=search_nested, - max_workers=max_workers, - return_relative_paths=True - ) + # Collect slide paths and relative paths + if selected_wsi_paths is not None: + full_paths = selected_wsi_paths + rel_paths = [os.path.relpath(path, wsi_source) for path in selected_wsi_paths] + else: + full_paths, rel_paths = collect_valid_slides( + wsi_dir=wsi_source, + custom_list_path=custom_list_of_wsis, + wsi_ext=self.wsi_ext, + search_nested=search_nested, + max_workers=max_workers, + return_relative_paths=True + ) self.wsi_rel_paths = rel_paths if custom_list_of_wsis else None diff --git a/trident/wsi_objects/WSI.py b/trident/wsi_objects/WSI.py index 6cdeb22..2404c17 100644 --- a/trident/wsi_objects/WSI.py +++ b/trident/wsi_objects/WSI.py @@ -2,6 +2,7 @@ import numpy as np import os import warnings +import multiprocessing as mp import torch from typing import List, Tuple, Optional, Literal, Union from torch.utils.data import DataLoader @@ -17,6 +18,11 @@ ReadMode = Literal['pil', 'numpy'] +try: + _DATALOADER_MP_CTX = mp.get_context('fork') if 'fork' in mp.get_all_start_methods() else None +except (ValueError, AttributeError): + _DATALOADER_MP_CTX = None + class WSI: """ @@ -313,11 +319,15 @@ def _segment_semantic( precision = segmentation_model.precision eval_transforms = segmentation_model.eval_transforms dataset = WSIPatcherDataset(patcher, eval_transforms) + inferred_workers = get_num_workers(batch_size, max_workers=self.max_workers) if num_workers is None else num_workers + dataloader_ctx = _DATALOADER_MP_CTX if inferred_workers and inferred_workers > 0 else None + dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=collate_fn, - num_workers=get_num_workers(batch_size, max_workers=self.max_workers) if num_workers is None else num_workers, + num_workers=inferred_workers, + multiprocessing_context=dataloader_ctx, pin_memory=True ) @@ -815,10 +825,17 @@ def extract_patch_features( coords_only=False, pil=True, ) - - dataset = WSIPatcherDataset(patcher, patch_transforms) - dataloader = DataLoader(dataset, batch_size=batch_limit, num_workers=get_num_workers(batch_limit, max_workers=self.max_workers), pin_memory=False) + inferred_workers = get_num_workers(batch_limit, max_workers=self.max_workers) + dataloader_ctx = _DATALOADER_MP_CTX if inferred_workers and inferred_workers > 0 else None + + dataloader = DataLoader( + dataset, + batch_size=batch_limit, + num_workers=inferred_workers, + pin_memory=False, + multiprocessing_context=dataloader_ctx, + ) dataloader = tqdm(dataloader) if verbose else dataloader @@ -826,7 +843,7 @@ def extract_patch_features( for imgs, _ in dataloader: imgs = imgs.to(device) with torch.autocast(device_type='cuda', dtype=precision, enabled=(precision != torch.float32)): - batch_features = patch_encoder(imgs) + batch_features = patch_encoder(imgs) features.append(batch_features.cpu().numpy()) # Concatenate features From f279fc87522bc081f1434e654ff35e4781fae1d8 Mon Sep 17 00:00:00 2001 From: winglet0996 Date: Mon, 24 Nov 2025 05:24:01 +0800 Subject: [PATCH 2/7] log and caching logic minor fix --- run_batch_of_slides.py | 16 ++++++++-------- trident/Concurrency.py | 6 ++++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/run_batch_of_slides.py b/run_batch_of_slides.py index 1976c5e..355b553 100644 --- a/run_batch_of_slides.py +++ b/run_batch_of_slides.py @@ -6,6 +6,9 @@ from typing import Any, List from queue import Queue from threading import Thread +from tqdm import tqdm +import warnings +warnings.filterwarnings("ignore", category=FutureWarning) from trident import Processor from trident.patch_encoder_models import encoder_registry as patch_encoder_registry @@ -264,7 +267,8 @@ def get_pending_slides(args: argparse.Namespace) -> List[str]: coords_dir = args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap' pending = [] - for slide in all_slides: + print(f"[MAIN] Found {len(all_slides)} slides. Filtering completed slides...") + for slide in tqdm(all_slides, desc="Checking slides", unit="slide"): stem = os.path.splitext(os.path.basename(slide))[0] is_done = True for t in tasks: @@ -281,7 +285,8 @@ def get_pending_slides(args: argparse.Namespace) -> List[str]: if not is_done: pending.append(slide) - print(f"[MAIN] Found {len(all_slides)} slides. Processing {len(pending)} pending slides ({len(all_slides)-len(pending)} skipped).") + skipped = len(all_slides) - len(pending) + print(f"[MAIN] Skipped {skipped} completed slides. Processing {len(pending)} pending slides.") return pending @@ -300,11 +305,6 @@ def worker_entrypoint(args: argparse.Namespace) -> None: queue = Queue(maxsize=1) valid_slides = args.selected_wsi_paths - warm = valid_slides[:args.cache_batch_size] - warmup_dir = os.path.join(gpu_cache_dir, "batch_0") - cache_batch(warm, warmup_dir) - queue.put(0) - def processor_factory(wsi_dir: str) -> Processor: local_args = argparse.Namespace(**vars(args)) local_args.wsi_dir = wsi_dir @@ -320,7 +320,7 @@ def run_task_fn(processor: Processor, task_name: str) -> None: run_task(processor, local_args) producer = Thread(target=batch_producer, args=( - queue, valid_slides, args.cache_batch_size, args.cache_batch_size, gpu_cache_dir + queue, valid_slides, 0, args.cache_batch_size, gpu_cache_dir )) consumer = Thread(target=batch_consumer, args=( diff --git a/trident/Concurrency.py b/trident/Concurrency.py index 0746a72..a9cb709 100644 --- a/trident/Concurrency.py +++ b/trident/Concurrency.py @@ -69,7 +69,8 @@ def batch_producer( ssd_batch_dir = os.path.join(cache_dir, f"batch_{batch_id}") print(f"[PRODUCER] Caching batch {batch_id}: {ssd_batch_dir}") cache_batch(batch_paths, ssd_batch_dir) - queue.put(batch_id) + print(f"[PRODUCER] Batch {batch_id} cached and ready") + queue.put(batch_id) # Put will block if queue is full (maxsize=1), enabling pipeline queue.put(None) # Sentinel to signal completion @@ -125,4 +126,5 @@ def batch_consumer( print(f"[CONSUMER] Clearing cache for batch {batch_id}") shutil.rmtree(ssd_batch_dir, ignore_errors=True) - queue.task_done() + print(f"[CONSUMER] Batch {batch_id} completed and cache cleared") + queue.task_done() # Signal completion to producer From abe0db44da1bd51912528087c91bfa1e592f2b56 Mon Sep 17 00:00:00 2001 From: winglet0996 Date: Mon, 24 Nov 2025 05:50:05 +0800 Subject: [PATCH 3/7] caching mode fix --- run_batch_of_slides.py | 31 +++++++++++++++++++++++++++---- trident/Concurrency.py | 24 ++++++++++++++++-------- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/run_batch_of_slides.py b/run_batch_of_slides.py index 355b553..10afe5b 100644 --- a/run_batch_of_slides.py +++ b/run_batch_of_slides.py @@ -3,9 +3,9 @@ import torch import multiprocessing as mp import shutil -from typing import Any, List +from typing import Any, Dict, List from queue import Queue -from threading import Thread +from threading import Thread, Event from tqdm import tqdm import warnings warnings.filterwarnings("ignore", category=FutureWarning) @@ -304,6 +304,28 @@ def worker_entrypoint(args: argparse.Namespace) -> None: queue = Queue(maxsize=1) valid_slides = args.selected_wsi_paths + cache_ready_flags: Dict[int, Event] = {} + + def get_flag(batch_id: int) -> Event: + flag = cache_ready_flags.get(batch_id) + if flag is None: + flag = Event() + cache_ready_flags[batch_id] = flag + return flag + + def mark_cache_ready(batch_id: int) -> None: + get_flag(batch_id).set() + + def wait_for_cache(batch_id: int) -> None: + get_flag(batch_id).wait() + + warm = valid_slides[:args.cache_batch_size] + warmup_dir = os.path.join(gpu_cache_dir, "batch_0") + print(f"[GPU {args.gpu}] Pre-caching batch 0 ({len(warm)} slides)...") + cache_batch(warm, warmup_dir) + print(f"[GPU {args.gpu}] Batch 0 cached to {warmup_dir}") + mark_cache_ready(0) + queue.put(0) def processor_factory(wsi_dir: str) -> Processor: local_args = argparse.Namespace(**vars(args)) @@ -311,6 +333,7 @@ def processor_factory(wsi_dir: str) -> Processor: local_args.wsi_cache = None local_args.custom_list_of_wsis = None local_args.search_nested = False + local_args.selected_wsi_paths = None return initialize_processor(local_args) def run_task_fn(processor: Processor, task_name: str) -> None: @@ -320,11 +343,11 @@ def run_task_fn(processor: Processor, task_name: str) -> None: run_task(processor, local_args) producer = Thread(target=batch_producer, args=( - queue, valid_slides, 0, args.cache_batch_size, gpu_cache_dir + queue, valid_slides, args.cache_batch_size, args.cache_batch_size, gpu_cache_dir, mark_cache_ready )) consumer = Thread(target=batch_consumer, args=( - queue, args.task, gpu_cache_dir, processor_factory, run_task_fn + queue, args.task, gpu_cache_dir, processor_factory, run_task_fn, wait_for_cache )) producer.start() diff --git a/trident/Concurrency.py b/trident/Concurrency.py index a9cb709..45c3bab 100644 --- a/trident/Concurrency.py +++ b/trident/Concurrency.py @@ -2,8 +2,9 @@ import gc import torch import shutil -from typing import List, Callable, Any +from typing import List, Callable, Any, Optional from queue import Queue +from tqdm import tqdm @@ -26,7 +27,7 @@ def cache_batch(wsis: List[str], dest_dir: str) -> List[str]: os.makedirs(dest_dir, exist_ok=True) copied = [] - for wsi in wsis: + for wsi in tqdm(wsis, desc=f"Caching to {os.path.basename(dest_dir)}", unit="slide", leave=False): dest_path = os.path.join(dest_dir, os.path.basename(wsi)) shutil.copy(wsi, dest_path) copied.append(dest_path) @@ -46,6 +47,7 @@ def batch_producer( start_idx: int, batch_size: int, cache_dir: str, + on_cached: Optional[Callable[[int], None]] = None, ) -> None: """ Produces and caches batches of slides. Sends batch IDs to a queue for downstream processing. @@ -67,10 +69,12 @@ def batch_producer( batch_paths = valid_slides[i:i + batch_size] batch_id = i // batch_size ssd_batch_dir = os.path.join(cache_dir, f"batch_{batch_id}") - print(f"[PRODUCER] Caching batch {batch_id}: {ssd_batch_dir}") + print(f"[PRODUCER] Caching batch {batch_id} ({len(batch_paths)} slides) to {ssd_batch_dir}") cache_batch(batch_paths, ssd_batch_dir) - print(f"[PRODUCER] Batch {batch_id} cached and ready") - queue.put(batch_id) # Put will block if queue is full (maxsize=1), enabling pipeline + print(f"[PRODUCER] Batch {batch_id} ready for processing") + if on_cached is not None: + on_cached(batch_id) + queue.put(batch_id) queue.put(None) # Sentinel to signal completion @@ -81,6 +85,7 @@ def batch_consumer( cache_dir: str, processor_factory: Callable[[str], Any], run_task_fn: Callable[[Any, str], None], + wait_for_cache_ready: Optional[Callable[[int], None]] = None, ) -> None: """ Consumes cached batches from the queue, processes them, and optionally clears cache. @@ -106,7 +111,11 @@ def batch_consumer( break ssd_batch_dir = os.path.join(cache_dir, f"batch_{batch_id}") - print(f"[CONSUMER] Processing batch {batch_id}: {ssd_batch_dir}") + + if wait_for_cache_ready is not None: + print(f"[CONSUMER] Waiting for batch {batch_id} cache to complete...") + wait_for_cache_ready(batch_id) + print(f"[CONSUMER] Batch {batch_id} cache ready, starting processing") processor = processor_factory(ssd_batch_dir) @@ -126,5 +135,4 @@ def batch_consumer( print(f"[CONSUMER] Clearing cache for batch {batch_id}") shutil.rmtree(ssd_batch_dir, ignore_errors=True) - print(f"[CONSUMER] Batch {batch_id} completed and cache cleared") - queue.task_done() # Signal completion to producer + queue.task_done() From bd6a981de788781bf40b1ba1e1e6d6a2aaf0bfec Mon Sep 17 00:00:00 2001 From: winglet0996 <78600759+winglet0996@users.noreply.github.com> Date: Mon, 24 Nov 2025 06:51:45 +0800 Subject: [PATCH 4/7] gigapath fix --- trident/slide_encoder_models/load.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/trident/slide_encoder_models/load.py b/trident/slide_encoder_models/load.py index e93db80..2e48dab 100644 --- a/trident/slide_encoder_models/load.py +++ b/trident/slide_encoder_models/load.py @@ -1,4 +1,4 @@ -import sys +2.5.import sys import os import torch import traceback @@ -330,14 +330,22 @@ def _build(self, pretrained=True): raise Exception("Please install fairscale and gigapath using `pip install fairscale git+https://github.com/prov-gigapath/prov-gigapath.git`.") # Make sure flash_attn is correct version - try: - import flash_attn; assert flash_attn.__version__ == '2.5.8' - except: - traceback.print_exc() - raise Exception("Please install flash_attn version 2.5.8 using `pip install flash_attn==2.5.8`.") + # try: + # import flash_attn; assert flash_attn.__version__ == '2.5.8' + # except: + # traceback.print_exc() + # raise Exception("Please install flash_attn version 2.5.8 using `pip install flash_attn==2.5.8`.") if pretrained: - model = create_model("hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536, global_pool=True) + # Try to get local weights path first + weights_path = get_weights_path('slide', self.enc_name) + if weights_path: + print(f"Loading GigaPath slide encoder from local path: {weights_path}") + model = create_model(weights_path, "gigapath_slide_enc12l768d", 1536, global_pool=True) + else: + # Fallback to downloading from Hugging Face Hub + print("Local weights not found. Downloading from Hugging Face Hub...") + model = create_model("hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536, global_pool=True) else: model = create_model("", "gigapath_slide_enc12l768d", 1536, global_pool=True) From c9da6e2369038c09338f1f82f52af3415b051682 Mon Sep 17 00:00:00 2001 From: winglet0996 Date: Mon, 24 Nov 2025 07:14:54 +0800 Subject: [PATCH 5/7] set spawn --- run_batch_of_slides.py | 1 + trident/wsi_objects/WSI.py | 5 +---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/run_batch_of_slides.py b/run_batch_of_slides.py index 10afe5b..a743fa3 100644 --- a/run_batch_of_slides.py +++ b/run_batch_of_slides.py @@ -414,4 +414,5 @@ def main() -> None: if __name__ == "__main__": + mp.set_start_method('spawn', force=True) main() \ No newline at end of file diff --git a/trident/wsi_objects/WSI.py b/trident/wsi_objects/WSI.py index 2404c17..91afe60 100644 --- a/trident/wsi_objects/WSI.py +++ b/trident/wsi_objects/WSI.py @@ -18,10 +18,7 @@ ReadMode = Literal['pil', 'numpy'] -try: - _DATALOADER_MP_CTX = mp.get_context('fork') if 'fork' in mp.get_all_start_methods() else None -except (ValueError, AttributeError): - _DATALOADER_MP_CTX = None +_DATALOADER_MP_CTX = mp.get_context('spawn') class WSI: From d3a47279a36952baa95ffe7897a167f8f51dd160 Mon Sep 17 00:00:00 2001 From: winglet0996 Date: Mon, 24 Nov 2025 17:36:38 +0800 Subject: [PATCH 6/7] set fork --- run_batch_of_slides.py | 91 +++++++++++++++++++++----------------- trident/Concurrency.py | 18 ++------ trident/wsi_objects/WSI.py | 5 ++- 3 files changed, 58 insertions(+), 56 deletions(-) diff --git a/run_batch_of_slides.py b/run_batch_of_slides.py index a743fa3..1ef519d 100644 --- a/run_batch_of_slides.py +++ b/run_batch_of_slides.py @@ -1,19 +1,20 @@ import os +import math import argparse import torch import multiprocessing as mp import shutil -from typing import Any, Dict, List +from typing import Any, List from queue import Queue -from threading import Thread, Event -from tqdm import tqdm +from threading import Thread import warnings +from tqdm import tqdm warnings.filterwarnings("ignore", category=FutureWarning) from trident import Processor from trident.patch_encoder_models import encoder_registry as patch_encoder_registry from trident.slide_encoder_models import encoder_registry as slide_encoder_registry -from trident.Concurrency import batch_producer, batch_consumer, cache_batch +from trident.Concurrency import batch_producer, batch_consumer from trident.IO import collect_valid_slides @@ -267,26 +268,47 @@ def get_pending_slides(args: argparse.Namespace) -> List[str]: coords_dir = args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap' pending = [] - print(f"[MAIN] Found {len(all_slides)} slides. Filtering completed slides...") - for slide in tqdm(all_slides, desc="Checking slides", unit="slide"): + def safe_listdir(path: str) -> List[str]: + try: + return os.listdir(path) + except (FileNotFoundError, NotADirectoryError): + return [] + + seg_done = set() + coords_done = set() + feat_done = set() + + if 'seg' in tasks: + contour_dir = os.path.join(args.job_dir, 'contours') + seg_done = {os.path.splitext(f)[0] for f in safe_listdir(contour_dir) if f.lower().endswith('.jpg')} + + if 'coords' in tasks: + patches_dir = os.path.join(args.job_dir, coords_dir, 'patches') + coords_done = { + f[:-len('_patches')] + for f in (os.path.splitext(fname)[0] for fname in safe_listdir(patches_dir) if fname.endswith('_patches.h5')) + } + + if 'feat' in tasks: + feat_sub = f'slide_features_{args.slide_encoder}' if args.slide_encoder else f'features_{args.patch_encoder}' + feat_dir = os.path.join(args.job_dir, coords_dir, feat_sub) + feat_done = {os.path.splitext(f)[0] for f in safe_listdir(feat_dir) if os.path.splitext(f)[1] in {'.h5', '.pt'}} + + for slide in tqdm(all_slides, desc="Checking slide status", unit="slide"): stem = os.path.splitext(os.path.basename(slide))[0] is_done = True for t in tasks: - if t == 'seg' and not os.path.exists(os.path.join(args.job_dir, 'contours', f'{stem}.jpg')): + if t == 'seg' and stem not in seg_done: is_done = False - elif t == 'coords' and not os.path.exists(os.path.join(args.job_dir, coords_dir, 'patches', f'{stem}_patches.h5')): + elif t == 'coords' and stem not in coords_done: + is_done = False + elif t == 'feat' and stem not in feat_done: is_done = False - elif t == 'feat': - feat_sub = f'slide_features_{args.slide_encoder}' if args.slide_encoder else f'features_{args.patch_encoder}' - feat_dir = os.path.join(args.job_dir, coords_dir, feat_sub) - if not (os.path.isdir(feat_dir) and any(os.path.exists(os.path.join(feat_dir, f'{stem}.{ext}')) for ext in ['h5', 'pt'])): - is_done = False if not is_done: break if not is_done: pending.append(slide) - skipped = len(all_slides) - len(pending) - print(f"[MAIN] Skipped {skipped} completed slides. Processing {len(pending)} pending slides.") + print(f"[MAIN] Found {len(all_slides)} slides. Processing {len(pending)} pending slides ({len(all_slides)-len(pending)} skipped).") return pending @@ -301,31 +323,18 @@ def worker_entrypoint(args: argparse.Namespace) -> None: # Setup specific cache dir for this GPU process gpu_cache_dir = os.path.join(args.wsi_cache, f"gpu_{args.gpu}") os.makedirs(gpu_cache_dir, exist_ok=True) + + valid_slides = list(args.selected_wsi_paths or []) + if not valid_slides: + print(f"[WORKER {args.gpu}] No slides assigned. Skipping cached pipeline.") + return + + batch_size = max(1, args.cache_batch_size or len(valid_slides)) queue = Queue(maxsize=1) - valid_slides = args.selected_wsi_paths - cache_ready_flags: Dict[int, Event] = {} - - def get_flag(batch_id: int) -> Event: - flag = cache_ready_flags.get(batch_id) - if flag is None: - flag = Event() - cache_ready_flags[batch_id] = flag - return flag - - def mark_cache_ready(batch_id: int) -> None: - get_flag(batch_id).set() - - def wait_for_cache(batch_id: int) -> None: - get_flag(batch_id).wait() - - warm = valid_slides[:args.cache_batch_size] - warmup_dir = os.path.join(gpu_cache_dir, "batch_0") - print(f"[GPU {args.gpu}] Pre-caching batch 0 ({len(warm)} slides)...") - cache_batch(warm, warmup_dir) - print(f"[GPU {args.gpu}] Batch 0 cached to {warmup_dir}") - mark_cache_ready(0) - queue.put(0) + + # No pre-warming: let producer handle all batching + start_idx = 0 def processor_factory(wsi_dir: str) -> Processor: local_args = argparse.Namespace(**vars(args)) @@ -340,14 +349,15 @@ def run_task_fn(processor: Processor, task_name: str) -> None: # We must use a local copy of args to update the task without affecting others local_args = argparse.Namespace(**vars(args)) local_args.task = task_name + local_args.selected_wsi_paths = None run_task(processor, local_args) producer = Thread(target=batch_producer, args=( - queue, valid_slides, args.cache_batch_size, args.cache_batch_size, gpu_cache_dir, mark_cache_ready + queue, valid_slides, start_idx, batch_size, gpu_cache_dir )) consumer = Thread(target=batch_consumer, args=( - queue, args.task, gpu_cache_dir, processor_factory, run_task_fn, wait_for_cache + queue, args.task, gpu_cache_dir, processor_factory, run_task_fn )) producer.start() @@ -414,5 +424,4 @@ def main() -> None: if __name__ == "__main__": - mp.set_start_method('spawn', force=True) main() \ No newline at end of file diff --git a/trident/Concurrency.py b/trident/Concurrency.py index 45c3bab..0746a72 100644 --- a/trident/Concurrency.py +++ b/trident/Concurrency.py @@ -2,9 +2,8 @@ import gc import torch import shutil -from typing import List, Callable, Any, Optional +from typing import List, Callable, Any from queue import Queue -from tqdm import tqdm @@ -27,7 +26,7 @@ def cache_batch(wsis: List[str], dest_dir: str) -> List[str]: os.makedirs(dest_dir, exist_ok=True) copied = [] - for wsi in tqdm(wsis, desc=f"Caching to {os.path.basename(dest_dir)}", unit="slide", leave=False): + for wsi in wsis: dest_path = os.path.join(dest_dir, os.path.basename(wsi)) shutil.copy(wsi, dest_path) copied.append(dest_path) @@ -47,7 +46,6 @@ def batch_producer( start_idx: int, batch_size: int, cache_dir: str, - on_cached: Optional[Callable[[int], None]] = None, ) -> None: """ Produces and caches batches of slides. Sends batch IDs to a queue for downstream processing. @@ -69,11 +67,8 @@ def batch_producer( batch_paths = valid_slides[i:i + batch_size] batch_id = i // batch_size ssd_batch_dir = os.path.join(cache_dir, f"batch_{batch_id}") - print(f"[PRODUCER] Caching batch {batch_id} ({len(batch_paths)} slides) to {ssd_batch_dir}") + print(f"[PRODUCER] Caching batch {batch_id}: {ssd_batch_dir}") cache_batch(batch_paths, ssd_batch_dir) - print(f"[PRODUCER] Batch {batch_id} ready for processing") - if on_cached is not None: - on_cached(batch_id) queue.put(batch_id) queue.put(None) # Sentinel to signal completion @@ -85,7 +80,6 @@ def batch_consumer( cache_dir: str, processor_factory: Callable[[str], Any], run_task_fn: Callable[[Any, str], None], - wait_for_cache_ready: Optional[Callable[[int], None]] = None, ) -> None: """ Consumes cached batches from the queue, processes them, and optionally clears cache. @@ -111,11 +105,7 @@ def batch_consumer( break ssd_batch_dir = os.path.join(cache_dir, f"batch_{batch_id}") - - if wait_for_cache_ready is not None: - print(f"[CONSUMER] Waiting for batch {batch_id} cache to complete...") - wait_for_cache_ready(batch_id) - print(f"[CONSUMER] Batch {batch_id} cache ready, starting processing") + print(f"[CONSUMER] Processing batch {batch_id}: {ssd_batch_dir}") processor = processor_factory(ssd_batch_dir) diff --git a/trident/wsi_objects/WSI.py b/trident/wsi_objects/WSI.py index 91afe60..2404c17 100644 --- a/trident/wsi_objects/WSI.py +++ b/trident/wsi_objects/WSI.py @@ -18,7 +18,10 @@ ReadMode = Literal['pil', 'numpy'] -_DATALOADER_MP_CTX = mp.get_context('spawn') +try: + _DATALOADER_MP_CTX = mp.get_context('fork') if 'fork' in mp.get_all_start_methods() else None +except (ValueError, AttributeError): + _DATALOADER_MP_CTX = None class WSI: From 34d4a90c759e1677ca99d374c7d07a0f3dbc4e7c Mon Sep 17 00:00:00 2001 From: winglet0996 <78600759+winglet0996@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:15:14 +0800 Subject: [PATCH 7/7] typo fix --- trident/slide_encoder_models/load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trident/slide_encoder_models/load.py b/trident/slide_encoder_models/load.py index 2e48dab..ca5f0f7 100644 --- a/trident/slide_encoder_models/load.py +++ b/trident/slide_encoder_models/load.py @@ -1,4 +1,4 @@ -2.5.import sys +import sys import os import torch import traceback