diff --git a/.gitignore b/.gitignore index 93d2b8d..6443292 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ slurm-*.out +*.zarr # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/BANIS.py b/BANIS.py index 8b6ed6a..d5bb6de 100644 --- a/BANIS.py +++ b/BANIS.py @@ -1,13 +1,13 @@ import argparse import gc import os +import shutil + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" from collections import defaultdict from datetime import datetime from typing import Any, Dict -import random -import numpy as np import pytorch_lightning as pl import torch import torchvision @@ -23,7 +23,7 @@ from tqdm import tqdm from data import load_data -from inference import scale_sigmoid, predict_aff, compute_connected_component_segmentation +from inference import AffinityPredictor, Thresholding from metrics import compute_metrics @@ -35,7 +35,7 @@ class BANIS(LightningModule): def __init__(self, **kwargs: Any): super().__init__() self.save_hyperparameters() - print(f"hparams: \n{self.hparams}") + # print(f"hparams: \n{self.hparams}") self.model = create_mednext_v1( num_input_channels=self.hparams.num_input_channels, @@ -163,8 +163,17 @@ def full_cube_inference(self, mode: str, global_step=None): img_data = zarr.open(os.path.join(seed_path, "data.zarr"), mode="r")["img"] - aff_pred = predict_aff(img_data, model=self, zarr_path=f"{self.hparams.save_dir}/pred_aff_{mode}.zarr", do_overlap=True, prediction_channels=3, divide=255, - small_size=self.hparams.small_size, compute_backend="local") + affinity_predictor = AffinityPredictor( + chunk_cube_size=3000, # can be adjusted + compute_backend="local", + model=self, + small_size=self.hparams.small_size, + do_overlap=True, + prediction_channels=3, + divide=255, + ) + affinity_predictor.img_to_aff(img_data, zarr_path=f"{self.hparams.save_dir}/pred_aff_{mode}.zarr") + aff_pred = zarr.open(f"{self.hparams.save_dir}/pred_aff_{mode}.zarr", mode="r") self._evaluate_thresholds(aff_pred, os.path.join(seed_path, "skeleton.pkl"), mode, global_step) @@ -179,9 +188,9 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str, torch.cuda.empty_cache() print(f"threshold {thr}") - pred_seg = compute_connected_component_segmentation( - aff_pred[:3] > thr # hard affinities - ) + postprocessor = Thresholding(3000, "local", thr) + postprocessor.aff_to_seg(aff_pred, f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr") + pred_seg = zarr.open(f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr", mode="r") metrics = compute_metrics(pred_seg, skel_path) @@ -201,9 +210,11 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str, self.best_thr_so_far[mode] = thr with open(f"{self.hparams.save_dir}/best_thr_{mode}.txt", "w") as f: f.write(str(self.best_thr_so_far[mode])) - seg_pred = zarr.array(pred_seg, dtype=np.uint32, - store=f"{self.hparams.save_dir}/pred_seg_{mode}.zarr", - chunks=(512, 512, 512), overwrite=True) + if os.path.exists(f"{self.hparams.save_dir}/pred_seg_{mode}.zarr"): + shutil.rmtree(f"{self.hparams.save_dir}/pred_seg_{mode}.zarr") + os.replace(f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr", f"{self.hparams.save_dir}/pred_seg_{mode}.zarr") + else: + shutil.rmtree(f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr") best_voi = min(best_voi, metrics["voi_sum"]) self.safe_add_scalar(f"{mode}_best_nerl", best_nerl, global_step) diff --git a/README.md b/README.md index 81d39f0..8af933e 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,18 @@ python slurm_job_scheduler.py Adding an `auto_resubmit` argument to `config.yaml` allows Slurm to automatically resubmit jobs that reach the Slurm time limit (see `aff_train.sh`). +## Prediction + +To predict segmentation from an image: + +```bash +python inference --img_path /path/to/image.zarr --model_path /path/to/model.ckpt --chunk_cube_size 3000 +``` + +The `chunk_cube_size` parameter sets the maximum cube size that can be loaded in memory. +If you have enough memory available, set it to a bigger value, if you are tight with memory, set a lower value (in exchange for increased computation time). +See [inference.py](inference.py) for other parameters. + ## Evaluation To evaluate a predicted segmentation (`.zarr` or `.npy`): diff --git a/aff_train.sh b/aff_train.sh index 074bcfb..90430d1 100644 --- a/aff_train.sh +++ b/aff_train.sh @@ -1,14 +1,14 @@ #!/bin/bash -l -#SBATCH --nodes=2 -#SBATCH --gres=gpu:4 -#SBATCH --ntasks-per-node=4 +#SBATCH --nodes=1 +#SBATCH --gres=gpu:1 +#SBATCH --ntasks-per-node=1 #SBATCH --time=7-00 -#SBATCH --cpus-per-task=16 -#SBATCH --mem=1000G +#SBATCH --cpus-per-task=32 +#SBATCH --mem=500G #SBATCH --signal=B:USR1@300 #SBATCH --open-mode=append -#SBATCH --partition=p.large +#SBATCH --partition=p.share mamba activate nisb diff --git a/config.yaml b/config.yaml index a0677bc..512e992 100644 --- a/config.yaml +++ b/config.yaml @@ -5,20 +5,20 @@ params: - 1e-2 seed: - 0 - #- 1 - #- 2 - #- 3 - #- 4 + - 1 + - 2 + - 3 + - 4 long_range: - 10 batch_size: - - 1 + - 8 scheduler: - true model_id: - - "L" + - "S" kernel_size: - - 5 + - 3 synthetic: - 1.0 drop_slice_prob: @@ -32,33 +32,34 @@ params: affine: - 0.5 n_steps: - - 1_000_000 + - 50000 small_size: - - 256 + - 128 data_setting: - #- "base" - #- "liconn" - #- "multichannel" - #- "neg_guidance" - #- "no_touch_thick" - #- "pos_guidance" - #- "slice_perturbed" - #- "touching_thin" + - "base" + - "liconn" + - "multichannel" + - "neg_guidance" + - "no_touch_thick" + - "pos_guidance" + - "slice_perturbed" + - "touching_thin" - "train_100" base_data_path: - "/cajal/nvmescratch/projects/NISB/" save_path: - #- "/cajal/scratch/projects/misc/riegerfr/aff_nis/" - - "/cajal/scratch/projects/misc/zuzur/xl_banis" + - "/cajal/scratch/projects/misc/riegerfr/aff_nis/" exp_name: - - "xl_test" + - "exp" real_data_path: #https://colab.research.google.com/github/funkelab/lsd/blob/master/lsd/tutorial/notebooks/lsd_data_download.ipynb - "/cajal/scratch/projects/misc/mdraw/data/funke/zebrafinch/training/" auto_resubmit: - - True + - False distributed: - - True - compile: - False + compile: + - True validate_extern: - - True \ No newline at end of file + - True + augment: + - True diff --git a/environment.yaml b/environment.yaml index 520faa1..a303a50 100644 --- a/environment.yaml +++ b/environment.yaml @@ -8,6 +8,7 @@ dependencies: - bzip2=1.0.8 - ca-certificates=2024.8.30 - cython==3.0.11 + - dask=2025.7.0 - ld_impl_linux-64=2.43 - libexpat=2.6.3 - libffi=3.4.2 @@ -40,9 +41,11 @@ dependencies: - batchgenerators==0.25 - certifi==2024.8.30 - charset-normalizer==3.4.0 + - cloud_volume==12.4.1 - connected-components-3d==3.19.0 - contourpy==1.3.0 - cycler==0.12.1 + - dask_jobqueue==0.9.0 - dicom2nifti==2.5.0 - fasteners==0.19 - filelock==3.16.1 @@ -68,7 +71,9 @@ dependencies: - monai==1.3.2 - mpmath==1.3.0 - multidict==6.1.0 + - mwatershed==0.5.3 - networkx==3.3 + - neuroglancer==2.40.1 - nibabel==5.3.0 - numba==0.60.0 - numcodecs==0.13.1 diff --git a/inference.py b/inference.py index 4ffecbf..06ee589 100644 --- a/inference.py +++ b/inference.py @@ -1,7 +1,10 @@ import shutil from collections import defaultdict +from copy import deepcopy from typing import Union, List, Tuple +import argparse +import cc3d import numba import numpy as np import torch @@ -10,7 +13,6 @@ import dask from dask import compute, persist, delayed from dask.distributed import Client, LocalCluster -from dask.diagnostics import ProgressBar import dask.array as da from distributed import progress from filelock import FileLock @@ -19,397 +21,630 @@ from torch import autocast from torch.nn.functional import sigmoid from tqdm import tqdm +import mwatershed + + +class Utils: + @staticmethod + def get_coordinates(shape: Tuple[int, int, int], small_size: int, overlap: int = 0, + last_has_smaller_overlap: bool = True) -> List[Tuple[int, int, int]]: + """ + Get coordinates for smaller patches to process a big cube in memory. + Args: + shape: The shape of the input (x, y, z). + small_size: The size of the patches. + overlap: The overlap between patches. The default 0 means no overlap (next patch starts on the next pixel from the previous patch). For half-cube overlap set overlap=small_size//2, for 1-pixel overlap set overlap=1. + last_has_smaller_overlap: If the last patch with the specified size and overlap would exceed the big cube, move the patch so that it ends with the big cube, creating a bigger overlap in this patch. + Returns: + List of (x, y, z) coordinates (starting voxel of a patch) for processing of smaller patches. + """ + if overlap < 0 or overlap >= small_size: + raise ValueError(f"Overlap must be between 0 and {small_size}.") + offsets = [Utils.get_offsets(s, small_size, small_size - overlap, last_has_smaller_overlap) for s in shape] + xyzs = [(x, y, z) for x in offsets[0] for y in offsets[1] for z in offsets[2]] + return xyzs + + @staticmethod + def get_offsets(big_size, small_size, step, last_has_smaller_overlap): + offsets = list(range(0, big_size - small_size + 1, step)) + if small_size > big_size: + offsets.append(0) + elif offsets[-1] != big_size - small_size and last_has_smaller_overlap: + offsets.append(big_size - small_size) + elif offsets[-1] != big_size - small_size and not last_has_smaller_overlap: + offsets.append(len(offsets) * step) + return offsets + + @staticmethod + def chunk_xyzs(xyzs, chunk_cube_size=1024): + """ + Chunks the patch coordinates into chunks containing coordinates from the same part of the big cube. + Args: + xyzs: list of all coordinates + chunk_cube_size: side length of each chunk + Returns: + chunked coordinates + """ + chunks = defaultdict(list) + for x, y, z in xyzs: + chunks[(x // chunk_cube_size, y // chunk_cube_size, z // chunk_cube_size)].append((x, y, z)) + return list(chunks.values()) + + @staticmethod + def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: + """Scale sigmoid to avoid numerical issues in high confidence fp16.""" + return sigmoid(0.2 * x) + + @staticmethod + def get_xyz_end(chunk, chunk_cube_size, aff_shape): + """ + Returns the end indices of a chunk, that correspond either to the chunk size, or align with the size of the affinities. + """ + x, y, z = chunk + x_end, y_end, z_end = (min(x + chunk_cube_size, aff_shape[1]), + min(y + chunk_cube_size, aff_shape[2]), + min(z + chunk_cube_size, aff_shape[3])) + return (x_end, y_end, z_end) + + +class AffinityPredictor: + def __init__(self, + chunk_cube_size: int = 1024, + compute_backend: str = "local", + model: torch.nn.Module = None, + model_path: str = None, + small_size: int = 128, + do_overlap: bool = True, + prediction_channels: int = 6, + divide: int = 1, + ): + self.chunk_cube_size = chunk_cube_size + self.compute_backend = compute_backend + + self.model = model # only for local prediction + self.model_path = model_path # loads model in the worker in case of distributed inference (model not pickleable) + self.small_size = small_size + self.do_overlap = do_overlap + self.prediction_channels = prediction_channels + self.divide = divide + + def img_to_aff(self, img, zarr_path): + """ + Complete prediction of affinities from the input image, with the model previously specified in AffinityPredictor. + """ + print(f"Performing patched inference with do_overlap={self.do_overlap} for img of shape {img.shape} and dtype {img.dtype}") + print(f"Parameters: cube size {self.chunk_cube_size}, compute backend {self.compute_backend}.") + + all_patch_coordinates = Utils.get_coordinates(img.shape[:3], self.small_size, overlap=self.small_size // 2 if self.do_overlap else 0, last_has_smaller_overlap=True) + chunked_patch_coordinates = Utils.chunk_xyzs(all_patch_coordinates, self.chunk_cube_size) + + z = zarr.open_group(zarr_path + "_tmp", mode='w') + zarr_chunk_size = min(self.chunk_cube_size, 512) + z.create_dataset('sum_pred', shape=(self.prediction_channels, *img.shape[:3]), chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') + z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') + + if self.compute_backend == "local": + for chunk in tqdm(chunked_patch_coordinates, desc="chunks"): + self.predict_aff_patches_chunked(chunk, img, zarr_path + "_tmp") + torch.cuda.empty_cache() + else: + if self.compute_backend == "local_cluster": + from dask_cuda import LocalCUDACluster + cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU + elif self.compute_backend == "slurm": + from dask_jobqueue import SLURMCluster + cluster = SLURMCluster( + cores=8, + memory="400GB", + processes=1, + worker_extra_args=["--resources processes=1", "--nthreads=1"], + job_extra_directives=["--gres=gpu:1"], + walltime="1-00:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + + else: + raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") + + client = Client(cluster) + print(f"Waiting for workers...") + client.wait_for_workers(n_workers=1) + print("Dask Client Dashboard:", client.dashboard_link) + tasks = [dask.delayed(self.predict_aff_patches_chunked)(chunk, img, zarr_path + "_tmp") for chunk in chunked_patch_coordinates] + futures = persist(tasks) + progress(futures) # progress bar + compute(futures) + + tmp_sum_pred = da.from_zarr(f"{zarr_path}_tmp/sum_pred") + tmp_sum_weight = da.from_zarr(f"{zarr_path}_tmp/sum_weight") + aff = tmp_sum_pred / tmp_sum_weight + aff.to_zarr(zarr_path, overwrite=True) + + shutil.rmtree(zarr_path + "_tmp") + + return + + def predict_aff_patches_chunked(self, patch_coordinates, img, zarr_path): + """ + Patch-wise predicts affinities in-memory, using coordinates of all patches inside a chunk. + Args: + patch_coordinates: List of patch coordinates. The extension of the coordinates must fit in memory (use adequate chunk size). + Returns: + Affinity prediction of the input chunk. + """ + max_x = max(x for x, y, z in patch_coordinates) + max_y = max(y for x, y, z in patch_coordinates) + max_z = max(z for x, y, z in patch_coordinates) + min_x = min(x for x, y, z in patch_coordinates) + min_y = min(y for x, y, z in patch_coordinates) + min_z = min(z for x, y, z in patch_coordinates) + + img_tmp = img[ + min_x: max_x + self.small_size, + min_y: max_y + self.small_size, + min_z: max_z + self.small_size, + ] + pred_tmp = np.zeros((self.prediction_channels, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) + weight_tmp = np.zeros((1, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) + single_pred_weight = self.get_single_pred_weight(self.do_overlap, self.small_size) + + if not self.model: + from BANIS import BANIS + print(self.model_path, flush=True) + model = BANIS.load_from_checkpoint(self.model_path) + else: + model = self.model + + for x_global, y_global, z_global in tqdm(patch_coordinates, desc=f'cube ({min_x}, {max_x + self.small_size}), ({min_y}, {max_y + self.small_size}), ({min_z}, {max_z + self.small_size})'): + x = x_global - min_x + y = y_global - min_y + z = z_global - min_z + img_patch = torch.tensor(np.moveaxis(img_tmp[x: x + self.small_size, y: y + self.small_size, z: z + self.small_size], -1, 0)[None]).to(model.device) / self.divide + pred = Utils.scale_sigmoid(model(img_patch))[0, :self.prediction_channels] + + weight_tmp[:, x: x + self.small_size, y: y + self.small_size, z: z + self.small_size] += single_pred_weight if self.do_overlap else 1 + pred_tmp[:, x: x + self.small_size, y: y + self.small_size, z: z + self.small_size] += pred.detach().cpu().numpy() * (single_pred_weight[None] if self.do_overlap else 1) + + z = zarr.open_group(zarr_path, mode='a') + weight_mask = z['sum_weight'] + full_pred = z['sum_pred'] + + with FileLock(f"{zarr_path}/sum_weight.lock"): + weight_mask[ + :, + min_x: max_x + self.small_size, + min_y: max_y + self.small_size, + min_z: max_z + self.small_size, + ] += weight_tmp + + with FileLock(f"{zarr_path}/sum_pred.lock"): + full_pred[ + :, + min_x: max_x + self.small_size, + min_y: max_y + self.small_size, + min_z: max_z + self.small_size, + ] += pred_tmp + + def get_single_pred_weight(self, do_overlap: bool, small_size: int) -> Union[np.ndarray, None]: + """ + Get the weight for a single prediction. + + Args: + do_overlap: Whether to perform overlapping predictions. + small_size: The size of the patches. + + Returns: + The weight array for a single prediction, or None if no overlap. + """ + if do_overlap: + # The weight (confidence/expected quality) of the predictions: + # Low at the surface of the predicted cube, high in the center + pred_weight_helper = np.pad(np.ones((small_size,) * 3), 1, mode='constant') + return distance_transform_cdt(pred_weight_helper).astype(np.float32)[1:-1, 1:-1, 1:-1] + else: + return None -def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: - """Scale sigmoid to avoid numerical issues in high confidence fp16.""" - return sigmoid(0.2 * x) - - -def measure_stats(func): - import os - import time - from datetime import timedelta - import tracemalloc - import threading - import psutil - - def monitor_memory(interval=0.1, result=None): - proc = psutil.Process(os.getpid()) - peak = 0 - while not getattr(monitor_memory, "stop", False): - rss = proc.memory_info().rss - peak = max(peak, rss) - time.sleep(interval) - if result is not None: - result["peak"] = peak # Save peak memory to shared dict - - def wrapper(*args, **kwargs): - memory_stats = {} - thread = threading.Thread(target=monitor_memory, kwargs={"interval": 0.1, "result": memory_stats}) - thread.start() - torch.cuda.reset_peak_memory_stats() - tracemalloc.start() - start = time.time() - - result = func(*args, **kwargs) - - end = time.time() - elapsed = timedelta(seconds=end - start) - current, peak = tracemalloc.get_traced_memory() - max_mem = torch.cuda.max_memory_reserved() - monitor_memory.stop = True - thread.join() - - stats = { - "time": f"{elapsed}", - "peak_python_mem": f"{peak / 1024**2:.2f} MB", - "max_cuda_mem": f"{max_mem / 1024 ** 2:.2f} MB", - "rss_mem": f"{memory_stats['peak'] / 1024 ** 2:.2f} MB" - } - - return result, stats - - return wrapper - - -@jit(nopython=True) -def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray: - """ - Compute connected components from affinities. - - Args: - hard_aff: The (thresholded, boolean) short range affinities. Shape: (3, x, y, z). - - Returns: - The segmentation. Shape: (x, y, z). - """ - visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=numba.boolean) - seg = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint32) - cur_id = 1 - for i in range(visited.shape[0]): - for j in range(visited.shape[1]): - for k in range(visited.shape[2]): - if hard_aff[:, i, j, k].any() and not visited[i, j, k]: # If foreground - cur_to_visit = [(i, j, k)] - visited[i, j, k] = True - while cur_to_visit: - x, y, z = cur_to_visit.pop() - seg[x, y, z] = cur_id - - # Check all neighbors - if x + 1 < visited.shape[0] and hard_aff[0, x, y, z] and not visited[x + 1, y, z]: - cur_to_visit.append((x + 1, y, z)) - visited[x + 1, y, z] = True - if y + 1 < visited.shape[1] and hard_aff[1, x, y, z] and not visited[x, y + 1, z]: - cur_to_visit.append((x, y + 1, z)) - visited[x, y + 1, z] = True - if z + 1 < visited.shape[2] and hard_aff[2, x, y, z] and not visited[x, y, z + 1]: - cur_to_visit.append((x, y, z + 1)) - visited[x, y, z + 1] = True - if x - 1 >= 0 and hard_aff[0, x - 1, y, z] and not visited[x - 1, y, z]: - cur_to_visit.append((x - 1, y, z)) - visited[x - 1, y, z] = True - if y - 1 >= 0 and hard_aff[1, x, y - 1, z] and not visited[x, y - 1, z]: - cur_to_visit.append((x, y - 1, z)) - visited[x, y - 1, z] = True - if z - 1 >= 0 and hard_aff[2, x, y, z - 1] and not visited[x, y, z - 1]: - cur_to_visit.append((x, y, z - 1)) - visited[x, y, z - 1] = True - cur_id += 1 - return seg - - -@torch.no_grad() -@autocast(device_type="cuda") -def predict_aff( - img: Union[np.ndarray, zarr.Array], - model: torch.nn.Module = None, - model_path: str = None, - zarr_path: str = "aff_prediction.zarr", - small_size: int = 128, - do_overlap: bool = True, - prediction_channels: int = 6, - divide: int = 1, - chunk_cube_size: int = 1024, - compute_backend: str = "local" -): - """ - Perform patched affinity prediction with a model on an image. - - Args: - img: The input image. Shape: (x, y, z, channel). - model: The model to use for predictions (only for local prediction). - model_path: Path to the model checkpoint to use for predictions (if model not specified). - zarr_path: Output path to save the prediction in zarr format. - small_size: The size of the patches. Defaults to 128. - do_overlap: Whether to perform overlapping predictions. Defaults to True: - half of patch size for all 3 axes. - prediction_channels: The number of channels in the output (additional model output - dimensions are discarded). Defaults to 6 (3 short + 3 long range affinities). - divide: The divisor for the image. Typically, 1 or 255 if img in [0, 255] - chunk_cube_size: The maximal side length of a cube held in memory. - compute_backend: Type of computation / dask backend. One of: - - - "local": uses a cycle on the local machine (default) - - "local_cluster": uses a localGPUcluster to utilize all local GPUs without SLURM - - "slurm": uses a slurm cluster with all available nodes - - Returns: - The full prediction. Shape: (channel, x, y, z). - """ - print( - f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") - print(f"Parameters: cube size {chunk_cube_size}, compute backend {compute_backend}.") - - all_patch_coordinates = get_coordinates(img.shape[:3], small_size, do_overlap) - chunked_patch_coordinates = chunk_xyzs(all_patch_coordinates, chunk_cube_size) - - z = zarr.open_group(zarr_path + "_tmp", mode='w') - zarr_chunk_size = min(chunk_cube_size, 512) - z.create_dataset('sum_pred', shape=(prediction_channels, *img.shape[:3]), - chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') - z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), - chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') - - if compute_backend == "local": - if not model: - from BANIS import BANIS - model = BANIS.load_from_checkpoint(model_path) - for chunk in tqdm(chunked_patch_coordinates): - predict_aff_patches_chunked(chunk, img, model, zarr_path + "_tmp", small_size, do_overlap, prediction_channels, divide) - torch.cuda.empty_cache() # TODO: does this help? - else: - if compute_backend == "local_cluster": - from dask_cuda import LocalCUDACluster - cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU - elif compute_backend == "slurm": - from dask_jobqueue import SLURMCluster - cluster = SLURMCluster( - cores=8, - memory="400GB", - processes=1, - worker_extra_args=["--resources processes=1", "--nthreads=1"], - job_extra_directives=["--gres=gpu:1"], - walltime="1-00:00:00" - ) - cluster.adapt(minimum_jobs=1, maximum_jobs=32) +class Postprocessing: + def __init__(self, + chunk_cube_size: int = 1024, + compute_backend: str = "local" + ): + self.chunk_cube_size = chunk_cube_size + self.compute_backend = compute_backend + + def aff_to_seg(self, aff, zarr_path): + chunks = Utils.get_coordinates(aff.shape[1:], self.chunk_cube_size, overlap=1, last_has_smaller_overlap=False) + reverse_chunks = {chunk: i for i, chunk in enumerate(chunks)} + patched_zarr_path = zarr_path + "_tmp" + + zarr_chunk_size = min(self.chunk_cube_size, 512) + z_root = zarr.create(shape=(len(chunks), self.chunk_cube_size, self.chunk_cube_size, self.chunk_cube_size), + store=patched_zarr_path, dtype='i4', overwrite=True, + chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size)) + + # SEGMENT AFFINITIES IN CHUNKS THAT FIT IN MEMORY + self.patched_segment_affinities(aff, patched_zarr_path, chunks) + + # FIND GROUPS OF FRAGMENTS THAT SHOULD BE MERGED BETWEEN CHUNKS + fragment_agglomeration, max_id = self.agglomerate_fragments(chunks, reverse_chunks, patched_zarr_path, aff.shape) + + # MERGE AND RELABEL INSTANCES GLOBALLY + self.merge_and_relabel(fragment_agglomeration, max_id, patched_zarr_path, zarr_path, chunks, aff.shape) + + return + def patched_segment_affinities(self, aff, patched_zarr_path, chunks): + if self.compute_backend == "local": + for i, chunk in enumerate(tqdm(chunks)): + self.segment_chunk_wrapped(chunk, i, aff, patched_zarr_path) else: - raise NotImplementedError(f"Compute backend {compute_backend} not available.") - - client = Client(cluster) - print(f"Waiting for workers...") - client.wait_for_workers(n_workers=1) - print("Dask Client Dashboard:", client.dashboard_link) - tasks = [dask.delayed(predict_aff_patches_chunked)(chunk, img, model_path, zarr_path + "_tmp", small_size, do_overlap, prediction_channels, divide) - for chunk in chunked_patch_coordinates - ] - futures = persist(tasks) - progress(futures) # progress bar - compute(futures) - - tmp_sum_pred = da.from_zarr(f"{zarr_path}_tmp/sum_pred") - tmp_sum_weight = da.from_zarr(f"{zarr_path}_tmp/sum_weight") - aff = tmp_sum_pred / tmp_sum_weight - aff.to_zarr(zarr_path, overwrite=True) - - shutil.rmtree(zarr_path + "_tmp") - - return zarr.open(zarr_path, mode="r") - - -def get_coordinates( - shape: Tuple[int, int, int], small_size: int, do_overlap: bool -) -> List[Tuple[int, int, int]]: - """ - Get coordinates for cubes to be predicted. - - Args: - shape: The shape of the input image (x, y, z). - small_size: The size of the patches. - do_overlap: Whether to perform overlapping predictions. - - Returns: - List of (x, y, z) coordinates for prediction cubes. - """ - offsets = [get_offsets(s, small_size) for s in shape] - xyzs = [(x, y, z) for x in offsets[0] for y in offsets[1] for z in offsets[2]] - if do_overlap: # Add shifted cubes (half cube overlap) - offset = small_size // 2 - - xyzs_shifted = [ - set((x + offset, y, z) for x, y, z in xyzs), - set((x, y + offset, z) for x, y, z in xyzs), - set((x, y, z + offset) for x, y, z in xyzs), - set((x + offset, y + offset, z) for x, y, z in xyzs), - set((x + offset, y, z + offset) for x, y, z in xyzs), - set((x, y + offset, z + offset) for x, y, z in xyzs), - set((x + offset, y + offset, z + offset) for x, y, z in xyzs), - ] - xyzs_shifted = set( - (x, y, z) - for s in xyzs_shifted - for x, y, z in s - if x + small_size <= shape[0] - and y + small_size <= shape[1] - and z + small_size <= shape[2] + if self.compute_backend == "local_cluster": + from dask_cuda import LocalCUDACluster + cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU + elif self.compute_backend == "slurm": + from dask_jobqueue import SLURMCluster + cluster = SLURMCluster( + cores=8, + memory="400GB", + processes=1, + worker_extra_args=["--resources processes=1", "--nthreads=1"], + job_extra_directives=["--gres=gpu:1"], + walltime="1-00:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + else: + raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") + + client = Client(cluster) + print(f"Waiting for workers...") + client.wait_for_workers(n_workers=1) + print("Dask Client Dashboard:", client.dashboard_link) + tasks = [dask.delayed(self.segment_chunk_wrapped)(chunk, i, aff, patched_zarr_path) for (i, chunk) in enumerate(chunks)] + futures = persist(tasks) + progress(futures) # progress bar + compute(futures) + + def agglomerate_fragments(self, chunks, reverse_chunks, patched_zarr_path, aff_shape): + if self.compute_backend == "local": + fragment_agglomeration = {} + for i, chunk in enumerate(tqdm(chunks)): + chunk_agglomeration = self.agglomerate_chunk(chunk, reverse_chunks, patched_zarr_path, aff_shape) + for node, nbrs in chunk_agglomeration.items(): + for nbr in nbrs: + fragment_agglomeration.setdefault(node, set()).add(nbr) + if len(fragment_agglomeration) > 10_000_000: + print("WARNING: fragment agglomeration too long, might cause problems!") + # TODO: solve this + + curr_id, fragment_agglomeration_flattened = self.flatten_agglomeration(fragment_agglomeration) + #print("MERGING CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) + #fragment_agglomeration_flattened = self.add_all_fragments_to_agglomeration(fragment_agglomeration_flattened, curr_id, chunks, patched_zarr_path) + #print("ALL CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) + + else: + # TODO: add slurm (and measure memory) + raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") + + return fragment_agglomeration_flattened, curr_id + + def agglomerate_chunk(self, chunk, reverse_chunks, patched_zarr_path, aff_shape): + fragment_agglomeration = {} + x, y, z = chunk + x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff_shape) + z_root = zarr.open(patched_zarr_path, mode='r') + + # for (x,y,z) get the last slice of the current cube (l, low) and the first slice of the next cube (h, high) + # these slices overlap, so the voxels should have the same global id + + if x_end < aff_shape[1]: + chunk_l = reverse_chunks[chunk] + chunk_h = reverse_chunks[x + self.chunk_cube_size - 1, y, z] + result_l = z_root[chunk_l, -1:, :, :] + result_h = z_root[chunk_h, :1, :, :] + combined = np.stack([result_l.flatten(), result_h.flatten()]).T + uniques = np.unique(combined, axis=0) + fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) + + if y_end < aff_shape[2]: + chunk_l = reverse_chunks[chunk] + chunk_h = reverse_chunks[x, y + self.chunk_cube_size - 1, z] + result_l = z_root[chunk_l, :, -1:, :] + result_h = z_root[chunk_h, :, :1, :] + combined = np.stack([result_l.flatten(), result_h.flatten()]).T + uniques = np.unique(combined, axis=0) + fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) + + if z_end < aff_shape[3]: + chunk_l = reverse_chunks[chunk] + chunk_h = reverse_chunks[x, y, z + self.chunk_cube_size - 1] + result_l = z_root[chunk_l, :, :, -1:] + result_h = z_root[chunk_h, :, :, :1] + combined = np.stack([result_l.flatten(), result_h.flatten()]).T + uniques = np.unique(combined, axis=0) + fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) + + return fragment_agglomeration + + def update_fragment_agglomeration(self, fragment_agglomeration, uniques, chunk_l, chunk_h): + for id_l, id_h in uniques: + if id_l > 0 and id_h > 0: + fragment_agglomeration.setdefault((chunk_h, id_h), set()).add( + (chunk_l, id_l) + ) + fragment_agglomeration.setdefault((chunk_l, id_l), set()).add( + (chunk_h, id_h) + ) + return fragment_agglomeration + + def flatten_agglomeration(self, fragment_agglomeration): + """ + Computes connected components in the fragment agglomeration graph, and assigns the fragments new ids starting from 1. + Args: + fragment_agglomeration: dictionary with keys (chunk_id, fragment_id), and values a set of (chunk_id, fragment_id) in another chunk (cube) that should be connected + Returns: + fragment_agglomeration_flattened: dictionary with keys (chunk_id, fragment_id) and values the global component index + """ + cur_id = 1 + fragment_agglomeration_flattened = dict() + for position_id in tqdm(fragment_agglomeration): # (chunk, idx) = position_id + if position_id not in fragment_agglomeration_flattened: + to_visit = {position_id} + visited = set() + while len(to_visit) > 0: + current = to_visit.pop() + if current not in visited: + visited.add(current) + for neighbor in fragment_agglomeration[current]: + to_visit.add(neighbor) + for v in visited: + assert v not in fragment_agglomeration_flattened + fragment_agglomeration_flattened[v] = cur_id + cur_id += 1 + + return cur_id, fragment_agglomeration_flattened + + #def add_all_fragments_to_agglomeration(self, fragment_agglomeration_flattened, cur_id, chunks, patched_zarr_path): + # z_root = zarr.open(patched_zarr_path) + # for i, chunk in enumerate(tqdm(chunks)): + # data = z_root[i, :, :, :] + # for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max + # if (i, idx) not in fragment_agglomeration_flattened: + # fragment_agglomeration_flattened[(i, idx)] = cur_id + # cur_id += 1 + # return fragment_agglomeration_flattened + + def merge_and_relabel(self, fragment_agglomeration, max_id, zarr_patched, zarr_final, chunks, aff_shape): + zarr_chunk_size = min(self.chunk_cube_size, 512) + z_root = zarr.open(zarr_patched) + z_final = zarr.create(shape=aff_shape[1:], + store=zarr_final, dtype='i4', overwrite=True, + chunks=(zarr_chunk_size, zarr_chunk_size, zarr_chunk_size)) + + if self.compute_backend == "local": + for i, chunk in enumerate(tqdm(chunks)): + x, y, z = chunk + x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff_shape) + data = z_root[i, : x_end - x, : y_end - y, : z_end - z] + perm = [0] + for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max + if not (i, idx) in fragment_agglomeration: + max_id += 1 + perm.append(max_id) + else: + perm.append(fragment_agglomeration[(i, idx)]) + perm = np.array(perm, dtype=np.uint64) + relabeled = perm[data] + z_final[x: x_end, y: y_end, z: z_end] = relabeled + + else: + raise NotImplementedError(f"Compute backend {self.compute_backend} not implemented.") + + shutil.rmtree(zarr_patched) + + def segment_chunk_wrapped(self, chunk, i, aff, zarr_path): + x, y, z = chunk + x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff.shape) + curr_aff = aff[:, x : x_end, y : y_end, z : z_end] + curr_seg = self.segment_chunk(curr_aff) + z_root = zarr.open(zarr_path, mode="r+") + z_root[i, : x_end - x, : y_end - y, : z_end - z] = curr_seg + + def segment_chunk(self, curr_aff): + """ + In-memory segmentation of a chunk of affinities. + Args: + curr_aff: The affinities to segment (must fit in memory). + Returns: + Segmentation of the given affinities. + """ + raise NotImplementedError(f"This method should be overridden in a subclass.") + + +class MutexWatershed(Postprocessing): + def __init__(self, chunk_cube_size, compute_backend, mws_bias_short, mws_bias_long, long_range=10): + super().__init__(chunk_cube_size, compute_backend) + self.mws_bias_short = mws_bias_short + self.mws_bias_long = mws_bias_long + self.long_range = long_range + + def compute_mws_segmentation(self, cur_aff): + cur_aff = deepcopy(cur_aff).astype(np.float64) + cur_aff[:3] += self.mws_bias_short + cur_aff[3:] += self.mws_bias_long + + cur_aff[:3] = np.clip(cur_aff[:3], 0, 1) # short-range attractive edges + cur_aff[3:] = np.clip(cur_aff[3:], -1, 0) # long-range repulsive edges (see the Mutex Watershed paper) + + mws_pred = mwatershed.agglom( + affinities=cur_aff, + offsets=( + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [self.long_range, 0, 0], + [0, self.long_range, 0], + [0, 0, self.long_range], + ] + ), ) - xyzs = list(set.union(set(xyzs), xyzs_shifted)) - return xyzs - - -def get_offsets(big_size: int, small_size: int) -> List[int]: - """ - Calculate offsets for image patching. - - Args: - big_size: The size of the whole image. - small_size: The size of the patches. - - Returns: - List of offsets. - """ - offsets = list(range(0, big_size - small_size + 1, small_size)) - if offsets[-1] != big_size - small_size: - offsets.append(big_size - small_size) - return offsets - - -def get_single_pred_weight(do_overlap: bool, small_size: int) -> Union[np.ndarray, None]: - """ - Get the weight for a single prediction. - - Args: - do_overlap: Whether to perform overlapping predictions. - small_size: The size of the patches. - - Returns: - The weight array for a single prediction, or None if no overlap. - """ - if do_overlap: - # The weight (confidence/expected quality) of the predictions: - # Low at the surface of the predicted cube, high in the center - pred_weight_helper = np.pad(np.ones((small_size,) * 3), 1, mode='constant') - return distance_transform_cdt(pred_weight_helper).astype(np.float32)[1:-1, 1:-1, 1:-1] - else: - return None - - -def chunk_xyzs(xyzs, chunk_cube_size=1024): - """ - Chunks the patch coordinates into chunks containing coordinates from the same part of the big cube. - Args: - xyzs: list of all coordinates - chunk_cube_size: side length of each chunk - Returns: - chunked coordinates - """ - chunks = defaultdict(list) - for x, y, z in xyzs: - chunks[(x // chunk_cube_size, y // chunk_cube_size, z // chunk_cube_size)].append((x, y, z)) - return list(chunks.values()) - - -@torch.no_grad() -@autocast(device_type="cuda") -def predict_aff_patches_chunked(patch_coordinates, img, model_path, zarr_path, small_size, do_overlap, prediction_channels, divide): - """ - Patch-wise predicts affinities in-memory, using coordinates of all patches inside a chunk. - Args: - patch_coordinates: List of patch coordinates. The extension of the coordinates must fit in memory (use adequate chunk size). - Returns: - Affinity prediction of the input chunk. - """ - max_x = max(x for x, y, z in patch_coordinates) - max_y = max(y for x, y, z in patch_coordinates) - max_z = max(z for x, y, z in patch_coordinates) - min_x = min(x for x, y, z in patch_coordinates) - min_y = min(y for x, y, z in patch_coordinates) - min_z = min(z for x, y, z in patch_coordinates) - - img_tmp = img[ - min_x: max_x + small_size, - min_y: max_y + small_size, - min_z: max_z + small_size, - ] - pred_tmp = np.zeros((prediction_channels, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) - weight_tmp = np.zeros((1, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) - single_pred_weight = get_single_pred_weight(do_overlap, small_size) - - from BANIS import BANIS - print(model_path, flush=True) - model = BANIS.load_from_checkpoint(model_path) - - for x_global, y_global, z_global in patch_coordinates: - x = x_global - min_x - y = y_global - min_y - z = z_global - min_z - img_patch = torch.tensor(np.moveaxis( - img_tmp[x: x + small_size, y: y + small_size, z: z + small_size], - -1, 0)[None]).to(model.device) / divide - pred = scale_sigmoid(model(img_patch))[0, :prediction_channels] - - weight_tmp[:, x: x + small_size, y: y + small_size, - z: z + small_size] += single_pred_weight if do_overlap else 1 - pred_tmp[:, x: x + small_size, y: y + small_size, z: z + small_size] += pred.detach().cpu().numpy() * ( - single_pred_weight[None] if do_overlap else 1) - - z = zarr.open_group(zarr_path, mode='a') - weight_mask = z['sum_weight'] - full_pred = z['sum_pred'] - - with FileLock(f"{zarr_path}/sum_weight.lock"): - weight_mask[ - :, - min_x: max_x + small_size, - min_y: max_y + small_size, - min_z: max_z + small_size, - ] += weight_tmp - - with FileLock(f"{zarr_path}/sum_pred.lock"): - full_pred[ - :, - min_x: max_x + small_size, - min_y: max_y + small_size, - min_z: max_z + small_size, - ] += pred_tmp + + # mwatershed is wasteful with IDs (not contiguous) -> filter out single voxel objects and relabel again + # size filter. single voxel objects are irrelevant for merging, take ~95% of IDs in an example cube, causing OOM when creating fragment_agglomeration + dusted = cc3d.dust( # does a cc first (reducing false mergers in add_to_agglomeration) + mws_pred, + threshold=2, + connectivity=6, + in_place=False, + ) + # relabeling to save IDs + pred_relabeled, N = cc3d.connected_components( + dusted, return_N=True, connectivity=6 + ) + + assert (pred_relabeled[mws_pred == 0] == 0).all() # 0 stays 0 + assert N <= np.iinfo(np.uint32).max + + pred = pred_relabeled.astype(np.uint32) + return pred + + def segment_chunk(self, curr_aff): + return self.compute_mws_segmentation(curr_aff) + + + +class Thresholding(Postprocessing): + def __init__(self, chunk_cube_size, compute_backend, thr): + super().__init__(chunk_cube_size, compute_backend) + self.thr = thr + + @staticmethod + @jit(nopython=True) + def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray: + """ + Compute connected components from affinities. + + Args: + hard_aff: The (thresholded, boolean) short range affinities. Shape: (3, x, y, z). + + Returns: + The segmentation. Shape: (x, y, z). + """ + visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=numba.boolean) + seg = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint32) + cur_id = 1 + for i in range(visited.shape[0]): + for j in range(visited.shape[1]): + for k in range(visited.shape[2]): + if hard_aff[:, i, j, k].any() and not visited[i, j, k]: # If foreground + cur_to_visit = [(i, j, k)] + visited[i, j, k] = True + while cur_to_visit: + x, y, z = cur_to_visit.pop() + seg[x, y, z] = cur_id + + # Check all neighbors + if x + 1 < visited.shape[0] and hard_aff[0, x, y, z] and not visited[x + 1, y, z]: + cur_to_visit.append((x + 1, y, z)) + visited[x + 1, y, z] = True + if y + 1 < visited.shape[1] and hard_aff[1, x, y, z] and not visited[x, y + 1, z]: + cur_to_visit.append((x, y + 1, z)) + visited[x, y + 1, z] = True + if z + 1 < visited.shape[2] and hard_aff[2, x, y, z] and not visited[x, y, z + 1]: + cur_to_visit.append((x, y, z + 1)) + visited[x, y, z + 1] = True + if x - 1 >= 0 and hard_aff[0, x - 1, y, z] and not visited[x - 1, y, z]: + cur_to_visit.append((x - 1, y, z)) + visited[x - 1, y, z] = True + if y - 1 >= 0 and hard_aff[1, x, y - 1, z] and not visited[x, y - 1, z]: + cur_to_visit.append((x, y - 1, z)) + visited[x, y - 1, z] = True + if z - 1 >= 0 and hard_aff[2, x, y, z - 1] and not visited[x, y, z - 1]: + cur_to_visit.append((x, y, z - 1)) + visited[x, y, z - 1] = True + cur_id += 1 + return seg + + def segment_chunk(self, curr_aff): + return self.compute_connected_component_segmentation(curr_aff[:3] > self.thr) def full_inference( + # RESOURCES ARGUMENTS: + chunk_cube_size: int = 3000, + compute_backend: str = "local", # AFFINITY PREDICTION ARGUMENTS: - img: Union[np.ndarray, zarr.Array], - model_path: str, + img: Union[np.ndarray, zarr.Array] = None, + model_path: str = None, aff_zarr_path: str = "aff_prediction.zarr", small_size: int = 128, do_overlap: bool = True, prediction_channels: int = 6, divide: int = 1, - chunk_cube_size: int = 1024, - compute_backend: str = "local", # POSTPROCESSING ARGUMENTS: postprocessing_type: str = "thresholding", + seg_zarr_path: str = "seg_prediction.zarr", thr: float = 0.5, - seg_zarr_path: str = "seg_prediction.zarr" + mws_bias_short: float = -0.5, + mws_bias_long: float = -0.5, ): - - aff = predict_aff( - img, + affinity_predictor = AffinityPredictor( + chunk_cube_size=chunk_cube_size, + compute_backend=compute_backend, model_path=model_path, - zarr_path=aff_zarr_path, small_size=small_size, do_overlap=do_overlap, prediction_channels=prediction_channels, divide=divide, - chunk_cube_size=chunk_cube_size, - compute_backend=compute_backend ) + affinity_predictor.img_to_aff(img, zarr_path=aff_zarr_path) + aff = zarr.open(aff_zarr_path, mode="r") if postprocessing_type == "thresholding": - seg = compute_connected_component_segmentation(aff[:3] > thr) - zarr.array(seg, store=seg_zarr_path) + postprocessor = Thresholding(chunk_cube_size, compute_backend, thr) elif postprocessing_type == "mws": - raise NotImplementedError(f"Mutex Watershed is not implemented") + postprocessor = MutexWatershed(chunk_cube_size, compute_backend, mws_bias_short, mws_bias_long) else: raise NotImplementedError(f"Postprocessing type {postprocessing_type} is not implemented") - + postprocessor.aff_to_seg(aff, zarr_path=seg_zarr_path) + seg = zarr.open(seg_zarr_path, mode="r") + + print(f"Segmentation saved at {seg_zarr_path}.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--chunk_cube_size", type=int, default=3000, help="The maximal side length of a cube held in memory.") + parser.add_argument("--compute_backend", type=str, default="local", help="Compute backend to use: local, slurm, or local_cluster.") + parser.add_argument("--img_path", type=str, help="The image to segment (path to zarr).") + parser.add_argument("--model_path", type=str, help="The path to the trained model.") + parser.add_argument("--aff_zarr_path", type=str, default="aff_prediction.zarr", help="Where to save the predicted affinities.") + parser.add_argument("--small_size", type=int, default=128, help="Size of the small patches for affinity prediction (model parameter).") + parser.add_argument("--do_overlap", type=bool, default=True, help="Use overlapping patches for affinity prediction for better precision.") + parser.add_argument("--prediction_channels", type=int, default=6, help="The number of prediction channels. Defaults to 6 (3 short + 3 long range affinities).") + parser.add_argument("--divide", type=int, default=255, help="The divisor for the image. Typically, 1 or 255 if img in [0, 255].") + parser.add_argument("--postprocessing_type", type=str, default="thresholding", help="Type of postprocessing to use: thresholding, or mws (mutex watershed).") + parser.add_argument("--seg_zarr_path", type=str, default="seg_prediction.zarr", help="Where to save the final segmentation.") + parser.add_argument("--thr", type=float, default=0.5, help="Threshold in case of thresholding.") + parser.add_argument("--mws_bias_short", type=float, default=-0.5, help="Short-range bias for mutex watershed.") + parser.add_argument("--mws_bias_long", type=float, default=-0.5, help="Long-range bias for mutex watershed.") + + args = parser.parse_args() + + img = zarr.open(args.img_path, mode="r")["img"] + full_inference( + chunk_cube_size=args.chunk_cube_size, + compute_backend=args.compute_backend, + img=img, + model_path=args.model_path, + aff_zarr_path=args.aff_zarr_path, + small_size=args.small_size, + do_overlap=args.do_overlap, + prediction_channels=args.prediction_channels, + divide=args.divide, + postprocessing_type=args.postprocessing_type, + seg_zarr_path=args.seg_zarr_path, + thr=args.thr, + mws_bias_short=args.mws_bias_short, + mws_bias_long=args.mws_bias_long, + )