|
4 | 4 | import torch |
5 | 5 | from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p as ensure_dir_exists |
6 | 6 | from yucca.paths import ( |
7 | | - get_models_path, |
8 | | - get_results_path, |
9 | 7 | get_preprocessed_data_path, |
10 | 8 | get_raw_data_path, |
11 | 9 | ) |
|
14 | 12 | from yucca.modules.data.data_modules.YuccaDataModule import YuccaDataModule |
15 | 13 | from yucca.modules.data.datasets.YuccaDataset import YuccaTestPreprocessedDataset |
16 | 14 | from yucca.pipeline.evaluation.YuccaEvaluator import YuccaEvaluator |
17 | | - from yucca.documentation.templates.template_config import config |
18 | | - |
19 | | - ckpt_path = os.path.join( |
20 | | - get_models_path(), |
21 | | - config["task"], |
22 | | - config["model_name"] + "__" + config["model_dimensions"], |
23 | | - "__" + config["config_name"], |
24 | | - "default", |
25 | | - "kfold_5_fold_0", |
26 | | - "version_0", |
27 | | - "checkpoints", |
28 | | - "last.ckpt", |
29 | | - ) |
| 15 | + from yucca.documentation.templates.template_config import config, ckpt_path, inference_save_path |
30 | 16 |
|
31 | 17 | gt_path = os.path.join(get_raw_data_path(), config["task"], "labelsTs") |
32 | 18 | target_data_path = os.path.join(get_preprocessed_data_path(), config["task"] + "_test", "demo") |
33 | 19 |
|
34 | | - save_path = os.path.join( |
35 | | - get_results_path(), |
36 | | - config["task"], |
37 | | - config["task"], |
38 | | - config["model_name"] + "__" + config["model_dimensions"], |
39 | | - "__" + config["config_name"], |
40 | | - "kfold_5_fold_0", |
41 | | - "version_0", |
42 | | - "best", |
43 | | - ) |
44 | | - ensure_dir_exists(save_path) |
| 20 | + ensure_dir_exists(inference_save_path) |
45 | 21 |
|
46 | 22 | ckpt = torch.load(ckpt_path, map_location="cpu") |
47 | | - pred_writer = WritePredictionFromLogits(output_dir=save_path, save_softmax=False, write_interval="batch") |
| 23 | + pred_writer = WritePredictionFromLogits(output_dir=inference_save_path, save_softmax=False, write_interval="batch") |
48 | 24 |
|
49 | 25 | model_module = BaseLightningModule( |
50 | 26 | model=config["model"], |
|
59 | 35 | data_module = YuccaDataModule( |
60 | 36 | batch_size=config["batch_size"], |
61 | 37 | patch_size=config["patch_size"], |
62 | | - pred_save_dir=save_path, |
| 38 | + pred_save_dir=inference_save_path, |
63 | 39 | pred_data_dir=target_data_path, |
64 | 40 | overwrite_predictions=True, |
65 | 41 | image_extension=".nii.gz", |
|
82 | 58 | evaluator = YuccaEvaluator( |
83 | 59 | labels=config["classes"], |
84 | 60 | folder_with_ground_truth=gt_path, |
85 | | - folder_with_predictions=save_path, |
| 61 | + folder_with_predictions=inference_save_path, |
86 | 62 | use_wandb=False, |
87 | 63 | ) |
88 | 64 | evaluator.run() |
0 commit comments