Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ Additional packages may be required to load some pretrained models. Follow error
python run_batch_of_slides.py --task all --wsi_dir ./wsis --job_dir ./trident_processed --patch_encoder uni_v1 --mag 20 --patch_size 256
```

Add `--gpus 0 1` (or any list of device indices) to split slides across multiple GPUs automatically.

**Feeling cautious?**

Run this command to perform all processing steps for a **single** slide:
Expand Down
214 changes: 164 additions & 50 deletions run_batch_of_slides.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
"""
Example usage:

```
python run_batch_of_slides.py --task all --wsi_dir output/wsis --job_dir output --patch_encoder uni_v1 --mag 20 --patch_size 256
```

"""
import os
import math
import argparse
import torch
from typing import Any
import multiprocessing as mp
import shutil
from typing import Any, List
from queue import Queue
from threading import Thread
import warnings
from tqdm import tqdm
warnings.filterwarnings("ignore", category=FutureWarning)

from trident import Processor
from trident.patch_encoder_models import encoder_registry as patch_encoder_registry
from trident.slide_encoder_models import encoder_registry as slide_encoder_registry
from trident.Concurrency import batch_producer, batch_consumer
from trident.IO import collect_valid_slides


def build_parser() -> argparse.ArgumentParser:
Expand All @@ -29,6 +31,8 @@ def build_parser() -> argparse.ArgumentParser:

# Generic arguments
parser.add_argument('--gpu', type=int, default=0, help='GPU index to use for processing tasks.')
parser.add_argument('--gpus', type=int, nargs='+', default=None,
help='Optional space-separated list of GPU indices to enable multi-GPU execution.')
parser.add_argument('--task', type=str, default='seg',
choices=['seg', 'coords', 'feat', 'all'],
help='Task to run: seg (segmentation), coords (save tissue coordinates), img (save tissue images), feat (extract features).')
Expand Down Expand Up @@ -165,6 +169,7 @@ def initialize_processor(args: argparse.Namespace) -> Processor:
max_workers=args.max_workers,
reader_type=args.reader_type,
search_nested=args.search_nested,
selected_wsi_paths=getattr(args, 'selected_wsi_paths', None),
)


Expand All @@ -180,6 +185,8 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None:
Parsed command-line arguments containing task configuration.
"""

device = getattr(args, 'device', f'cuda:{args.gpu}')

if args.task == 'seg':
from trident.segmentation_models.load import segmentation_model_factory

Expand All @@ -203,7 +210,7 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None:
holes_are_tissue= not args.remove_holes,
artifact_remover_model=artifact_remover_model,
batch_size=args.seg_batch_size if args.seg_batch_size is not None else args.batch_size,
device=f'cuda:{args.gpu}',
device=device,
)
elif args.task == 'coords':
processor.run_patching_job(
Expand All @@ -220,7 +227,7 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None:
processor.run_patch_feature_extraction_job(
coords_dir=args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap',
patch_encoder=encoder,
device=f'cuda:{args.gpu}',
device=device,
saveas='h5',
batch_limit=args.feat_batch_size if args.feat_batch_size is not None else args.batch_size,
)
Expand All @@ -230,76 +237,134 @@ def run_task(processor: Processor, args: argparse.Namespace) -> None:
processor.run_slide_feature_extraction_job(
slide_encoder=encoder,
coords_dir=args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap',
device=f'cuda:{args.gpu}',
device=device,
saveas='h5',
batch_limit=args.feat_batch_size if args.feat_batch_size is not None else args.batch_size,
)
else:
raise ValueError(f'Invalid task: {args.task}')


def main() -> None:
"""
Main entry point for the Trident batch processing script.

Handles both sequential and parallel processing modes based on whether
WSI caching is enabled. Supports segmentation, coordinate extraction,
and feature extraction tasks.
"""
def cleanup_files(job_dir: str, cache_dir: str = None) -> None:
if os.path.isdir(job_dir):
for root, _, files in os.walk(job_dir):
for f in files:
if f.endswith('.lock'):
try: os.remove(os.path.join(root, f))
except OSError: pass
if cache_dir and os.path.isdir(cache_dir):
try: shutil.rmtree(cache_dir)
except OSError: pass
os.makedirs(cache_dir, exist_ok=True)

args = parse_arguments()
args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu'

def get_pending_slides(args: argparse.Namespace) -> List[str]:
all_slides = collect_valid_slides(
wsi_dir=args.wsi_dir, custom_list_path=args.custom_list_of_wsis,
wsi_ext=args.wsi_ext, search_nested=args.search_nested, max_workers=args.max_workers
)

tasks = ['seg', 'coords', 'feat'] if args.task == 'all' else [args.task]
coords_dir = args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap'
pending = []

def safe_listdir(path: str) -> List[str]:
try:
return os.listdir(path)
except (FileNotFoundError, NotADirectoryError):
return []

seg_done = set()
coords_done = set()
feat_done = set()

if 'seg' in tasks:
contour_dir = os.path.join(args.job_dir, 'contours')
seg_done = {os.path.splitext(f)[0] for f in safe_listdir(contour_dir) if f.lower().endswith('.jpg')}

if 'coords' in tasks:
patches_dir = os.path.join(args.job_dir, coords_dir, 'patches')
coords_done = {
f[:-len('_patches')]
for f in (os.path.splitext(fname)[0] for fname in safe_listdir(patches_dir) if fname.endswith('_patches.h5'))
}

if 'feat' in tasks:
feat_sub = f'slide_features_{args.slide_encoder}' if args.slide_encoder else f'features_{args.patch_encoder}'
feat_dir = os.path.join(args.job_dir, coords_dir, feat_sub)
feat_done = {os.path.splitext(f)[0] for f in safe_listdir(feat_dir) if os.path.splitext(f)[1] in {'.h5', '.pt'}}

for slide in tqdm(all_slides, desc="Checking slide status", unit="slide"):
stem = os.path.splitext(os.path.basename(slide))[0]
is_done = True
for t in tasks:
if t == 'seg' and stem not in seg_done:
is_done = False
elif t == 'coords' and stem not in coords_done:
is_done = False
elif t == 'feat' and stem not in feat_done:
is_done = False
if not is_done: break

if not is_done: pending.append(slide)

print(f"[MAIN] Found {len(all_slides)} slides. Processing {len(pending)} pending slides ({len(all_slides)-len(pending)} skipped).")
return pending


def worker_entrypoint(args: argparse.Namespace) -> None:
"""
Entry point for each GPU worker process.
Handles both cached (threading) and non-cached (sequential) execution modes
inside the isolated process.
"""
if args.wsi_cache:
# === Parallel pipeline with caching ===

from queue import Queue
from threading import Thread

from trident.Concurrency import batch_producer, batch_consumer, cache_batch
from trident.IO import collect_valid_slides

# === Parallel pipeline with caching (Threaded inside Process) ===
# Setup specific cache dir for this GPU process
gpu_cache_dir = os.path.join(args.wsi_cache, f"gpu_{args.gpu}")
os.makedirs(gpu_cache_dir, exist_ok=True)

valid_slides = list(args.selected_wsi_paths or [])
if not valid_slides:
print(f"[WORKER {args.gpu}] No slides assigned. Skipping cached pipeline.")
return

batch_size = max(1, args.cache_batch_size or len(valid_slides))

queue = Queue(maxsize=1)
valid_slides = collect_valid_slides(
wsi_dir=args.wsi_dir,
custom_list_path=args.custom_list_of_wsis,
wsi_ext=args.wsi_ext,
search_nested=args.search_nested,
max_workers=args.max_workers
)
print(f"[MAIN] Found {len(valid_slides)} valid slides in {args.wsi_dir}.")

warm = valid_slides[:args.cache_batch_size]
warmup_dir = os.path.join(args.wsi_cache, "batch_0")
print(f"[MAIN] Warmup caching batch: {warmup_dir}")
cache_batch(warm, warmup_dir)
queue.put(0)

# No pre-warming: let producer handle all batching
start_idx = 0

def processor_factory(wsi_dir: str) -> Processor:
local_args = argparse.Namespace(**vars(args))
local_args.wsi_dir = wsi_dir
local_args.wsi_cache = None
local_args.custom_list_of_wsis = None
local_args.search_nested = False
local_args.selected_wsi_paths = None
return initialize_processor(local_args)

def run_task_fn(processor: Processor, task_name: str) -> None:
args.task = task_name
run_task(processor, args)
# We must use a local copy of args to update the task without affecting others
local_args = argparse.Namespace(**vars(args))
local_args.task = task_name
local_args.selected_wsi_paths = None
run_task(processor, local_args)

producer = Thread(target=batch_producer, args=(
queue, valid_slides, args.cache_batch_size, args.cache_batch_size, args.wsi_cache
queue, valid_slides, start_idx, batch_size, gpu_cache_dir
))

consumer = Thread(target=batch_consumer, args=(
queue, args.task, args.wsi_cache, processor_factory, run_task_fn
queue, args.task, gpu_cache_dir, processor_factory, run_task_fn
))

print("[MAIN] Starting producer and consumer threads.")
producer.start()
consumer.start()
producer.join()
consumer.join()

else:
# === Sequential mode ===
processor = initialize_processor(args)
Expand All @@ -309,5 +374,54 @@ def run_task_fn(processor: Processor, task_name: str) -> None:
run_task(processor, args)


def main() -> None:
"""
Main entry point for the Trident batch processing script.

Handles both sequential and parallel processing modes based on whether
WSI caching is enabled. Supports segmentation, coordinate extraction,
and feature extraction tasks.
"""

args = parse_arguments()
cleanup_files(args.job_dir, args.wsi_cache)

if args.gpus:
gpu_ids = list(dict.fromkeys(args.gpus))
else:
gpu_ids = [args.gpu]

if not torch.cuda.is_available() and any(g >= 0 for g in gpu_ids):
print('[MAIN] Warning: CUDA not available, using CPU.')
gpu_ids = [-1]

pending_slides = get_pending_slides(args)
if not pending_slides:
return

num_shards = len(gpu_ids)
shards = [[] for _ in range(num_shards)]
for i, slide in enumerate(pending_slides):
shards[i % num_shards].append(slide)

ctx = mp.get_context('spawn') if torch.cuda.is_available() else mp.get_context('fork')
processes = []

for i, gpu_id in enumerate(gpu_ids):
if not shards[i]: continue

worker_args = argparse.Namespace(**vars(args))
worker_args.gpu = gpu_id
worker_args.device = f'cuda:{gpu_id}' if gpu_id >= 0 else 'cpu'
worker_args.selected_wsi_paths = shards[i]

p = ctx.Process(target=worker_entrypoint, args=(worker_args,))
p.start()
processes.append(p)

for p in processes:
p.join()


if __name__ == "__main__":
main()
main()
Loading