Skip to content

Commit 55aca42

Browse files
committed
add missing param to auto-run inference
1 parent 4d92a40 commit 55aca42

5 files changed

Lines changed: 34 additions & 8 deletions

File tree

asparagus/modules/lightning_modules/segmentation_module.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,10 @@ def on_test_epoch_start(self):
219219
def test_step(self, batch, batch_idx):
220220
x = batch["image"]
221221

222-
logits = self.model.predict(
223-
mode=self.inference_mode,
222+
logits = self.model.sliding_window_predict(
224223
data=x,
225224
patch_size=fit_patch_size_to_image_size(self.inference_patch_size, list(x.shape[2:])),
226225
overlap=0.5,
227-
sliding_window_prediction=True,
228226
)
229227

230228
src_logits = reverse_preprocessing(logits, batch["properties"])
@@ -249,12 +247,10 @@ def on_test_epoch_end(self):
249247

250248
def predict_step(self, batch, batch_idx):
251249
x = batch["image"]
252-
logits = self.model.predict(
253-
mode=self.inference_mode,
250+
logits = self.model.sliding_window_predict(
254251
data=x,
255252
patch_size=self.inference_patch_size,
256253
overlap=0.5,
257-
sliding_window_prediction=True,
258254
)
259255
logits = reverse_preprocessing(
260256
array=logits,

asparagus/pipeline/run/test_seg.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def main(cfg: DictConfig) -> None:
5151
ckpt_cfg.lightning._lightning_module,
5252
model=model,
5353
weights=path_store.ckpt_path,
54-
inference_mode=ckpt_cfg.model.dimensions,
5554
inference_patch_size=ckpt_cfg.training.patch_size,
5655
test_output_path=output_path,
5756
)

asparagus/pipeline/run/train_seg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def main(cfg: DictConfig) -> None:
115115
optimizer=cfg.model.train_optim,
116116
learning_rate=cfg.model.train_lr,
117117
deep_supervision=cfg.model.deep_supervision,
118+
inference_patch_size=cfg.training.patch_size,
118119
test_output_path=os.path.join(
119120
path_store.run_dir,
120121
"predictions",
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# @package _global_
2+
defaults:
3+
- /core/base@
4+
- /default_train_seg@
5+
- /hardware/cpu@hardware
6+
- /model/unet_tiny@model
7+
- _self_
8+
9+
task:
10+
root: base
11+
stem: debug
12+
13+
model:
14+
dimensions: 3D
15+
16+
training:
17+
patch_size: [32, 32, 32]
18+
batch_size: 2
19+
train_batches_per_epoch_per_device: 10
20+
val_batches_per_epoch_per_device: 5
21+
epochs: 5
22+
23+
logger:
24+
progress_bar: True
25+
profile: False
26+
wandb_log_model: False
27+
wandb_logging: False
28+
mlflow_logging: False
29+
log_every_n_steps: 1
30+
log_images_every_n_epoch: 999999

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ keywords = ['deep learning', 'medical image analysis','foundation models']
1616

1717

1818
dependencies = [
19-
"gardening_tools>=0.1.1",
19+
"gardening_tools>=0.2.0",
2020
"lightning==2.4.0",
2121
"nibabel>=5.3.2",
2222
"numpy>=1.23.1",

0 commit comments

Comments
 (0)