Skip to content

Commit cc37244

Browse files
committed
Add FOMO25 Templates and polish current functional templates
1 parent 1b810fd commit cc37244

5 files changed

Lines changed: 174 additions & 35 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from yucca.pipeline.managers.YuccaManagerV2 import YuccaManagerV2
2+
3+
manager = YuccaManagerV2
4+
modelfile = "YOUR_PATH_HERE.ckpt"
5+
experiment = "default"
6+
model_name = "UNet"
7+
model_dimensions = "3D"
8+
source_task = "YOUR_TASK_HERE"
9+
planner = "YuccaPlannerV2"
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from yucca.modules.callbacks.prediction_writer import WriteSinglePredictionFromLogits
2+
3+
4+
def predict(input_path, output_path):
5+
from config import (
6+
manager,
7+
modelfile,
8+
experiment,
9+
model_name,
10+
model_dimensions,
11+
source_task,
12+
planner,
13+
)
14+
from yucca.modules.data.datasets.alternative_datasets.SingleFileDataset import (
15+
SingleFileTestDataset,
16+
)
17+
18+
output_path = output_path.replace(".nii.gz", "")
19+
manager = manager(
20+
ckpt_path=modelfile,
21+
enable_logging=False,
22+
experiment=experiment,
23+
num_workers=0,
24+
model_name=model_name,
25+
model_dimensions=model_dimensions,
26+
task=source_task,
27+
planner=planner,
28+
)
29+
30+
manager.batch_size = 1
31+
manager.test_dataset_class = SingleFileTestDataset
32+
manager.initialize(
33+
stage="predict",
34+
disable_tta=True,
35+
disable_inference_preprocessing=False,
36+
overwrite_predictions=True,
37+
pred_data_dir=input_path,
38+
prediction_output_dir=output_path,
39+
save_softmax=False,
40+
)
41+
42+
manager.trainer.callbacks = [
43+
WriteSinglePredictionFromLogits(
44+
output_path=output_path,
45+
multilabel=False,
46+
save_softmax=False,
47+
write_interval="batch",
48+
)
49+
]
50+
51+
manager.trainer.predict(
52+
model=manager.model_module,
53+
dataloaders=manager.data_module,
54+
ckpt_path=modelfile,
55+
return_predictions=False,
56+
)
57+
58+
59+
if __name__ == "__main__":
60+
import argparse
61+
62+
parser = argparse.ArgumentParser()
63+
parser.add_argument(
64+
"--inputpath",
65+
help="The path to a SINGLE file",
66+
default="",
67+
)
68+
parser.add_argument(
69+
"--outputpath",
70+
default="",
71+
)
72+
args = parser.parse_args()
73+
74+
predict(args.inputpath, args.outputpath)

yucca/documentation/templates/functional_inference.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import torch
55
from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p as ensure_dir_exists
66
from yucca.paths import (
7-
get_models_path,
8-
get_results_path,
97
get_preprocessed_data_path,
108
get_raw_data_path,
119
)
@@ -14,37 +12,15 @@
1412
from yucca.modules.data.data_modules.YuccaDataModule import YuccaDataModule
1513
from yucca.modules.data.datasets.YuccaDataset import YuccaTestPreprocessedDataset
1614
from yucca.pipeline.evaluation.YuccaEvaluator import YuccaEvaluator
17-
from yucca.documentation.templates.template_config import config
18-
19-
ckpt_path = os.path.join(
20-
get_models_path(),
21-
config["task"],
22-
config["model_name"] + "__" + config["model_dimensions"],
23-
"__" + config["config_name"],
24-
"default",
25-
"kfold_5_fold_0",
26-
"version_0",
27-
"checkpoints",
28-
"last.ckpt",
29-
)
15+
from yucca.documentation.templates.template_config import config, ckpt_path, inference_save_path
3016

3117
gt_path = os.path.join(get_raw_data_path(), config["task"], "labelsTs")
3218
target_data_path = os.path.join(get_preprocessed_data_path(), config["task"] + "_test", "demo")
3319

34-
save_path = os.path.join(
35-
get_results_path(),
36-
config["task"],
37-
config["task"],
38-
config["model_name"] + "__" + config["model_dimensions"],
39-
"__" + config["config_name"],
40-
"kfold_5_fold_0",
41-
"version_0",
42-
"best",
43-
)
44-
ensure_dir_exists(save_path)
20+
ensure_dir_exists(inference_save_path)
4521

4622
ckpt = torch.load(ckpt_path, map_location="cpu")
47-
pred_writer = WritePredictionFromLogits(output_dir=save_path, save_softmax=False, write_interval="batch")
23+
pred_writer = WritePredictionFromLogits(output_dir=inference_save_path, save_softmax=False, write_interval="batch")
4824

