diff --git a/pyproject.toml b/pyproject.toml index 966e2b00..5c7f9c30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,25 +23,25 @@ dependencies = [ "flake8-unused-arguments>=0.0.13", "fvcore>=0.1.5", "lightning>=2.2.1", - "matplotlib>=3.8.3", - "monai>=1.3.0", "nibabel>=5.2.1", "numpy>=1.26.4", - "pandas>=2.2.1", "python-dotenv==1.0.0", "scikit_image>=0.22.0", "scikit_learn>=1.4.1.post1", - "seaborn>=0.13.2", - "SimpleITK>=2.3.1", "tqdm>=4.66.2", "timm>=0.9.8", "torchmetrics>=1.4.0.post0", "wandb>=0.16.3", - "weave>=0.39.0", ] [project.optional-dependencies] +extras = [ + "matplotlib>=3.8.3", + "monai>=1.3.0", + "SimpleITK>=2.3.1", + "pandas>=2.2.1", +] test = [ 'pytest>=7.3', 'flake8>=6.1.0', diff --git a/yucca/documentation/templates/functional_inference.py b/yucca/documentation/templates/functional_inference.py index 424d1822..f00c7e1c 100644 --- a/yucca/documentation/templates/functional_inference.py +++ b/yucca/documentation/templates/functional_inference.py @@ -2,7 +2,6 @@ import lightning as L import os import torch - from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p as ensure_dir_exists from yucca.paths import ( get_models_path, get_results_path, @@ -41,7 +40,7 @@ "version_0", "best", ) - ensure_dir_exists(save_path) + os.makedirs(save_path, exist_ok=True) ckpt = torch.load(ckpt_path, map_location="cpu") pred_writer = WritePredictionFromLogits(output_dir=save_path, save_softmax=False, write_interval="batch") diff --git a/yucca/documentation/templates/functional_preprocessing.py b/yucca/documentation/templates/functional_preprocessing.py index 64d78ced..5333ce30 100644 --- a/yucca/documentation/templates/functional_preprocessing.py +++ b/yucca/documentation/templates/functional_preprocessing.py @@ -3,26 +3,22 @@ import os import numpy as np import torch - from batchgenerators.utilities.file_and_folder_operations import ( - subfiles, - join, - save_pickle, - maybe_mkdir_p as ensure_dir_exists, - ) + from yucca.functional.utils.files_and_folders import subfiles + from yucca.functional.utils.saving import save_pickle from yucca.paths import get_raw_data_path, get_preprocessed_data_path from yucca.documentation.templates.template_config import config from yucca.functional.preprocessing import preprocess_case_for_training_with_label, preprocess_case_for_inference from yucca.functional.utils.loading import read_file_to_nifti_or_np - raw_images_dir = join(get_raw_data_path(), config["task"], "imagesTr") - raw_labels_dir = join(get_raw_data_path(), config["task"], "labelsTr") - test_raw_images_dir = join(get_raw_data_path(), config["task"], "imagesTs") + raw_images_dir = os.path.join(get_raw_data_path(), config["task"], "imagesTr") + raw_labels_dir = os.path.join(get_raw_data_path(), config["task"], "labelsTr") + test_raw_images_dir = os.path.join(get_raw_data_path(), config["task"], "imagesTs") - target_dir = join(get_preprocessed_data_path(), config["task"], config["config_name"]) - test_target_dir = join(get_preprocessed_data_path(), config["task"] + "_test", config["config_name"]) + target_dir = os.path.join(get_preprocessed_data_path(), config["task"], config["config_name"]) + test_target_dir = os.path.join(get_preprocessed_data_path(), config["task"] + "_test", config["config_name"]) - ensure_dir_exists(target_dir) - ensure_dir_exists(test_target_dir) + os.makedirs(target_dir, exist_ok=True) + os.makedirs(test_target_dir, exist_ok=True) # Preprocess the training data subjects = [file[: -len(config["extension"])] for file in subfiles(raw_labels_dir, join=False) if not file.startswith(".")] @@ -34,7 +30,7 @@ if re.search(re.escape(sub) + "_" + r"\d{3}" + ".", os.path.split(image_path)[-1]) ] images = [read_file_to_nifti_or_np(image) for image in images] - label = read_file_to_nifti_or_np(join(raw_labels_dir, sub + config["extension"])) + label = read_file_to_nifti_or_np(os.path.join(raw_labels_dir, sub + config["extension"])) images, label, image_props = preprocess_case_for_training_with_label( images=images, label=label, @@ -45,7 +41,7 @@ ) images = np.vstack((np.array(images), np.array(label)[np.newaxis]), dtype=np.float32) - save_path = join(target_dir, sub) + save_path = os.path.join(target_dir, sub) np.save(save_path + ".npy", images) save_pickle(image_props, save_path + ".pkl") @@ -73,6 +69,6 @@ target_spacing=config["target_spacing"], target_orientation=config["target_coordinate_system"], ) - save_path = join(test_target_dir, sub) + save_path = os.path.join(test_target_dir, sub) torch.save(images, save_path + ".pt") save_pickle(image_props, save_path + ".pkl") diff --git a/yucca/documentation/tests/transforms/dataloader.ipynb b/yucca/documentation/tests/transforms/dataloader.ipynb index 41ac220d..391b6c19 100644 --- a/yucca/documentation/tests/transforms/dataloader.ipynb +++ b/yucca/documentation/tests/transforms/dataloader.ipynb @@ -2,19 +2,20 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from yucca.modules.data.datasets.YuccaDataset import YuccaTrainDataset\n", "from yucca.paths import yucca_preprocessed_data\n", - "from batchgenerators.utilities.file_and_folder_operations import join, subfiles\n", + "from yucca.functional.utils.files_and_folders import subfiles\n", + "import os\n", "from torch.utils.data import DataLoader" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -28,7 +29,7 @@ "source": [ "from matplotlib import pyplot as plt\n", "\n", - "samples = subfiles(join(yucca_preprocessed_data(), \"Task299_Combine\", \"UnsupervisedPlanner\"), suffix=\".npy\")\n", + "samples = subfiles(os.path.join(yucca_preprocessed_data(), \"Task299_Combine\", \"UnsupervisedPlanner\"), suffix=\".npy\")\n", "\n", "\n", "dataset = YuccaTrainDataset(samples=samples, patch_size=(96,) * 3, composed_transforms=None, task_type=\"contrastive\")\n", @@ -93,7 +94,7 @@ "metadata": {}, "outputs": [], "source": [ - "samples = subfiles(join(yucca_preprocessed_data(), \"Task001_OASIS\", \"YuccaPlanner\"), suffix=\".npy\")\n", + "samples = subfiles(os.path.join(yucca_preprocessed_data(), \"Task001_OASIS\", \"YuccaPlanner\"), suffix=\".npy\")\n", "\n", "print(samples)\n", "\n", diff --git a/yucca/functional/evaluation/__init__.py b/yucca/functional/evaluation/__init__.py index 1203c213..ebffa940 100644 --- a/yucca/functional/evaluation/__init__.py +++ b/yucca/functional/evaluation/__init__.py @@ -1,4 +1,3 @@ -from .confusion_matrix import torch_confusion_matrix_from_logits, torch_get_tp_fp_tn_fn from .metrics import ( dice, dice_per_label, @@ -17,5 +16,4 @@ total_pos_gt, total_pos_pred, ) -from .obj_metrics import get_obj_stats_for_label, obj_get_tp_fp_fn_gtvols_predvols from .surface_metrics import get_surface_metrics_for_label diff --git a/yucca/functional/evaluation/evaluate_folder.py b/yucca/functional/evaluation/evaluate_folder.py index 9dd989b4..1a04b222 100644 --- a/yucca/functional/evaluation/evaluate_folder.py +++ b/yucca/functional/evaluation/evaluate_folder.py @@ -2,6 +2,7 @@ import numpy as np import nibabel as nib import logging +import os from typing import Optional from yucca.functional.transforms.label_transforms import convert_labels_to_regions, translate_region_labels from yucca.functional.utils.nib_utils import get_nib_spacing @@ -9,7 +10,6 @@ from yucca.functional.evaluation.obj_metrics import get_obj_stats_for_label from yucca.functional.evaluation.surface_metrics import get_surface_metrics_for_label from tqdm import tqdm -from batchgenerators.utilities.file_and_folder_operations import join from sklearn.metrics import confusion_matrix from yucca.functional.evaluation.metrics import auroc @@ -97,8 +97,8 @@ def evaluate_multilabel_case_segm( assert regions is not None case_dict = {} - predpath = join(folder_with_predictions, case) - gtpath = join(folder_with_ground_truth, case) + predpath = os.path.join(folder_with_predictions, case) + gtpath = os.path.join(folder_with_ground_truth, case) case_dict["prediction_path"] = predpath case_dict["ground_truth_path"] = gtpath @@ -184,8 +184,8 @@ def evaluate_case_segm( surface_tol: int = 1, ): case_dict = {} - predpath = join(folder_with_predictions, case) - gtpath = join(folder_with_ground_truth, case) + predpath = os.path.join(folder_with_predictions, case) + gtpath = os.path.join(folder_with_ground_truth, case) case_dict["prediction_path"] = predpath case_dict["ground_truth_path"] = gtpath @@ -269,8 +269,8 @@ def evaluate_folder_cls( # load predictions and ground truths for case in tqdm(subjects, desc="Evaluating"): - predpath = join(folder_with_predictions, case) - gtpath = join(folder_with_ground_truth, case) + predpath = os.path.join(folder_with_predictions, case) + gtpath = os.path.join(folder_with_ground_truth, case) pred: int = np.loadtxt(predpath) gt: int = np.loadtxt(gtpath) diff --git a/yucca/functional/planning.py b/yucca/functional/planning.py index 4b280a74..28b0f479 100644 --- a/yucca/functional/planning.py +++ b/yucca/functional/planning.py @@ -1,6 +1,7 @@ import numpy as np from typing import Optional, List, Union -from batchgenerators.utilities.file_and_folder_operations import subfiles, load_pickle +from yucca.functional.utils.files_and_folders import subfiles +from yucca.functional.utils.loading import load_pickle def make_plans_file( diff --git a/yucca/functional/utils/files_and_folders.py b/yucca/functional/utils/files_and_folders.py index 276e8e51..d54f1bfa 100644 --- a/yucca/functional/utils/files_and_folders.py +++ b/yucca/functional/utils/files_and_folders.py @@ -4,11 +4,29 @@ import re import shutil import os -from batchgenerators.utilities.file_and_folder_operations import ( - join, - subdirs, -) -from typing import Union, List +from typing import Union, List, Optional + + +def subdirs( + folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True +) -> List[str]: + """ + implementation by: https://github.com/MIC-DKFZ/batchgenerators + """ + subdirectories = [] + with os.scandir(folder) as entries: + for entry in entries: + if ( + entry.is_dir() + and (prefix is None or entry.name.startswith(prefix)) + and (suffix is None or entry.name.endswith(suffix)) + ): + dir_path = entry.path if join else entry.name + subdirectories.append(dir_path) + + if sort: + subdirectories.sort() + return subdirectories def replace_in_file(file_path, pattern_replacement): @@ -139,7 +157,7 @@ def _recursive_find_python_class(folder: list, class_name: str, current_module: if ispkg: next_current_module = current_module + "." + modname tr = _recursive_find_python_class( - [join(folder[0], modname)], + [os.path.join(folder[0], modname)], class_name, current_module=next_current_module, ) diff --git a/yucca/functional/utils/loading.py b/yucca/functional/utils/loading.py index f65de112..c2ea42d0 100644 --- a/yucca/functional/utils/loading.py +++ b/yucca/functional/utils/loading.py @@ -2,6 +2,7 @@ import os import nibabel as nib import numpy as np +import pickle from PIL import Image @@ -21,3 +22,14 @@ def read_file_to_nifti_or_np(imagepath, dtype=np.float32): return np.atleast_1d(np.genfromtxt(imagepath, delimiter=",", dtype=dtype)) else: raise TypeError(f"File type invalid. Found extension: {ext} and expected one in [nii, nii.gz, png, csv, txt]") + + +def load_pickle(file: str, mode: str = "rb"): + with open(file, mode) as f: + a = pickle.load(f) + return a + + +def load_json(p): + with open(p, "r") as f: + return json.load(f) diff --git a/yucca/functional/utils/saving.py b/yucca/functional/utils/saving.py index 7360b59f..5b8596d3 100644 --- a/yucca/functional/utils/saving.py +++ b/yucca/functional/utils/saving.py @@ -1,13 +1,16 @@ import nibabel as nib import numpy as np +import os +import pickle from yucca.functional.utils.softmax import softmax from yucca.functional.utils.nib_utils import reorient_nib_image +from yucca.functional.utils.file_and_folders import subfiles from PIL import Image -from batchgenerators.utilities.file_and_folder_operations import ( - join, - subfiles, - maybe_mkdir_p as ensure_dir_exists, -) + + +def save_pickle(obj, file: str, mode: str = "wb") -> None: + with open(file, mode) as f: + pickle.dump(obj, f) def save_nifti_from_numpy(pred, outpath, properties, compression=9): @@ -87,7 +90,7 @@ def save_multilabel_prediction_from_logits(logits, outpath, properties, compress def merge_softmax_from_folders(folders: list, outpath, method="sum"): - ensure_dir_exists(outpath) + os.makedirs(outpath, exists_ok=True) cases = subfiles(folders[0], suffix=".npz", join=False) for folder in folders: assert cases == subfiles(folder, suffix=".npz", join=False), ( @@ -98,7 +101,7 @@ def merge_softmax_from_folders(folders: list, outpath, method="sum"): ) for case in cases: - files_for_case = [np.load(join(folder, case), allow_pickle=True) for folder in folders] + files_for_case = [np.load(os.path.join(folder, case), allow_pickle=True) for folder in folders] properties_for_case = files_for_case[0]["properties"] files_for_case = [file["data"].astype(np.float32) for file in files_for_case] @@ -108,7 +111,7 @@ def merge_softmax_from_folders(folders: list, outpath, method="sum"): files_for_case = np.argmax(files_for_case, 0) save_nifti_from_numpy( files_for_case, - join(outpath, case[:-4]), + os.path.join(outpath, case[:-4]), properties=properties_for_case.item(), ) diff --git a/yucca/modules/callbacks/loggers.py b/yucca/modules/callbacks/loggers.py index 3e52f323..346ec959 100644 --- a/yucca/modules/callbacks/loggers.py +++ b/yucca/modules/callbacks/loggers.py @@ -3,15 +3,10 @@ import logging from argparse import Namespace from lightning.pytorch.loggers.logger import Logger -from pytorch_lightning.utilities.rank_zero import rank_zero_only -from pytorch_lightning.core.saving import save_hparams_to_yaml -from lightning_fabric.utilities.logger import _convert_params +from lightning.pytorch.utilities.rank_zero import rank_zero_only +from lightning.pytorch.core.saving import save_hparams_to_yaml +from lightning.fabric.utilities.logger import _convert_params from time import localtime, strftime, time -from batchgenerators.utilities.file_and_folder_operations import ( - join, - maybe_mkdir_p as ensure_dir_exists, - isdir, -) from typing import Any, Dict, Optional, Union @@ -60,18 +55,18 @@ def root_dir(self): def log_dir(self): log_dir = self.root_dir if self.name is not None: - log_dir = join(log_dir, self.name) + log_dir = os.path.join(log_dir, self.name) if self.version is not None: version = self.version if isinstance(self.version, str) else f"version_{self.version}" - log_dir = join(log_dir, version) - if not isdir(log_dir): - ensure_dir_exists(log_dir) + log_dir = os.path.join(log_dir, version) + if not os.path.isdir(log_dir): + os.makedirs(log_dir, exist_ok=True) return log_dir @rank_zero_only def create_logfile(self): - ensure_dir_exists(self.log_dir) - self.log_file = join( + os.makedirs(self.log_dir, exist_ok=True) + self.log_file = os.path.join( self.log_dir, "training_log.txt", ) diff --git a/yucca/modules/callbacks/prediction_writer.py b/yucca/modules/callbacks/prediction_writer.py index 74445f1d..6427e7a0 100644 --- a/yucca/modules/callbacks/prediction_writer.py +++ b/yucca/modules/callbacks/prediction_writer.py @@ -1,6 +1,6 @@ from lightning.pytorch.callbacks import BasePredictionWriter from yucca.functional.utils.saving import save_prediction_from_logits, save_multilabel_prediction_from_logits -from batchgenerators.utilities.file_and_folder_operations import join +import os class WritePredictionFromLogits(BasePredictionWriter): @@ -21,13 +21,13 @@ def write_on_batch_end(self, _trainer, _pl_module, data_dict, _batch_indices, _b if self.multilabel: save_multilabel_prediction_from_logits( logits, - join(self.output_dir, case_id), + os.path.join(self.output_dir, case_id), properties=properties, ) else: save_prediction_from_logits( logits, - join(self.output_dir, case_id), + os.path.join(self.output_dir, case_id), properties=properties, save_softmax=self.save_softmax, ) diff --git a/yucca/modules/data/augmentation/transforms/YuccaTransform.py b/yucca/modules/data/augmentation/transforms/YuccaTransform.py index be0ce6e2..845fba79 100644 --- a/yucca/modules/data/augmentation/transforms/YuccaTransform.py +++ b/yucca/modules/data/augmentation/transforms/YuccaTransform.py @@ -1,9 +1,22 @@ -from batchgenerators.transforms.abstract_transforms import AbstractTransform -from abc import abstractmethod +import abc + + +class AbstractTransform(object): + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def __call__(self, **data_dict): + raise NotImplementedError("Abstract, so implement") + + def __repr__(self): + ret_str = ( + str(type(self).__name__) + "( " + ", ".join([key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )" + ) + return ret_str class YuccaTransform(AbstractTransform): - @abstractmethod + @abc.abstractmethod def get_params(self): """ This will return a random value between @@ -35,7 +48,7 @@ def get_params(self): data modalities and labels remain registered. """ - @abstractmethod + @abc.abstractmethod def __call__(self): """ This will be of the form __call__(self, packed_dict: dict = None, **unpacked_dict): diff --git a/yucca/modules/data/data_modules/YuccaDataModule.py b/yucca/modules/data/data_modules/YuccaDataModule.py index 86fb2969..ccbf2989 100644 --- a/yucca/modules/data/data_modules/YuccaDataModule.py +++ b/yucca/modules/data/data_modules/YuccaDataModule.py @@ -2,9 +2,9 @@ import torchvision import logging import torch +import os from typing import Literal, Optional, Union from torch.utils.data import DataLoader, Sampler -from batchgenerators.utilities.file_and_folder_operations import join from yucca.pipeline.configuration.split_data import SplitConfig from yucca.modules.data.datasets.YuccaDataset import YuccaTestDataset, YuccaTrainDataset from yucca.modules.data.samplers import InfiniteRandomSampler @@ -116,8 +116,8 @@ def setup(self, stage: Literal["fit", "test", "predict"]): assert self.splits_config is not None assert self.task_type is not None - self.train_samples = [join(self.train_data_dir, i) for i in self.splits_config.train(self.split_idx)] - self.val_samples = [join(self.train_data_dir, i) for i in self.splits_config.val(self.split_idx)] + self.train_samples = [os.path.join(self.train_data_dir, i) for i in self.splits_config.train(self.split_idx)] + self.val_samples = [os.path.join(self.train_data_dir, i) for i in self.splits_config.val(self.split_idx)] if len(self.train_samples) < 100: logging.info(f"Training on samples: {self.train_samples}") diff --git a/yucca/modules/data/datasets/YuccaCompressedDataset.py b/yucca/modules/data/datasets/YuccaCompressedDataset.py index 760636f5..ce4c1ec4 100644 --- a/yucca/modules/data/datasets/YuccaCompressedDataset.py +++ b/yucca/modules/data/datasets/YuccaCompressedDataset.py @@ -1,12 +1,12 @@ import numpy as np -from batchgenerators.utilities.file_and_folder_operations import isfile +import os from yucca.modules.data.datasets.YuccaDataset import YuccaTrainDataset class YuccaCompressedTrainDataset(YuccaTrainDataset): def load_and_maybe_keep_volume(self, path: str): path = path + ".npz" - if isfile(path): + if os.path.isfile(path): try: return np.load(path, "r")["data"] except ValueError: diff --git a/yucca/modules/data/datasets/YuccaDataset.py b/yucca/modules/data/datasets/YuccaDataset.py index 54ffcf49..e52f0dec 100644 --- a/yucca/modules/data/datasets/YuccaDataset.py +++ b/yucca/modules/data/datasets/YuccaDataset.py @@ -3,7 +3,8 @@ import os import logging from typing import Union, Literal, Optional -from batchgenerators.utilities.file_and_folder_operations import subfiles, load_pickle, isfile +from yucca.functional.utils.files_and_folders import subfiles +from yucca.functional.utils.loading import load_pickle from yucca.modules.data.augmentation.transforms.cropping_and_padding import CropPad from yucca.modules.data.augmentation.transforms.formatting import NumpyToTorch @@ -64,7 +65,7 @@ def load_and_maybe_keep_pickle(self, path): def load_and_maybe_keep_volume(self, path): path = path + ".npy" if not self.keep_in_ram: - if isfile(path): + if os.path.isfile(path): try: return np.load(path, "r") except ValueError: @@ -72,7 +73,7 @@ def load_and_maybe_keep_volume(self, path): else: print("uncompressed data was not found.") - if isfile(path): + if os.path.isfile(path): if path in self.already_loaded_cases: return self.already_loaded_cases[path] try: @@ -242,10 +243,9 @@ def __getitem__(self, idx): if __name__ == "__main__": import torch from yucca.paths import get_preprocessed_data_path - from batchgenerators.utilities.file_and_folder_operations import join from yucca.modules.data.samplers import InfiniteRandomSampler - files = subfiles(join(get_preprocessed_data_path(), "Task001_OASIS/YuccaPlanner"), suffix="npy") + files = subfiles(os.path.join(get_preprocessed_data_path(), "Task001_OASIS/YuccaPlanner"), suffix="npy") ds = YuccaTrainDataset(files, patch_size=(12, 12, 12)) sampler = InfiniteRandomSampler(ds) dl = torch.utils.data.DataLoader(ds, num_workers=2, batch_size=2, sampler=sampler) diff --git a/yucca/modules/networks/networks/3D_resnet.py b/yucca/modules/networks/networks/3D_resnet.py index 5ec788e3..795601f3 100644 --- a/yucca/modules/networks/networks/3D_resnet.py +++ b/yucca/modules/networks/networks/3D_resnet.py @@ -1,5 +1,4 @@ -from typing import Union, List, Optional, Callable, Type -from pytorchvideo.models.resnet import create_resnet +from typing import List, Optional, Type from torch import nn, Tensor import torch from yucca.modules.networks.blocks_and_layers.res_blocks import BasicBlock, Bottleneck, conv_k1 diff --git a/yucca/modules/networks/utils/model_memory_estimation.py b/yucca/modules/networks/utils/model_memory_estimation.py index 9cd93c56..acd9a8bc 100644 --- a/yucca/modules/networks/utils/model_memory_estimation.py +++ b/yucca/modules/networks/utils/model_memory_estimation.py @@ -26,11 +26,10 @@ import yucca import math import logging +import os from yucca.functional.utils.torch_utils import get_available_device, flush_and_get_torch_memory_allocated from yucca.functional.utils.files_and_folders import recursive_find_python_class from yucca.functional.utils.kwargs import filter_kwargs - -from batchgenerators.utilities.file_and_folder_operations import join from torch import nn @@ -154,7 +153,7 @@ def find_optimal_tensor_dims( model_name = model_name.split("_")[0] model = recursive_find_python_class( - folder=[join(yucca.__path__[0], "modules", "networks")], + folder=[os.path.join(yucca.__path__[0], "modules", "networks")], class_name=model_name, current_module="yucca.modules.networks", ) diff --git a/yucca/modules/optimization/utils/LossTree.py b/yucca/modules/optimization/utils/LossTree.py index 0fd0dc22..3497a17a 100644 --- a/yucca/modules/optimization/utils/LossTree.py +++ b/yucca/modules/optimization/utils/LossTree.py @@ -1,4 +1,4 @@ -from batchgenerators.utilities.file_and_folder_operations import load_json +from yucca.functional.utils.loading import load_json class node: diff --git a/yucca/paths.py b/yucca/paths.py index a6ba66c0..32ca007d 100644 --- a/yucca/paths.py +++ b/yucca/paths.py @@ -5,8 +5,6 @@ import os from dotenv import load_dotenv -from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p as ensure_dir_exists - def var_is_set(var): return var in os.environ.keys() @@ -18,7 +16,7 @@ def get_environment_variable(var): raise ValueError(f"Missing required environment variable {var}.") path = os.environ[var] - ensure_dir_exists(path) + os.makedirs(path, exist_ok=True) return path diff --git a/yucca/pipeline/configuration/configure_checkpoint.py b/yucca/pipeline/configuration/configure_checkpoint.py index d1b9dbe5..a27e9e93 100644 --- a/yucca/pipeline/configuration/configure_checkpoint.py +++ b/yucca/pipeline/configuration/configure_checkpoint.py @@ -1,5 +1,5 @@ +import os import torch -from batchgenerators.utilities.file_and_folder_operations import join, isfile from dataclasses import dataclass from typing import Union from yucca.pipeline.configuration.configure_paths import PathConfig @@ -96,12 +96,16 @@ def get_checkpoint_config_from_ckpt(ckpt_path: str): def find_checkpoint_path(ckpt_path: Union[str, None], continue_from_most_recent: bool, version: int, version_dir: str): if ckpt_path: - assert isfile(ckpt_path), f"Checkpoint was not found. Looked in: {ckpt_path}" + assert os.path.isfile(ckpt_path), f"Checkpoint was not found. Looked in: {ckpt_path}" logging.info(f"Using ckpt file: {ckpt_path}") return ckpt_path - elif version is not None and continue_from_most_recent and isfile(join(version_dir, "checkpoints", "last.ckpt")): + elif ( + version is not None + and continue_from_most_recent + and os.path.isfile(os.path.join(version_dir, "checkpoints", "last.ckpt")) + ): logging.info("Using last checkpoint and continuing training") - return join(version_dir, "checkpoints", "last.ckpt") + return os.path.join(version_dir, "checkpoints", "last.ckpt") else: return None diff --git a/yucca/pipeline/configuration/configure_paths.py b/yucca/pipeline/configuration/configure_paths.py index f70677d6..6b8cd37b 100644 --- a/yucca/pipeline/configuration/configure_paths.py +++ b/yucca/pipeline/configuration/configure_paths.py @@ -1,4 +1,5 @@ -from batchgenerators.utilities.file_and_folder_operations import join, isdir, subdirs, maybe_mkdir_p as ensure_dir_exists +import os +from yucca.functional.utils.files_and_folders import subdirs from dataclasses import dataclass from typing import Union, Literal from yucca.paths import get_models_path, get_preprocessed_data_path @@ -25,9 +26,9 @@ def lm_hparams(self): def get_path_config(task_config: TaskConfig, stage: Literal["fit", "test", "predict"]): - task_dir = join(get_preprocessed_data_path(), task_config.task) - train_data_dir = join(task_dir, task_config.planner_name) - save_dir = join( + task_dir = os.path.join(get_preprocessed_data_path(), task_config.task) + train_data_dir = os.path.join(task_dir, task_config.planner_name) + save_dir = os.path.join( get_models_path(), task_config.task, task_config.model_name + "__" + task_config.model_dimensions, @@ -37,16 +38,16 @@ def get_path_config(task_config: TaskConfig, stage: Literal["fit", "test", "pred ) version = detect_version(save_dir, task_config.continue_from_most_recent) - version_dir = join(save_dir, f"version_{version}") - ensure_dir_exists(version_dir) + version_dir = os.path.join(save_dir, f"version_{version}") + os.makedirs(version_dir, exist_ok=True) # First try to load torch checkpoints and extract plans and carry-over information from there. if stage == "fit": - plans_path = join(task_dir, task_config.planner_name, task_config.planner_name + "_plans.json") + plans_path = os.path.join(task_dir, task_config.planner_name, task_config.planner_name + "_plans.json") if stage == "test": raise NotImplementedError if stage == "predict": - plans_path = join(version_dir, "hparams.yaml") + plans_path = os.path.join(version_dir, "hparams.yaml") return PathConfig( plans_path=plans_path, @@ -60,7 +61,7 @@ def get_path_config(task_config: TaskConfig, stage: Literal["fit", "test", "pred def detect_version(save_dir, continue_from_most_recent) -> Union[None, int]: # If the dir doesn't exist we return version 0 - if not isdir(save_dir): + if not os.path.isdir(save_dir): return 0 # The dir exists. Check if any previous version exists in dir. diff --git a/yucca/pipeline/configuration/configure_plans.py b/yucca/pipeline/configuration/configure_plans.py index a3e7277b..bdca99c5 100644 --- a/yucca/pipeline/configuration/configure_plans.py +++ b/yucca/pipeline/configuration/configure_plans.py @@ -1,12 +1,12 @@ import yucca -from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile +import os from dataclasses import dataclass from typing import Union, Literal from yucca.pipeline.preprocessing.UnsupervisedPreprocessor import UnsupervisedPreprocessor from yucca.pipeline.preprocessing.ClassificationPreprocessor import ClassificationPreprocessor from yucca.functional.utils.dict import without_keys from yucca.functional.utils.files_and_folders import recursive_find_python_class -from yucca.functional.utils.loading import load_yaml +from yucca.functional.utils.loading import load_yaml, load_json import logging @@ -54,7 +54,7 @@ def get_plan_config( if stage == "predict": # In this case we don't want to rely on plans being found in the preprocessed folder, # as it might not exist. - if isfile(plans_path): + if os.path.isfile(plans_path): plans = load_yaml(plans_path)["config"]["plans"] else: assert ckpt_plans is not None @@ -109,7 +109,7 @@ def setup_task_type(plans): # If key is not present in plan then we try to infer the task_type from the Type of Preprocessor preprocessor_class = recursive_find_python_class( - folder=[join(yucca.__path__[0], "pipeline", "preprocessing")], + folder=[os.path.join(yucca.__path__[0], "pipeline", "preprocessing")], class_name=plans["preprocessor"], current_module="yucca.pipeline.preprocessing", ) diff --git a/yucca/pipeline/evaluation/YuccaEvaluator.py b/yucca/pipeline/evaluation/YuccaEvaluator.py index 375fcee6..956d556f 100644 --- a/yucca/pipeline/evaluation/YuccaEvaluator.py +++ b/yucca/pipeline/evaluation/YuccaEvaluator.py @@ -231,37 +231,3 @@ def save_as_json(self, dict): print("Saving results.json \n \n ########################################################################") with open(self.outpath, "w") as f: json.dump(dict, f, default=float, indent=4) - - def update_streamtable(self, results_dict): - """ - Save evaluation results to a wandb StreamTable - - :param results_dict: dictionary with evaluation results - """ - from weave.monitoring import StreamTable - - task = self.outpath.split(os.path.sep)[-5] - target = self.outpath.split(os.path.sep)[-6] - model_name = "/".join(self.outpath.split(os.path.sep)[-4:]) - - st = StreamTable(table_name=task, entity_name=wandb.api.viewer()["entity"], project_name="Yucca") - - stream_dict = {"0. Experiment": model_name, "0. Target Task": target} - - if self.task_type == "classification": - stream_dict = {**stream_dict, **results_dict} - - elif self.task_type == "segmentation": - for key, _ in results_dict.items(): - if key == "0": - continue - else: - stream_dict.update( - {f"{key}. " + k: v for k, v in results_dict[key].items() if k in self.metrics_included_in_streamtable} - ) - - else: - raise NotImplementedError("Task type not supported") - - st.log(stream_dict) - st.finish()