Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions yucca/documentation/templates/FOMO25/FOMO25_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from yucca.pipeline.managers.YuccaManagerV2 import YuccaManagerV2

manager = YuccaManagerV2
modelfile = "YOUR_PATH_HERE.ckpt"
experiment = "default"
model_name = "UNet"
model_dimensions = "3D"
source_task = "YOUR_TASK_HERE"
planner = "YuccaPlannerV2"
74 changes: 74 additions & 0 deletions yucca/documentation/templates/FOMO25/FOMO25_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from yucca.modules.callbacks.prediction_writer import WriteSinglePredictionFromLogits


def predict(input_path, output_path):
from config import (
manager,
modelfile,
experiment,
model_name,
model_dimensions,
source_task,
planner,
)
from yucca.modules.data.datasets.alternative_datasets.SingleFileDataset import (
SingleFileTestDataset,
)

output_path = output_path.replace(".nii.gz", "")
manager = manager(
ckpt_path=modelfile,
enable_logging=False,
experiment=experiment,
num_workers=0,
model_name=model_name,
model_dimensions=model_dimensions,
task=source_task,
planner=planner,
)

manager.batch_size = 1
manager.test_dataset_class = SingleFileTestDataset
manager.initialize(
stage="predict",
disable_tta=True,
disable_inference_preprocessing=False,
overwrite_predictions=True,
pred_data_dir=input_path,
prediction_output_dir=output_path,
save_softmax=False,
)

manager.trainer.callbacks = [
WriteSinglePredictionFromLogits(
output_path=output_path,
multilabel=False,
save_softmax=False,
write_interval="batch",
)
]

manager.trainer.predict(
model=manager.model_module,
dataloaders=manager.data_module,
ckpt_path=modelfile,
return_predictions=False,
)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--inputpath",
help="The path to a SINGLE file",
default="",
)
parser.add_argument(
"--outputpath",
default="",
)
args = parser.parse_args()

predict(args.inputpath, args.outputpath)
34 changes: 5 additions & 29 deletions yucca/documentation/templates/functional_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch
from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p as ensure_dir_exists
from yucca.paths import (
get_models_path,
get_results_path,
get_preprocessed_data_path,
get_raw_data_path,
)
Expand All @@ -14,37 +12,15 @@
from yucca.modules.data.data_modules.YuccaDataModule import YuccaDataModule
from yucca.modules.data.datasets.YuccaDataset import YuccaTestPreprocessedDataset
from yucca.pipeline.evaluation.YuccaEvaluator import YuccaEvaluator
from yucca.documentation.templates.template_config import config

ckpt_path = os.path.join(
get_models_path(),
config["task"],
config["model_name"] + "__" + config["model_dimensions"],
"__" + config["config_name"],
"default",
"kfold_5_fold_0",
"version_0",
"checkpoints",
"last.ckpt",
)
from yucca.documentation.templates.template_config import config, ckpt_path, inference_save_path

gt_path = os.path.join(get_raw_data_path(), config["task"], "labelsTs")
target_data_path = os.path.join(get_preprocessed_data_path(), config["task"] + "_test", "demo")

save_path = os.path.join(
get_results_path(),
config["task"],
config["task"],
config["model_name"] + "__" + config["model_dimensions"],
"__" + config["config_name"],
"kfold_5_fold_0",
"version_0",
"best",
)
ensure_dir_exists(save_path)
ensure_dir_exists(inference_save_path)

ckpt = torch.load(ckpt_path, map_location="cpu")
pred_writer = WritePredictionFromLogits(output_dir=save_path, save_softmax=False, write_interval="batch")
pred_writer = WritePredictionFromLogits(output_dir=inference_save_path, save_softmax=False, write_interval="batch")

model_module = BaseLightningModule(
model=config["model"],
Expand All @@ -59,7 +35,7 @@
data_module = YuccaDataModule(
batch_size=config["batch_size"],
patch_size=config["patch_size"],
pred_save_dir=save_path,
pred_save_dir=inference_save_path,
pred_data_dir=target_data_path,
overwrite_predictions=True,
image_extension=".nii.gz",
Expand All @@ -82,7 +58,7 @@
evaluator = YuccaEvaluator(
labels=config["classes"],
folder_with_ground_truth=gt_path,
folder_with_predictions=save_path,
folder_with_predictions=inference_save_path,
use_wandb=False,
)
evaluator.run()
44 changes: 38 additions & 6 deletions yucca/documentation/templates/template_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import os
from yucca.modules.networks.networks import TinyUNet
from yucca.modules.optimization.loss_functions.nnUNet_losses import DiceCE
from yucca.paths import (
get_models_path,
get_results_path,
get_preprocessed_data_path,
get_raw_data_path,
)

model = TinyUNet
classes = [0, 1]
modalities = ("MRI",)

config = {
"batch_size": 2,
"classes": [0, 1],
"classes": classes,
"config_name": "demo",
"crop_to_nonzero": True,
"continue_from_most_recent": True,
Expand All @@ -15,11 +24,14 @@
"learning_rate": 1e-3,
"loss_fn": DiceCE,
"max_epochs": 2,
"modalities": ("MRI",),
"modalities": modalities,
"model_dimensions": "2D",
"model": TinyUNet,
"model": model,
"model_name": model.__name__,
"momentum": 0.99,
"norm_op": "volume_wise_znorm",
"num_classes": len(classes),
"num_modalities": len(modalities),
"patch_size": (32, 32),
"plans": None,
"split_idx": 0,
Expand All @@ -32,6 +44,26 @@
"task_type": "segmentation",
}

config["model_name"] = config["model"].__name__
config["num_classes"] = len(config["classes"])
config["num_modalities"] = len(config["modalities"])

ckpt_path = os.path.join(
get_models_path(),
config["task"],
config["model_name"] + "__" + config["model_dimensions"],
"__" + config["config_name"],
"default",
"kfold_5_fold_0",
"version_0",
"checkpoints",
"last.ckpt",
)

inference_save_path = os.path.join(
get_results_path(),
config["task"],
config["task"],
config["model_name"] + "__" + config["model_dimensions"],
"__" + config["config_name"],
"kfold_5_fold_0",
"version_0",
"best",
)
48 changes: 48 additions & 0 deletions yucca/modules/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,51 @@ def write_on_batch_end(self, _trainer, _pl_module, data_dict, _batch_indices, _b
save_softmax=self.save_softmax,
)
del data_dict


class WriteSinglePredictionFromLogits(BasePredictionWriter):
# Saves input at the exact specified path. Does not take an output directory and does not use
# the filename stored in the data_dict object.
def __init__(
self,
output_path,
multilabel: bool = False,
save_softmax: bool = False,
write_interval: str = "batch",
):
super().__init__(write_interval)
self.output_path = output_path
self.multilabel = multilabel
self.save_softmax = save_softmax

def write_on_batch_end(
self,
_trainer,
_pl_module,
data_dict,
_batch_indices,
_batch,
_batch_idx,
_dataloader_idx,
):
# this will create N (num processes) files in `output_dir` each containing
# the predictions of it's respective rank
logits, properties, case_id = (
data_dict["logits"],
data_dict["properties"],
data_dict["case_id"],
)
if self.multilabel:
save_multilabel_prediction_from_logits(
logits,
self.output_path,
properties=properties,
)
else:
save_prediction_from_logits(
logits,
self.output_path,
properties=properties,
save_softmax=self.save_softmax,
)
del data_dict
Loading