Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a8710fe
add long training option
Nov 12, 2024
ecc27ed
validate in another process in parallel, fix all bugs
Dec 9, 2024
421abb1
forgotten global_step
Dec 9, 2024
10b6eab
fix time for validation
Dec 11, 2024
99f1cd1
separate validation processes for each checkpoint
Dec 11, 2024
a24c2c1
fix validation run
Dec 18, 2024
b1d78c0
problems with child job
Dec 18, 2024
18ba725
debug OOM, better resubmit
Feb 5, 2025
7f51dc4
sort BANIS arguments
May 23, 2025
6bf836d
ignore slurm logs
May 23, 2025
7d78676
one script to rule them all (long training also possible from slurm_j…
May 23, 2025
45f2ecf
fix exp_name
May 24, 2025
488ac20
monai determinism
May 24, 2025
083e49e
worker init fn
May 24, 2025
6954b51
xl long distributed training
May 24, 2025
edf4f0b
fix bugs
May 24, 2025
c34142f
gradient clippinng, logging activations etc, deactivate augmentations
Jun 13, 2025
7cc87fd
correct worker init function (default)
Jun 30, 2025
5abdc75
config
Jul 2, 2025
585b42b
wip: LocalCluster
Jul 14, 2025
82a832f
add adapted nerl
Aug 1, 2025
2732401
inference
Aug 9, 2025
bebf90d
inference: local or slurm computation, memory management via cube size
Aug 14, 2025
d9f2670
testing memory-efficient segmentation
Aug 29, 2025
4ced80f
class-based inference
ZuzkaU Aug 29, 2025
1c62b46
fix bug
ZuzkaU Aug 29, 2025
6141a84
dont jit class
ZuzkaU Aug 29, 2025
807b2e8
autocomplete wrong
ZuzkaU Aug 29, 2025
7387843
update and debug and profile
ZuzkaU Aug 31, 2025
03dc977
object-based inference
Sep 10, 2025
11a9075
forgotten file
Sep 10, 2025
288a6c7
return to original config
Sep 10, 2025
1936764
merge back main
Sep 10, 2025
cdbfb37
new inference modules
ZuzkaU Sep 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
slurm-*.out
*.zarr

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
35 changes: 23 additions & 12 deletions BANIS.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import argparse
import gc
import os
import shutil

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict
import random

import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
Expand All @@ -23,7 +23,7 @@
from tqdm import tqdm

from data import load_data
from inference import scale_sigmoid, predict_aff, compute_connected_component_segmentation
from inference import AffinityPredictor, Thresholding
from metrics import compute_metrics


Expand All @@ -35,7 +35,7 @@ class BANIS(LightningModule):
def __init__(self, **kwargs: Any):
super().__init__()
self.save_hyperparameters()
print(f"hparams: \n{self.hparams}")
# print(f"hparams: \n{self.hparams}")

self.model = create_mednext_v1(
num_input_channels=self.hparams.num_input_channels,
Expand Down Expand Up @@ -163,8 +163,17 @@ def full_cube_inference(self, mode: str, global_step=None):

img_data = zarr.open(os.path.join(seed_path, "data.zarr"), mode="r")["img"]

aff_pred = predict_aff(img_data, model=self, zarr_path=f"{self.hparams.save_dir}/pred_aff_{mode}.zarr", do_overlap=True, prediction_channels=3, divide=255,
small_size=self.hparams.small_size, compute_backend="local")
affinity_predictor = AffinityPredictor(
chunk_cube_size=3000, # can be adjusted
compute_backend="local",
model=self,
small_size=self.hparams.small_size,
do_overlap=True,
prediction_channels=3,
divide=255,
)
affinity_predictor.img_to_aff(img_data, zarr_path=f"{self.hparams.save_dir}/pred_aff_{mode}.zarr")
aff_pred = zarr.open(f"{self.hparams.save_dir}/pred_aff_{mode}.zarr", mode="r")

self._evaluate_thresholds(aff_pred, os.path.join(seed_path, "skeleton.pkl"), mode, global_step)

Expand All @@ -179,9 +188,9 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str,
torch.cuda.empty_cache()
print(f"threshold {thr}")

pred_seg = compute_connected_component_segmentation(
aff_pred[:3] > thr # hard affinities
)
postprocessor = Thresholding(3000, "local", thr)
postprocessor.aff_to_seg(aff_pred, f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr")
pred_seg = zarr.open(f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr", mode="r")

metrics = compute_metrics(pred_seg, skel_path)

Expand All @@ -201,9 +210,11 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str,
self.best_thr_so_far[mode] = thr
with open(f"{self.hparams.save_dir}/best_thr_{mode}.txt", "w") as f:
f.write(str(self.best_thr_so_far[mode]))
seg_pred = zarr.array(pred_seg, dtype=np.uint32,
store=f"{self.hparams.save_dir}/pred_seg_{mode}.zarr",
chunks=(512, 512, 512), overwrite=True)
if os.path.exists(f"{self.hparams.save_dir}/pred_seg_{mode}.zarr"):
shutil.rmtree(f"{self.hparams.save_dir}/pred_seg_{mode}.zarr")
os.replace(f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr", f"{self.hparams.save_dir}/pred_seg_{mode}.zarr")
else:
shutil.rmtree(f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr")
best_voi = min(best_voi, metrics["voi_sum"])

self.safe_add_scalar(f"{mode}_best_nerl", best_nerl, global_step)
Expand Down
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ python slurm_job_scheduler.py

Adding an `auto_resubmit` argument to `config.yaml` allows Slurm to automatically resubmit jobs that reach the Slurm time limit (see `aff_train.sh`).

## Prediction

To predict segmentation from an image:

```bash
python inference --img_path /path/to/image.zarr --model_path /path/to/model.ckpt --chunk_cube_size 3000
```

The `chunk_cube_size` parameter sets the maximum cube size that can be loaded in memory.
If you have enough memory available, set it to a bigger value, if you are tight with memory, set a lower value (in exchange for increased computation time).
See [inference.py](inference.py) for other parameters.

## Evaluation

To evaluate a predicted segmentation (`.zarr` or `.npy`):
Expand Down
12 changes: 6 additions & 6 deletions aff_train.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#!/bin/bash -l

#SBATCH --nodes=2
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --nodes=1
#SBATCH --gres=gpu:1
#SBATCH --ntasks-per-node=1
#SBATCH --time=7-00
#SBATCH --cpus-per-task=16
#SBATCH --mem=1000G
#SBATCH --cpus-per-task=32
#SBATCH --mem=500G
#SBATCH --signal=B:USR1@300
#SBATCH --open-mode=append
#SBATCH --partition=p.large
#SBATCH --partition=p.share

mamba activate nisb

Expand Down
49 changes: 25 additions & 24 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ params:
- 1e-2
seed:
- 0
#- 1
#- 2
#- 3
#- 4
- 1
- 2
- 3
- 4
long_range:
- 10
batch_size:
- 1
- 8
scheduler:
- true
model_id:
- "L"
- "S"
kernel_size:
- 5
- 3
synthetic:
- 1.0
drop_slice_prob:
Expand All @@ -32,33 +32,34 @@ params:
affine:
- 0.5
n_steps:
- 1_000_000
- 50000
small_size:
- 256
- 128
data_setting:
#- "base"
#- "liconn"
#- "multichannel"
#- "neg_guidance"
#- "no_touch_thick"
#- "pos_guidance"
#- "slice_perturbed"
#- "touching_thin"
- "base"
- "liconn"
- "multichannel"
- "neg_guidance"
- "no_touch_thick"
- "pos_guidance"
- "slice_perturbed"
- "touching_thin"
- "train_100"
base_data_path:
- "/cajal/nvmescratch/projects/NISB/"
save_path:
#- "/cajal/scratch/projects/misc/riegerfr/aff_nis/"
- "/cajal/scratch/projects/misc/zuzur/xl_banis"
- "/cajal/scratch/projects/misc/riegerfr/aff_nis/"
exp_name:
- "xl_test"
- "exp"
real_data_path: #https://colab.research.google.com/github/funkelab/lsd/blob/master/lsd/tutorial/notebooks/lsd_data_download.ipynb
- "/cajal/scratch/projects/misc/mdraw/data/funke/zebrafinch/training/"
auto_resubmit:
- True
- False
distributed:
- True
compile:
- False
compile:
- True
validate_extern:
- True
- True
augment:
- True
5 changes: 5 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- bzip2=1.0.8
- ca-certificates=2024.8.30
- cython==3.0.11
- dask=2025.7.0
- ld_impl_linux-64=2.43
- libexpat=2.6.3
- libffi=3.4.2
Expand Down Expand Up @@ -40,9 +41,11 @@ dependencies:
- batchgenerators==0.25
- certifi==2024.8.30
- charset-normalizer==3.4.0
- cloud_volume==12.4.1
- connected-components-3d==3.19.0
- contourpy==1.3.0
- cycler==0.12.1
- dask_jobqueue==0.9.0
- dicom2nifti==2.5.0
- fasteners==0.19
- filelock==3.16.1
Expand All @@ -68,7 +71,9 @@ dependencies:
- monai==1.3.2
- mpmath==1.3.0
- multidict==6.1.0
- mwatershed==0.5.3
- networkx==3.3
- neuroglancer==2.40.1
- nibabel==5.3.0
- numba==0.60.0
- numcodecs==0.13.1
Expand Down
Loading