diff --git a/yucca/documentation/templates/FOMO25/FOMO25_config.py b/yucca/documentation/templates/FOMO25/FOMO25_config.py new file mode 100644 index 00000000..c879793b --- /dev/null +++ b/yucca/documentation/templates/FOMO25/FOMO25_config.py @@ -0,0 +1,9 @@ +from yucca.pipeline.managers.YuccaManagerV2 import YuccaManagerV2 + +manager = YuccaManagerV2 +modelfile = "YOUR_PATH_HERE.ckpt" +experiment = "default" +model_name = "UNet" +model_dimensions = "3D" +source_task = "YOUR_TASK_HERE" +planner = "YuccaPlannerV2" diff --git a/yucca/documentation/templates/FOMO25/FOMO25_predict.py b/yucca/documentation/templates/FOMO25/FOMO25_predict.py new file mode 100644 index 00000000..030b23ea --- /dev/null +++ b/yucca/documentation/templates/FOMO25/FOMO25_predict.py @@ -0,0 +1,74 @@ +from yucca.modules.callbacks.prediction_writer import WriteSinglePredictionFromLogits + + +def predict(input_path, output_path): + from config import ( + manager, + modelfile, + experiment, + model_name, + model_dimensions, + source_task, + planner, + ) + from yucca.modules.data.datasets.alternative_datasets.SingleFileDataset import ( + SingleFileTestDataset, + ) + + output_path = output_path.replace(".nii.gz", "") + manager = manager( + ckpt_path=modelfile, + enable_logging=False, + experiment=experiment, + num_workers=0, + model_name=model_name, + model_dimensions=model_dimensions, + task=source_task, + planner=planner, + ) + + manager.batch_size = 1 + manager.test_dataset_class = SingleFileTestDataset + manager.initialize( + stage="predict", + disable_tta=True, + disable_inference_preprocessing=False, + overwrite_predictions=True, + pred_data_dir=input_path, + prediction_output_dir=output_path, + save_softmax=False, + ) + + manager.trainer.callbacks = [ + WriteSinglePredictionFromLogits( + output_path=output_path, + multilabel=False, + save_softmax=False, + write_interval="batch", + ) + ] + + manager.trainer.predict( + model=manager.model_module, + dataloaders=manager.data_module, + ckpt_path=modelfile, + return_predictions=False, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--inputpath", + help="The path to a SINGLE file", + default="", + ) + parser.add_argument( + "--outputpath", + default="", + ) + args = parser.parse_args() + + predict(args.inputpath, args.outputpath) diff --git a/yucca/documentation/templates/functional_inference.py b/yucca/documentation/templates/functional_inference.py index 424d1822..534b7468 100644 --- a/yucca/documentation/templates/functional_inference.py +++ b/yucca/documentation/templates/functional_inference.py @@ -4,8 +4,6 @@ import torch from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p as ensure_dir_exists from yucca.paths import ( - get_models_path, - get_results_path, get_preprocessed_data_path, get_raw_data_path, ) @@ -14,37 +12,15 @@ from yucca.modules.data.data_modules.YuccaDataModule import YuccaDataModule from yucca.modules.data.datasets.YuccaDataset import YuccaTestPreprocessedDataset from yucca.pipeline.evaluation.YuccaEvaluator import YuccaEvaluator - from yucca.documentation.templates.template_config import config - - ckpt_path = os.path.join( - get_models_path(), - config["task"], - config["model_name"] + "__" + config["model_dimensions"], - "__" + config["config_name"], - "default", - "kfold_5_fold_0", - "version_0", - "checkpoints", - "last.ckpt", - ) + from yucca.documentation.templates.template_config import config, ckpt_path, inference_save_path gt_path = os.path.join(get_raw_data_path(), config["task"], "labelsTs") target_data_path = os.path.join(get_preprocessed_data_path(), config["task"] + "_test", "demo") - save_path = os.path.join( - get_results_path(), - config["task"], - config["task"], - config["model_name"] + "__" + config["model_dimensions"], - "__" + config["config_name"], - "kfold_5_fold_0", - "version_0", - "best", - ) - ensure_dir_exists(save_path) + ensure_dir_exists(inference_save_path) ckpt = torch.load(ckpt_path, map_location="cpu") - pred_writer = WritePredictionFromLogits(output_dir=save_path, save_softmax=False, write_interval="batch") + pred_writer = WritePredictionFromLogits(output_dir=inference_save_path, save_softmax=False, write_interval="batch") model_module = BaseLightningModule( model=config["model"], @@ -59,7 +35,7 @@ data_module = YuccaDataModule( batch_size=config["batch_size"], patch_size=config["patch_size"], - pred_save_dir=save_path, + pred_save_dir=inference_save_path, pred_data_dir=target_data_path, overwrite_predictions=True, image_extension=".nii.gz", @@ -82,7 +58,7 @@ evaluator = YuccaEvaluator( labels=config["classes"], folder_with_ground_truth=gt_path, - folder_with_predictions=save_path, + folder_with_predictions=inference_save_path, use_wandb=False, ) evaluator.run() diff --git a/yucca/documentation/templates/template_config.py b/yucca/documentation/templates/template_config.py index 2670bc25..0c8df68f 100644 --- a/yucca/documentation/templates/template_config.py +++ b/yucca/documentation/templates/template_config.py @@ -1,11 +1,20 @@ +import os from yucca.modules.networks.networks import TinyUNet from yucca.modules.optimization.loss_functions.nnUNet_losses import DiceCE +from yucca.paths import ( + get_models_path, + get_results_path, + get_preprocessed_data_path, + get_raw_data_path, +) model = TinyUNet +classes = [0, 1] +modalities = ("MRI",) config = { "batch_size": 2, - "classes": [0, 1], + "classes": classes, "config_name": "demo", "crop_to_nonzero": True, "continue_from_most_recent": True, @@ -15,11 +24,14 @@ "learning_rate": 1e-3, "loss_fn": DiceCE, "max_epochs": 2, - "modalities": ("MRI",), + "modalities": modalities, "model_dimensions": "2D", - "model": TinyUNet, + "model": model, + "model_name": model.__name__, "momentum": 0.99, "norm_op": "volume_wise_znorm", + "num_classes": len(classes), + "num_modalities": len(modalities), "patch_size": (32, 32), "plans": None, "split_idx": 0, @@ -32,6 +44,26 @@ "task_type": "segmentation", } -config["model_name"] = config["model"].__name__ -config["num_classes"] = len(config["classes"]) -config["num_modalities"] = len(config["modalities"]) + +ckpt_path = os.path.join( + get_models_path(), + config["task"], + config["model_name"] + "__" + config["model_dimensions"], + "__" + config["config_name"], + "default", + "kfold_5_fold_0", + "version_0", + "checkpoints", + "last.ckpt", +) + +inference_save_path = os.path.join( + get_results_path(), + config["task"], + config["task"], + config["model_name"] + "__" + config["model_dimensions"], + "__" + config["config_name"], + "kfold_5_fold_0", + "version_0", + "best", +) diff --git a/yucca/modules/callbacks/prediction_writer.py b/yucca/modules/callbacks/prediction_writer.py index 74445f1d..8f2f72f5 100644 --- a/yucca/modules/callbacks/prediction_writer.py +++ b/yucca/modules/callbacks/prediction_writer.py @@ -32,3 +32,51 @@ def write_on_batch_end(self, _trainer, _pl_module, data_dict, _batch_indices, _b save_softmax=self.save_softmax, ) del data_dict + + +class WriteSinglePredictionFromLogits(BasePredictionWriter): + # Saves input at the exact specified path. Does not take an output directory and does not use + # the filename stored in the data_dict object. + def __init__( + self, + output_path, + multilabel: bool = False, + save_softmax: bool = False, + write_interval: str = "batch", + ): + super().__init__(write_interval) + self.output_path = output_path + self.multilabel = multilabel + self.save_softmax = save_softmax + + def write_on_batch_end( + self, + _trainer, + _pl_module, + data_dict, + _batch_indices, + _batch, + _batch_idx, + _dataloader_idx, + ): + # this will create N (num processes) files in `output_dir` each containing + # the predictions of it's respective rank + logits, properties, case_id = ( + data_dict["logits"], + data_dict["properties"], + data_dict["case_id"], + ) + if self.multilabel: + save_multilabel_prediction_from_logits( + logits, + self.output_path, + properties=properties, + ) + else: + save_prediction_from_logits( + logits, + self.output_path, + properties=properties, + save_softmax=self.save_softmax, + ) + del data_dict