4925
model_module = BaseLightningModule(
5026
model=config["model"],
@@ -59,7 +35,7 @@
5935
data_module = YuccaDataModule(
6036
batch_size=config["batch_size"],
6137
patch_size=config["patch_size"],
62-
pred_save_dir=save_path,
38+
pred_save_dir=inference_save_path,
6339
pred_data_dir=target_data_path,
6440
overwrite_predictions=True,
6541
image_extension=".nii.gz",
@@ -82,7 +58,7 @@
8258
evaluator = YuccaEvaluator(
8359
labels=config["classes"],
8460
folder_with_ground_truth=gt_path,
85-
folder_with_predictions=save_path,
61+
folder_with_predictions=inference_save_path,
8662
use_wandb=False,
8763
)
8864
evaluator.run()

yucca/documentation/templates/template_config.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1+
import os
12
from yucca.modules.networks.networks import TinyUNet
23
from yucca.modules.optimization.loss_functions.nnUNet_losses import DiceCE
4+
from yucca.paths import (
5+
get_models_path,
6+
get_results_path,
7+
get_preprocessed_data_path,
8+
get_raw_data_path,
9+
)
310

411
model = TinyUNet
12+
classes = [0, 1]
13+
modalities = ("MRI",)
514

615
config = {
716
"batch_size": 2,
8-
"classes": [0, 1],
17+
"classes": classes,
918
"config_name": "demo",
1019
"crop_to_nonzero": True,
1120
"continue_from_most_recent": True,
@@ -15,11 +24,14 @@
1524
"learning_rate": 1e-3,
1625
"loss_fn": DiceCE,
1726
"max_epochs": 2,
18-
"modalities": ("MRI",),
27+
"modalities": modalities,
1928
"model_dimensions": "2D",
20-
"model": TinyUNet,
29+
"model": model,
30+
"model_name": model.__name__,
2131
"momentum": 0.99,
2232
"norm_op": "volume_wise_znorm",
33+
"num_classes": len(classes),
34+
"num_modalities": len(modalities),
2335
"patch_size": (32, 32),
2436
"plans": None,
2537
"split_idx": 0,
@@ -32,6 +44,26 @@
3244
"task_type": "segmentation",
3345
}
3446

35-
config["model_name"] = config["model"].__name__
36-
config["num_classes"] = len(config["classes"])
37-
config["num_modalities"] = len(config["modalities"])
47+
48+
ckpt_path = os.path.join(
49+
get_models_path(),
50+
config["task"],
51+
config["model_name"] + "__" + config["model_dimensions"],
52+
"__" + config["config_name"],
53+
"default",
54+
"kfold_5_fold_0",
55+
"version_0",
56+
"checkpoints",
57+
"last.ckpt",
58+
)
59+
60+
inference_save_path = os.path.join(
61+
get_results_path(),
62+
config["task"],
63+
config["task"],
64+
config["model_name"] + "__" + config["model_dimensions"],
65+
"__" + config["config_name"],
66+
"kfold_5_fold_0",
67+
"version_0",
68+
"best",
69+
)

yucca/modules/callbacks/prediction_writer.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,51 @@ def write_on_batch_end(self, _trainer, _pl_module, data_dict, _batch_indices, _b
3232
save_softmax=self.save_softmax,
3333
)
3434
del data_dict
35+
36+
37+
class WriteSinglePredictionFromLogits(BasePredictionWriter):
38+
# Saves input at the exact specified path. Does not take an output directory and does not use
39+
# the filename stored in the data_dict object.
40+
def __init__(
41+
self,
42+
output_path,
43+
multilabel: bool = False,
44+
save_softmax: bool = False,
45+
write_interval: str = "batch",
46+
):
47+
super().__init__(write_interval)
48+
self.output_path = output_path
49+
self.multilabel = multilabel
50+
self.save_softmax = save_softmax
51+
52+
def write_on_batch_end(
53+
self,
54+
_trainer,
55+
_pl_module,
56+
data_dict,
57+
_batch_indices,
58+
_batch,
59+
_batch_idx,
60+
_dataloader_idx,
61+
):
62+
# this will create N (num processes) files in `output_dir` each containing
63+
# the predictions of it's respective rank
64+
logits, properties, case_id = (
65+
data_dict["logits"],
66+
data_dict["properties"],
67+
data_dict["case_id"],
68+
)
69+
if self.multilabel:
70+
save_multilabel_prediction_from_logits(
71+
logits,
72+
self.output_path,
73+
properties=properties,
74+
)
75+
else:
76+
save_prediction_from_logits(
77+
logits,
78+
self.output_path,
79+
properties=properties,
80+
save_softmax=self.save_softmax,
81+
)
82+
del data_dict

0 commit comments

Comments
 (0)