From 55aca42ff6b48921f72eace08f4e14acd9aceb48 Mon Sep 17 00:00:00 2001 From: LlambiasMBP Date: Thu, 29 Jan 2026 15:24:27 +0100 Subject: [PATCH] add missing param to auto-run inference --- .../lightning_modules/segmentation_module.py | 8 ++--- asparagus/pipeline/run/test_seg.py | 1 - asparagus/pipeline/run/train_seg.py | 1 + .../development/runs_train/DEBUG_SEG.yaml | 30 +++++++++++++++++++ pyproject.toml | 2 +- 5 files changed, 34 insertions(+), 8 deletions(-) create mode 100644 configs/projects/development/runs_train/DEBUG_SEG.yaml diff --git a/asparagus/modules/lightning_modules/segmentation_module.py b/asparagus/modules/lightning_modules/segmentation_module.py index bdb01c0..b682bcf 100644 --- a/asparagus/modules/lightning_modules/segmentation_module.py +++ b/asparagus/modules/lightning_modules/segmentation_module.py @@ -219,12 +219,10 @@ def on_test_epoch_start(self): def test_step(self, batch, batch_idx): x = batch["image"] - logits = self.model.predict( - mode=self.inference_mode, + logits = self.model.sliding_window_predict( data=x, patch_size=fit_patch_size_to_image_size(self.inference_patch_size, list(x.shape[2:])), overlap=0.5, - sliding_window_prediction=True, ) src_logits = reverse_preprocessing(logits, batch["properties"]) @@ -249,12 +247,10 @@ def on_test_epoch_end(self): def predict_step(self, batch, batch_idx): x = batch["image"] - logits = self.model.predict( - mode=self.inference_mode, + logits = self.model.sliding_window_predict( data=x, patch_size=self.inference_patch_size, overlap=0.5, - sliding_window_prediction=True, ) logits = reverse_preprocessing( array=logits, diff --git a/asparagus/pipeline/run/test_seg.py b/asparagus/pipeline/run/test_seg.py index 0458f00..de8aaf2 100644 --- a/asparagus/pipeline/run/test_seg.py +++ b/asparagus/pipeline/run/test_seg.py @@ -51,7 +51,6 @@ def main(cfg: DictConfig) -> None: ckpt_cfg.lightning._lightning_module, model=model, weights=path_store.ckpt_path, - inference_mode=ckpt_cfg.model.dimensions, inference_patch_size=ckpt_cfg.training.patch_size, test_output_path=output_path, ) diff --git a/asparagus/pipeline/run/train_seg.py b/asparagus/pipeline/run/train_seg.py index b352549..90c7d7c 100644 --- a/asparagus/pipeline/run/train_seg.py +++ b/asparagus/pipeline/run/train_seg.py @@ -115,6 +115,7 @@ def main(cfg: DictConfig) -> None: optimizer=cfg.model.train_optim, learning_rate=cfg.model.train_lr, deep_supervision=cfg.model.deep_supervision, + inference_patch_size=cfg.training.patch_size, test_output_path=os.path.join( path_store.run_dir, "predictions", diff --git a/configs/projects/development/runs_train/DEBUG_SEG.yaml b/configs/projects/development/runs_train/DEBUG_SEG.yaml new file mode 100644 index 0000000..80cfdd7 --- /dev/null +++ b/configs/projects/development/runs_train/DEBUG_SEG.yaml @@ -0,0 +1,30 @@ +# @package _global_ +defaults: + - /core/base@ + - /default_train_seg@ + - /hardware/cpu@hardware + - /model/unet_tiny@model + - _self_ + +task: +root: base +stem: debug + +model: + dimensions: 3D + +training: + patch_size: [32, 32, 32] + batch_size: 2 + train_batches_per_epoch_per_device: 10 + val_batches_per_epoch_per_device: 5 + epochs: 5 + +logger: + progress_bar: True + profile: False + wandb_log_model: False + wandb_logging: False + mlflow_logging: False + log_every_n_steps: 1 + log_images_every_n_epoch: 999999 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 46fcdce..0624e10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ keywords = ['deep learning', 'medical image analysis','foundation models'] dependencies = [ - "gardening_tools>=0.1.1", + "gardening_tools>=0.2.0", "lightning==2.4.0", "nibabel>=5.3.2", "numpy>=1.23.1",