diff --git a/pyproject.toml b/pyproject.toml index 7e5dea97..06c0577a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "yucca" -version = "2.2.1" +version = "2.2.4" authors = [ { name="Sebastian Llambias", email="llambias@live.com" }, { name="Asbjørn Munk", email="9844416+asbjrnmunk@users.noreply.github.com" }, @@ -29,6 +29,7 @@ dependencies = [ "numpy>=1.26.4", "pandas>=2.2.1", "python-dotenv==1.0.0", + "pytorchvideo==0.1.5", "scikit_image>=0.22.0", "scikit_learn>=1.4.1.post1", "seaborn>=0.13.2", diff --git a/yucca/functional/evaluation/confusion_matrix.py b/yucca/functional/evaluation/confusion_matrix.py index ac4afe24..6abc698a 100644 --- a/yucca/functional/evaluation/confusion_matrix.py +++ b/yucca/functional/evaluation/confusion_matrix.py @@ -35,3 +35,12 @@ def torch_get_tp_fp_tn_fn(confusion_matrix, ignore_label=0): TN.append(tn.cpu().numpy()) FN.append(fn.cpu().numpy()) return TP, FP, TN, FN + + +def convert_confusion_matrix_to_dict(confusion_matrix): + d = {} + for true_label, row in enumerate(confusion_matrix): + d[str(true_label)] = {} + for predicted_label, value in enumerate(row): + d[str(true_label)][str(predicted_label)] = value + return d diff --git a/yucca/functional/evaluation/evaluate_folder.py b/yucca/functional/evaluation/evaluate_folder.py index f35f342e..e83ab3b7 100644 --- a/yucca/functional/evaluation/evaluate_folder.py +++ b/yucca/functional/evaluation/evaluate_folder.py @@ -1,4 +1,5 @@ import sys +import os import numpy as np import nibabel as nib import logging @@ -11,6 +12,7 @@ from batchgenerators.utilities.file_and_folder_operations import join from sklearn.metrics import confusion_matrix from yucca.functional.evaluation.metrics import auroc +from yucca.functional.evaluation.confusion_matrix import convert_confusion_matrix_to_dict def evaluate_folder_segm( @@ -257,9 +259,6 @@ def evaluate_folder_cls( prediction_probs = [] ground_truths = [] - # Flag to check if we have prediction probabilities to calculate AUROC - use_probs = False - # load predictions and ground truths for case in tqdm(subjects, desc="Evaluating"): predpath = join(folder_with_predictions, case) @@ -268,15 +267,9 @@ def evaluate_folder_cls( pred: int = np.loadtxt(predpath) gt: int = np.loadtxt(gtpath) - try: - if len(prediction_probs) == 0: - print("Prediction probabilities found. Will use them for evaluation.") - use_probs = True - + if os.path.isfile(predpath.replace(".txt", ".npz")): pred_probs = np.load(predpath.replace(".txt", ".npz"))["data"] # contains output probabilities prediction_probs.append(pred_probs) - except FileNotFoundError: - pred_probs = None predictions.append(pred) ground_truths.append(gt) @@ -287,6 +280,8 @@ def evaluate_folder_cls( # calculate per-class metrics cmat = confusion_matrix(ground_truths, predictions, labels=labels) + cmat_dict = convert_confusion_matrix_to_dict(cmat) + resultdict["confusion_matrix"] = cmat_dict resultdict["per_class"] = {} for label in labels: @@ -303,7 +298,7 @@ def evaluate_folder_cls( resultdict["per_class"][str(label)] = labeldict # calculate AUROC - if use_probs: + if len(prediction_probs) > 0: auroc_per_class: list[float] = auroc(ground_truths, prediction_probs) for label, score in zip(labels, auroc_per_class): resultdict["per_class"][str(label)]["AUROC"] = round(score, 4) diff --git a/yucca/functional/preprocessing.py b/yucca/functional/preprocessing.py index 852a785d..ba1acadb 100644 --- a/yucca/functional/preprocessing.py +++ b/yucca/functional/preprocessing.py @@ -402,6 +402,7 @@ def preprocess_case_for_inference( target_size, target_spacing, target_orientation, + background_pixel_value: int = 0, allow_missing_modalities: bool = False, ext=".nii.gz", keep_aspect_ratio: bool = True, @@ -425,7 +426,9 @@ def preprocess_case_for_inference( image_properties["uncropped_shape"] = np.array(images[0].shape) if crop_to_nonzero: - nonzero_box = get_bbox_for_foreground(images[0], background_label=0) + if np.max(images[0]) <= background_pixel_value: + background_pixel_value = np.min(images[0]) + nonzero_box = get_bbox_for_foreground(images[0], background_label=background_pixel_value) for i in range(len(images)): images[i] = crop_to_box(images[i], nonzero_box) image_properties["nonzero_box"] = nonzero_box diff --git a/yucca/functional/testing/data/nifti.py b/yucca/functional/testing/data/nifti.py index c2f3f64a..fa39c789 100644 --- a/yucca/functional/testing/data/nifti.py +++ b/yucca/functional/testing/data/nifti.py @@ -1,6 +1,7 @@ import nibabel as nib import numpy as np import nibabel.orientations as nio +import logging from yucca.functional.utils.nib_utils import get_nib_orientation, get_nib_spacing @@ -48,4 +49,5 @@ def verify_orientation_is_LR_PA_IS(image: nib.Nifti1Image): if np.all(nio.axcodes2ornt(orientation)[:, 0] == expected_orientation_code): return True else: + logging.info(f"Found orientation {orientation}") return False diff --git a/yucca/functional/transforms/gamma.py b/yucca/functional/transforms/gamma.py index 50150414..3a7c1c02 100644 --- a/yucca/functional/transforms/gamma.py +++ b/yucca/functional/transforms/gamma.py @@ -9,6 +9,7 @@ def augment_gamma( invert_image=False, epsilon=1e-7, per_channel=False, + p_per_channel=None, clip_to_input_range=False, ): if invert_image: @@ -27,19 +28,20 @@ def augment_gamma( data_sample = np.clip(data_sample, a_min=img_min, a_max=img_max) else: for c in range(data_sample.shape[0]): - if np.random.random() < 0.5 and gamma_range[0] < 1: - gamma = np.random.uniform(gamma_range[0], 1) - else: - gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) - img_min = data_sample[c].min() - img_max = data_sample[c].max() - img_range = img_max - img_min - data_sample[c] = ( - np.power(((data_sample[c] - img_min) / float(img_range + epsilon)), gamma) * float(img_range + epsilon) - + img_min - ) - if clip_to_input_range: - data_sample[c] = np.clip(data_sample[c], a_min=img_min, a_max=img_max) + if np.random.uniform() < p_per_channel[c]: + if np.random.random() < 0.5 and gamma_range[0] < 1: + gamma = np.random.uniform(gamma_range[0], 1) + else: + gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) + img_min = data_sample[c].min() + img_max = data_sample[c].max() + img_range = img_max - img_min + data_sample[c] = ( + np.power(((data_sample[c] - img_min) / float(img_range + epsilon)), gamma) * float(img_range + epsilon) + + img_min + ) + if clip_to_input_range: + data_sample[c] = np.clip(data_sample[c], a_min=img_min, a_max=img_max) if invert_image: data_sample = -data_sample return data_sample diff --git a/yucca/functional/transforms/spatial.py b/yucca/functional/transforms/spatial.py index 2f902a82..d011d8b0 100644 --- a/yucca/functional/transforms/spatial.py +++ b/yucca/functional/transforms/spatial.py @@ -26,6 +26,7 @@ def spatial( scale_factor, clip_to_input_range, label: Optional[np.ndarray] = None, + linear_interpolation_channel=None, skip_label: bool = False, do_crop: bool = True, random_crop: bool = True, @@ -40,6 +41,9 @@ def spatial( cval = cval assert isinstance(cval, (int, float)), f"got {cval} of type {type(cval)}" + if isinstance(order, (int, float)): + order = [order for _ in range(image.shape[1])] + coords = create_zero_centered_coordinate_matrix(patch_size) image_canvas = np.zeros((image.shape[0], image.shape[1], *patch_size), dtype=np.float32) @@ -92,7 +96,7 @@ def spatial( image_canvas[b, c] = map_coordinates( image[b, c].astype(float), coords, - order=order, + order=order[c], mode="constant", cval=cval, ).astype(image.dtype) @@ -106,7 +110,7 @@ def spatial( dtype=np.float32, ) - # Mapping the labelmentations to the distorted coordinates + # Mapping the label to the distorted coordinates for b in range(label.shape[0]): for c in range(label.shape[1]): label_canvas[b, c] = map_coordinates(label[b, c], coords, order=0, mode="constant", cval=0.0).astype( diff --git a/yucca/functional/visualization/__init__.py b/yucca/functional/visualization/__init__.py index 772de7e8..e10d3cb0 100644 --- a/yucca/functional/visualization/__init__.py +++ b/yucca/functional/visualization/__init__.py @@ -1 +1,3 @@ -from yucca.functional.visualization.imshow import get_train_fig_with_inp_out_tar +from yucca.functional.visualization.imshow import get_segm_train_fig_with_inp_out_tar +from yucca.functional.visualization.imshow import get_cls_train_fig_with_inp_out_tar +from yucca.functional.visualization.imshow import get_ssl_train_fig_with_inp_out_tar diff --git a/yucca/functional/visualization/imshow.py b/yucca/functional/visualization/imshow.py index 9832197d..6491cb34 100644 --- a/yucca/functional/visualization/imshow.py +++ b/yucca/functional/visualization/imshow.py @@ -3,18 +3,14 @@ import matplotlib.pyplot as plt -def get_train_fig_with_inp_out_tar(input, output, target, fig_title, task_type: str = "segmentation"): +def get_segm_train_fig_with_inp_out_tar(input, output, target, fig_title): # This needs to handle the following cases: # Segmentation : {"input": (m,x,y(,z)), "target": (1,x,y(,z)), "output": (c,x,y(,z))} - # Self-supervised : {"input": (m,x,y(,z)), "target": (m,x,y(,z)), "output": (m,x,y(,z))} - # Classification : {"input": (m,x,y(,z)), "target": (1,x), "output": (c,x)} channel_idx = np.random.randint(0, input.shape[0]) - if len(input.shape) == 4: # 3D images. - # We need to select a slice to visualize. - if task_type == "segmentation" and len(target[0].nonzero()[0]) > 0: - # Select a foreground slice if any exist. + if len(input.shape) == 4: + if len(target[0].nonzero()[0]) > 0: foreground_locations = target[0].nonzero() slice_to_visualize = foreground_locations[0][np.random.randint(0, len(foreground_locations[0]))] else: @@ -27,30 +23,67 @@ def get_train_fig_with_inp_out_tar(input, output, target, fig_title, task_type: output = output[:, slice_to_visualize] image = input[channel_idx] + target = target[0] + output = output.argmax(0) + + fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), dpi=100, constrained_layout=True) + axes[0].imshow(image, cmap="gray", vmin=np.quantile(image, 0.01), vmax=np.quantile(image, 0.99)) + axes[0].set_title("input") + axes[1].imshow(target, cmap="gray") + axes[1].set_title("target") + axes[2].imshow(output, cmap="gray") + axes[2].set_title("output") + fig.suptitle(fig_title, fontsize=16) + return fig + + +def get_cls_train_fig_with_inp_out_tar(input, output, target, fig_title): + # This needs to handle the following case: + # Classification : {"input": (m,x,y(,z)), "target": (n_classes), "output": (n_classes)} + + channel_idx = np.random.randint(0, input.shape[0]) + + slice_to_visualize = np.random.randint(0, input.shape[1]) + + if len(input.shape) == 4: # 3D images. + input = input[:, slice_to_visualize] + + image = input[channel_idx] + + output = output.argmax(0) + + fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), dpi=100, constrained_layout=True) + axes.imshow(image, cmap="gray", vmin=np.quantile(image, 0.01), vmax=np.quantile(image, 0.99)) + axes.set_title(f"Input: {fig_title}", fontsize=12) + fig.suptitle(f"Target: {target} | Output: {output}", fontsize=12) + return fig + + +def get_ssl_train_fig_with_inp_out_tar(input, output, target, fig_title): + # This needs to handle the following cases: + # Self-supervised : {"input": (m,x,y(,z)), "target": (m,x,y(,z)), "output": (m,x,y(,z))} + + channel_idx = np.random.randint(0, input.shape[0]) + + if len(input.shape) == 4: # 3D images. + slice_to_visualize = np.random.randint(0, input.shape[1]) + input = input[:, slice_to_visualize] + if len(target.shape) == 4: + target = target[:, slice_to_visualize] + if len(output.shape) == 4: + output = output[:, slice_to_visualize] + + image = input[channel_idx] + + target = target[channel_idx] + output = output[channel_idx] - if task_type in ["segmentation", "classification"]: - target = target[0] - output = output.argmax(0) - elif task_type == "self-supervised": - target = target[channel_idx] - output = output[channel_idx] - else: - logging.warn( - f"Unknown task type. Found {task_type} and expected one in ['classification', 'segmentation', 'self-supervised']" - ) - - if len(target.shape) == 1: - fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), dpi=100, constrained_layout=True) - axes[0].imshow(image, cmap="gray", vmin=np.quantile(image, 0.01), vmax=np.quantile(image, 0.99)) - axes[0].set_title("input") - fig.suptitle(f"{fig_title}. Target: {target} | Output: {output}", fontsize=16) - else: - fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), dpi=100, constrained_layout=True) - axes[0].imshow(image, cmap="gray", vmin=np.quantile(image, 0.01), vmax=np.quantile(image, 0.99)) - axes[0].set_title("input") - axes[1].imshow(target, cmap="gray") - axes[1].set_title("target") - axes[2].imshow(output, cmap="gray") - axes[2].set_title("output") - fig.suptitle(fig_title, fontsize=16) + fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), dpi=100, constrained_layout=True) + axes[0].imshow(image, cmap="gray", vmin=np.quantile(image, 0.01), vmax=np.quantile(image, 0.99)) + axes[0].set_title("input") + axes[1].imshow(target, cmap="gray") + axes[1].set_title("target") + axes[2].imshow(output, cmap="gray") + axes[2].set_title("output") + fig.suptitle(fig_title, fontsize=16) return fig diff --git a/yucca/modules/callbacks/loggers.py b/yucca/modules/callbacks/loggers.py index 74cb242b..5647c30c 100644 --- a/yucca/modules/callbacks/loggers.py +++ b/yucca/modules/callbacks/loggers.py @@ -74,7 +74,7 @@ def create_logfile(self): self.log_dir, "training_log.txt", ) - with open(self.log_file, "w") as f: + with open(self.log_file, "a+") as f: f.write("Starting model training") logging.info("Starting model training \n" f'{"log file:":20} {self.log_file} \n') f.write("\n") diff --git a/yucca/modules/data/augmentation/YuccaAugmentationComposer.py b/yucca/modules/data/augmentation/YuccaAugmentationComposer.py index 60d9ce1a..dc5d7c30 100644 --- a/yucca/modules/data/augmentation/YuccaAugmentationComposer.py +++ b/yucca/modules/data/augmentation/YuccaAugmentationComposer.py @@ -69,6 +69,7 @@ def setup_default_params(self, is_2D, patch_size): self.cval = "min" # can be an int, float or a str in ['min', 'max'] self.clip_to_input_range = False # ensures no augmentations go beyond the input range of the image/patch self.normalize = False + self.interpolation_order = 3 # label/segmentation transforms self.skip_label = False @@ -78,17 +79,23 @@ def setup_default_params(self, is_2D, patch_size): # default augmentation probabilities self.additive_noise_p_per_sample = 0.2 + self.additive_noise_p_per_channel = 1.0 self.biasfield_p_per_sample = 0.33 + self.biasfield_p_per_channel = 1 self.blurring_p_per_sample = 0.2 self.blurring_p_per_channel = 0.5 self.elastic_deform_p_per_sample = 0.33 self.gamma_p_per_sample = 0.2 + self.gamma_p_per_channel = 1.0 self.gamma_p_invert_image = 0.05 self.gibbs_ringing_p_per_sample = 0.2 + self.gibbs_ringing_p_per_channel = 1.0 self.mirror_p_per_sample = 0.0 self.mirror_p_per_axis = 0.33 self.motion_ghosting_p_per_sample = 0.2 + self.motion_ghosting_p_per_channel = 1.0 self.multiplicative_noise_p_per_sample = 0.2 + self.multiplicative_noise_p_per_channel = 1.0 self.rotation_p_per_sample = 0.2 self.rotation_p_per_axis = 0.66 self.scale_p_per_sample = 0.2 @@ -170,6 +177,7 @@ def compose_train_transforms(self): random_crop=self.random_crop, cval=self.cval, clip_to_input_range=self.clip_to_input_range, + order=self.interpolation_order, p_deform_per_sample=self.elastic_deform_p_per_sample, deform_sigma=self.elastic_deform_sigma, deform_alpha=self.elastic_deform_alpha, @@ -184,6 +192,7 @@ def compose_train_transforms(self): ), AdditiveNoise( p_per_sample=self.additive_noise_p_per_sample, + p_per_channel=self.additive_noise_p_per_channel, mean=self.additive_noise_mean, sigma=self.additive_noise_sigma, clip_to_input_range=self.clip_to_input_range, @@ -196,12 +205,14 @@ def compose_train_transforms(self): ), MultiplicativeNoise( p_per_sample=self.multiplicative_noise_p_per_sample, + p_per_channel=self.multiplicative_noise_p_per_channel, mean=self.multiplicative_noise_mean, sigma=self.multiplicative_noise_sigma, clip_to_input_range=self.clip_to_input_range, ), MotionGhosting( p_per_sample=self.motion_ghosting_p_per_sample, + p_per_channel=self.motion_ghosting_p_per_channel, alpha=self.motion_ghosting_alpha, num_reps=self.motion_ghosting_num_reps, axes=self.motion_ghosting_axes, @@ -209,6 +220,7 @@ def compose_train_transforms(self): ), GibbsRinging( p_per_sample=self.gibbs_ringing_p_per_sample, + p_per_channel=self.gibbs_ringing_p_per_channel, cut_freq=self.gibbs_ringing_cut_freq, axes=self.gibbs_ringing_axes, clip_to_input_range=self.clip_to_input_range, @@ -222,10 +234,12 @@ def compose_train_transforms(self): ), BiasField( p_per_sample=self.biasfield_p_per_sample, + p_per_channel=self.biasfield_p_per_channel, clip_to_input_range=self.clip_to_input_range, ), Gamma( p_per_sample=self.gamma_p_per_sample, + p_per_channel=self.gamma_p_per_channel, p_invert_image=self.gamma_p_invert_image, gamma_range=self.gamma_range, clip_to_input_range=self.clip_to_input_range, @@ -282,19 +296,24 @@ def lm_hparams(self): "copy_image_to_label": self.copy_image_to_label, "convert_labels_to_regions": self.convert_labels_to_regions, "additive_noise_p_per_sample": self.additive_noise_p_per_sample, + "additive_noise_p_per_channel": self.additive_noise_p_per_channel, "additive_noise_mean": self.additive_noise_mean, "additive_noise_sigma": self.additive_noise_sigma, "biasfield_p_per_sample": self.biasfield_p_per_sample, + "biasfield_p_per_channel": self.biasfield_p_per_channel, "blurring_p_per_sample": self.blurring_p_per_sample, + "blurring_p_per_channel": self.blurring_p_per_channel, "blurring_sigma": self.blurring_sigma, "blurring_p_per_channel": self.blurring_p_per_channel, "elastic_deform_p_per_sample": self.elastic_deform_p_per_sample, "elastic_deform_alpha": self.elastic_deform_alpha, "elastic_deform_sigma": self.elastic_deform_sigma, "gamma_p_per_sample": self.gamma_p_per_sample, + "gamma_p_per_channel": self.gamma_p_per_channel, "gamma_p_invert_image": self.gamma_p_invert_image, "gamma_range": self.gamma_range, "gibbs_ringing_p_per_sample": self.gibbs_ringing_p_per_sample, + "gibbs_ringing_p_per_channel": self.gibbs_ringing_p_per_channel, "gibbs_ringing_cut_freq": self.gibbs_ringing_cut_freq, "gibbs_ringing_axes": self.gibbs_ringing_axes, "mask_ratio": self.mask_ratio, @@ -302,10 +321,12 @@ def lm_hparams(self): "mirror_p_per_axis": self.mirror_p_per_axis, "mirror_axes": self.mirror_axes, "motion_ghosting_p_per_sample": self.motion_ghosting_p_per_sample, + "motion_ghosting_p_per_channel": self.motion_ghosting_p_per_channel, "motion_ghosting_alpha": self.motion_ghosting_alpha, "motion_ghosting_num_reps": self.motion_ghosting_num_reps, "motion_ghosting_axes": self.motion_ghosting_axes, "multiplicative_noise_p_per_sample": self.multiplicative_noise_p_per_sample, + "multiplicative_noise_p_per_channel": self.multiplicative_noise_p_per_channel, "multiplicative_noise_mean": self.multiplicative_noise_mean, "multiplicative_noise_sigma": self.multiplicative_noise_sigma, "rotation_p_per_sample": self.rotation_p_per_sample, diff --git a/yucca/modules/data/augmentation/augmentation_presets.py b/yucca/modules/data/augmentation/augmentation_presets.py index f3cb52e8..aa02915a 100644 --- a/yucca/modules/data/augmentation/augmentation_presets.py +++ b/yucca/modules/data/augmentation/augmentation_presets.py @@ -122,3 +122,16 @@ "simulate_lowres_p_per_axis": 0.66, # default augmentation values } + +channel_specific_probas = { + "interpolation_order": [3, 0], + "additive_noise_p_per_channel": [1.0, 0.0], + "biasfield_p_per_channel": [1.0, 0.0], + "blurring_p_per_channel": [1.0, 0.0], + "gamma_p_per_channel": [0.2, 0.0], + "gibbs_ringing_p_per_channel": [0.0, 0.0], + "motion_ghosting_p_per_channel": [0.0, 0.0], + "multiplicative_noise_p_per_channel": [0.2, 0.0], + "simulate_lowres_p_per_channel": [0.5, 0], + # default augmentation values +} diff --git a/yucca/modules/data/augmentation/transforms/BiasField.py b/yucca/modules/data/augmentation/transforms/BiasField.py index 603060af..bbb32c2e 100644 --- a/yucca/modules/data/augmentation/transforms/BiasField.py +++ b/yucca/modules/data/augmentation/transforms/BiasField.py @@ -4,14 +4,17 @@ class BiasField(YuccaTransform): + def __init__( self, data_key="image", p_per_sample: float = 1.0, + p_per_channel=1.0, clip_to_input_range=False, ): self.data_key = data_key self.p_per_sample = p_per_sample + self.p_per_channel = p_per_channel self.clip_to_input_range = clip_to_input_range @staticmethod @@ -30,8 +33,13 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict): ), f"Incorrect data size or shape. \nShould be (b, c, x, y, z) or (b, c, x, y) and is:\ {data_dict[self.data_key].shape}" + self.p_per_channel = self.__ensure_p_per_channel_is_iterable__( + self.p_per_channel, n_channels=data_dict[self.data_key].shape[1] + ) + for b in range(data_dict[self.data_key].shape[0]): - for c in range(data_dict[self.data_key][b].shape[0]): - if np.random.uniform() < self.p_per_sample: - data_dict[self.data_key][b, c] = self.__biasField__(data_dict[self.data_key][b, c]) + if np.random.uniform() < self.p_per_sample: + for c in range(data_dict[self.data_key][b].shape[0]): + if np.random.uniform() < self.p_per_channel[c]: + data_dict[self.data_key][b, c] = self.__biasField__(data_dict[self.data_key][b, c]) return data_dict diff --git a/yucca/modules/data/augmentation/transforms/Blur.py b/yucca/modules/data/augmentation/transforms/Blur.py index d39badae..790dfe8a 100644 --- a/yucca/modules/data/augmentation/transforms/Blur.py +++ b/yucca/modules/data/augmentation/transforms/Blur.py @@ -26,7 +26,7 @@ def get_params(sigma: Tuple[float]): def __blur__(self, image, sigma): for c in range(image.shape[0]): - if np.random.uniform() < self.p_per_channel: + if np.random.uniform() < self.p_per_channel[c]: image[c] = blur(image[c], sigma, clip_to_input_range=self.clip_to_input_range) return image @@ -37,6 +37,10 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict): ), f"Incorrect data size or shape.\ \nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}" + self.p_per_channel = self.__ensure_p_per_channel_is_iterable__( + self.p_per_channel, n_channels=data_dict[self.data_key].shape[1] + ) + for b in range(data_dict[self.data_key].shape[0]): if np.random.uniform() < self.p_per_sample: sigma = self.get_params(self.sigma) diff --git a/yucca/modules/data/augmentation/transforms/Gamma.py b/yucca/modules/data/augmentation/transforms/Gamma.py index 6a7d10a2..92aadcfc 100644 --- a/yucca/modules/data/augmentation/transforms/Gamma.py +++ b/yucca/modules/data/augmentation/transforms/Gamma.py @@ -23,6 +23,7 @@ def __init__( self, data_key="image", p_per_sample: float = 1.0, + p_per_channel=1.0, p_invert_image: float = 0.05, gamma_range=(0.5, 2.0), per_channel=True, @@ -30,6 +31,7 @@ def __init__( ): self.data_key = data_key self.p_per_sample = p_per_sample + self.p_per_channel = p_per_channel self.gamma_range = gamma_range self.p_invert_image = p_invert_image self.per_channel = per_channel @@ -43,12 +45,13 @@ def get_params(p_invert_image): do_invert = True return do_invert - def __gamma__(self, image, gamma_range, invert_image, per_channel): + def __gamma__(self, image, gamma_range, invert_image, per_channel, p_per_channel): return augment_gamma( image, gamma_range, invert_image, - per_channel, + per_channel=per_channel, + p_per_channel=p_per_channel, clip_to_input_range=self.clip_to_input_range, ) @@ -59,6 +62,10 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict): ), f"Incorrect data size or shape.\ \nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}" + self.p_per_channel = self.__ensure_p_per_channel_is_iterable__( + self.p_per_channel, n_channels=data_dict[self.data_key].shape[1] + ) + for b in range(data_dict[self.data_key].shape[0]): if np.random.uniform() < self.p_per_sample: do_invert = self.get_params(self.p_invert_image) @@ -67,5 +74,6 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict): self.gamma_range, do_invert, per_channel=self.per_channel, + p_per_channel=self.p_per_channel, ) return data_dict diff --git a/yucca/modules/data/augmentation/transforms/Ghosting.py b/yucca/modules/data/augmentation/transforms/Ghosting.py index 5b34f6a1..dee774ec 100644 --- a/yucca/modules/data/augmentation/transforms/Ghosting.py +++ b/yucca/modules/data/augmentation/transforms/Ghosting.py @@ -5,10 +5,12 @@ class MotionGhosting(YuccaTransform): + def __init__( self, data_key="image", p_per_sample: float = 1.0, + p_per_channel=1.0, alpha=(0.85, 0.95), num_reps=(2, 5), axes=(0, 3), @@ -16,6 +18,7 @@ def __init__( ): self.data_key = data_key self.p_per_sample = p_per_sample + self.p_per_channel = p_per_channel self.alpha = alpha self.num_reps = num_reps self.axes = axes @@ -41,11 +44,16 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict): ), f"Incorrect data size or shape.\ \nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}" + self.p_per_channel = self.__ensure_p_per_channel_is_iterable__( + self.p_per_channel, n_channels=data_dict[self.data_key].shape[1] + ) + for b in range(data_dict[self.data_key].shape[0]): - for c in range(data_dict[self.data_key][b].shape[0]): - if np.random.uniform() < self.p_per_sample: - alpha, num_reps, axis = self.get_params(self.alpha, self.num_reps, self.axes) - data_dict[self.data_key][b, c] = self.__motionGhosting__( - data_dict[self.data_key][b, c], alpha, num_reps, axis - ) + if np.random.uniform() < self.p_per_sample: + for c in range(data_dict[self.data_key][b].shape[0]): + if np.random.uniform() < self.p_per_channel[c]: + alpha, num_reps, axis = self.get_params(self.alpha, self.num_reps, self.axes) + data_dict[self.data_key][b, c] = self.__motionGhosting__( + data_dict[self.data_key][b, c], alpha, num_reps, axis + ) return data_dict diff --git a/yucca/modules/data/augmentation/transforms/Noise.py b/yucca/modules/data/augmentation/transforms/Noise.py index f10669fd..3896e1fd 100644 --- a/yucca/modules/data/augmentation/transforms/Noise.py +++ b/yucca/modules/data/augmentation/transforms/Noise.py @@ -9,12 +9,14 @@ def __init__( self, data_key="image", p_per_sample: float = 1.0, + p_per_channel=1.0, mean=(0.0, 0.0), sigma=(1e-3, 1e-4), clip_to_input_range=False, ): self.data_key = data_key self.p_per_sample = p_per_sample + self.p_per_channel = p_per_channel self.mean = mean self.sigma = sigma self.clip_to_input_range = clip_to_input_range @@ -36,11 +38,16 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict): ), f"Incorrect data size or shape.\ \nShould be (c, x, y, z) or (c, x, y) and is: {data_dict[self.data_key].shape}" + self.p_per_channel = self.__ensure_p_per_channel_is_iterable__( + self.p_per_channel, n_channels=data_dict[self.data_key].shape[1] + ) + for b in range(data_dict[self.data_key].shape[0]): - for c in range(data_dict[self.data_key][b].shape[0]): - mean, sigma = self.get_params(self.mean, self.sigma) - if np.random.uniform() < self.p_per_sample: - data_dict[self.data_key][b, c] = self.__additiveNoise__(data_dict[self.data_key][b, c], mean, sigma) + if np.random.uniform() < self.p_per_sample: + for c in range(data_dict[self.data_key][b].shape[0]): + if np.random.uniform() < self.p_per_channel[c]: + mean, sigma = self.get_params(self.mean, self.sigma) + data_dict[self.data_key][b, c] = self.__additiveNoise__(data_dict[self.data_key][b, c], mean, sigma) return data_dict @@ -57,12 +64,14 @@ def __init__( self, data_key="image", p_per_sample: float = 1.0, + p_per_channel=1.0, mean=(0.0, 0.0), sigma=(1e-3, 1e-4), clip_to_input_range=False, ): self.data_key = data_key self.p_per_sample = p_per_sample + self.p_per_channel = p_per_channel self.mean = mean self.sigma = sigma self.clip_to_input_range = clip_to_input_range @@ -84,9 +93,16 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict): ), f"Incorrect data size or shape.\ \nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}" + self.p_per_channel = self.__ensure_p_per_channel_is_iterable__( + self.p_per_channel, n_channels=data_dict[self.data_key].shape[1] + ) + for b in range(data_dict[self.data_key].shape[0]): - for c in range(data_dict[self.data_key][b].shape[0]): - if np.random.uniform() < self.p_per_sample: - mean, sigma = self.get_params(self.mean, self.sigma) - data_dict[self.data_key][b, c] = self.__multiplicativeNoise__(data_dict[self.data_key][b, c], mean, sigma) + if np.random.uniform() < self.p_per_sample: + for c in range(data_dict[self.data_key][b].shape[0]): + if np.random.uniform() < self.p_per_channel[c]: + mean, sigma = self.get_params(self.mean, self.sigma) + data_dict[self.data_key][b, c] = self.__multiplicativeNoise__( + data_dict[self.data_key][b, c], mean, sigma + ) return data_dict diff --git a/yucca/modules/data/augmentation/transforms/Ringing.py b/yucca/modules/data/augmentation/transforms/Ringing.py index b3c848d8..1d477a7f 100644 --- a/yucca/modules/data/augmentation/transforms/Ringing.py +++ b/yucca/modules/data/augmentation/transforms/Ringing.py @@ -4,16 +4,19 @@ class GibbsRinging(YuccaTransform): + def __init__( self, data_key="image", p_per_sample: float = 1.0, + p_per_channel=1.0, cut_freq=(96, 129), axes=(0, 3), clip_to_input_range=False, ): self.data_key = data_key self.p_per_sample = p_per_sample + self.p_per_channel = p_per_channel self.cut_freq = cut_freq self.axes = axes self.clip_to_input_range = clip_to_input_range @@ -35,9 +38,14 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict): ), f"Incorrect data size or shape.\ \nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}" + self.p_per_channel = self.__ensure_p_per_channel_is_iterable__( + self.p_per_channel, n_channels=data_dict[self.data_key].shape[1] + ) + for b in range(data_dict[self.data_key].shape[0]): - for c in range(data_dict[self.data_key][b].shape[0]): - if np.random.uniform() < self.p_per_sample: - cut_freq, axis = self.get_params(self.cut_freq, self.axes) - data_dict[self.data_key][b, c] = self.__gibbsRinging__(data_dict[self.data_key][b, c], cut_freq, axis) + if np.random.uniform() < self.p_per_sample: + for c in range(data_dict[self.data_key][b].shape[0]): + if np.random.uniform() < self.p_per_channel[c]: + cut_freq, axis = self.get_params(self.cut_freq, self.axes) + data_dict[self.data_key][b, c] = self.__gibbsRinging__(data_dict[self.data_key][b, c], cut_freq, axis) return data_dict diff --git a/yucca/modules/data/augmentation/transforms/SimulateLowres.py b/yucca/modules/data/augmentation/transforms/SimulateLowres.py index dd92087b..9b75ff00 100644 --- a/yucca/modules/data/augmentation/transforms/SimulateLowres.py +++ b/yucca/modules/data/augmentation/transforms/SimulateLowres.py @@ -44,10 +44,14 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict): ), f"Incorrect data size or shape.\ \nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}" + self.p_per_channel = self.__ensure_p_per_channel_is_iterable__( + self.p_per_channel, n_channels=data_dict[self.data_key].shape[1] + ) + for b in range(data_dict[self.data_key].shape[0]): if np.random.uniform() < self.p_per_sample: for c in range(data_dict[self.data_key][b].shape[0]): - if np.random.uniform() < self.p_per_channel: + if np.random.uniform() < self.p_per_channel[c]: target_shape = self.get_params( self.zoom_range, data_dict[self.data_key][b, c].shape, diff --git a/yucca/modules/data/augmentation/transforms/YuccaTransform.py b/yucca/modules/data/augmentation/transforms/YuccaTransform.py index be0ce6e2..c153b055 100644 --- a/yucca/modules/data/augmentation/transforms/YuccaTransform.py +++ b/yucca/modules/data/augmentation/transforms/YuccaTransform.py @@ -42,3 +42,9 @@ def __call__(self): which allows calling it as either transform(data_dict) or transform(**data_dict), supporting both Torch pipelines and batchgenerators. """ + + @staticmethod + def __ensure_p_per_channel_is_iterable__(p_per_channel, n_channels): + if not isinstance(p_per_channel, (list, tuple)): + p_per_channel = [p_per_channel for _ in range(n_channels)] + return p_per_channel diff --git a/yucca/modules/data/augmentation/transforms/copy_image_to_label.py b/yucca/modules/data/augmentation/transforms/copy_image_to_label.py index a1e15a55..d4011b44 100644 --- a/yucca/modules/data/augmentation/transforms/copy_image_to_label.py +++ b/yucca/modules/data/augmentation/transforms/copy_image_to_label.py @@ -18,10 +18,10 @@ def __copy__(self, data_dict): def __call__(self, packed_data_dict=None, **unpacked_data_dict): data_dict = packed_data_dict if packed_data_dict else unpacked_data_dict - assert ( - len(data_dict[self.data_key].shape) == 5 or len(data_dict[self.data_key].shape) == 4 - ), f"Incorrect data size or shape.\ - \nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}" if self.copy: + assert ( + len(data_dict[self.data_key].shape) == 5 or len(data_dict[self.data_key].shape) == 4 + ), f"Incorrect data size or shape.\ + \nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}" data_dict = self.__copy__(data_dict) return data_dict diff --git a/yucca/modules/data/datasets/ClassificationDataset.py b/yucca/modules/data/datasets/ClassificationDataset.py new file mode 100644 index 00000000..736adb76 --- /dev/null +++ b/yucca/modules/data/datasets/ClassificationDataset.py @@ -0,0 +1,172 @@ +import numpy as np +import torch +import os +from typing import Union, Optional +from batchgenerators.utilities.file_and_folder_operations import subfiles +from yucca.modules.data.augmentation.transforms.cropping_and_padding import CropPad +from yucca.modules.data.augmentation.transforms.formatting import NumpyToTorch +from yucca.modules.data.datasets.YuccaDataset import YuccaTrainDataset, YuccaTestDataset + + +class ClassificationTrainDataset(YuccaTrainDataset): + def __init__( + self, + samples: list, + patch_size: list | tuple, + keep_in_ram: Union[bool, None] = None, + label_dtype: Optional[Union[int, float]] = torch.int32, + task_type: str = "classification", + composed_transforms=None, + allow_missing_modalities=False, + p_oversample_foreground=0.33, + ): + self.all_cases = samples + self.composed_transforms = composed_transforms + self.patch_size = patch_size + self.label_dtype = label_dtype + self.allow_missing_modalities = allow_missing_modalities + self.already_loaded_cases = {} + + self.croppad = CropPad(patch_size=self.patch_size, label_key=None, p_oversample_foreground=p_oversample_foreground) + self.to_torch = NumpyToTorch(label_dtype=self.label_dtype) + + self._keep_in_ram = keep_in_ram + + def __getitem__(self, idx): + # remove extension if file splits include extensions + case, _ = os.path.splitext(self.all_cases[idx]) + data = self.load_and_maybe_keep_volume(case) + metadata = self.load_and_maybe_keep_pickle(case) + + if self.allow_missing_modalities: + image, label = self.unpack_with_zeros(data) + else: + image, label = self.unpack(data) + + data_dict = {"file_path": case} + data_dict.update({"image": image, "label": label}) + + return self._transform(data_dict, metadata) + + def unpack(self, data): + return data[0], data[-1][0] + + def unpack_with_zeros(self, data): + assert data.dtype == "object", "allow missing modalities is true but dtype is not object" + + # First find the array with the largest array. + # in classification this avoids setting the zero array to the 1d array with classes + sizes = [i.size for i in data] + idx_largest_array = np.where(sizes == np.max(sizes))[0][0] + + # replace missing modalities with zero-filed arrays + for idx, i in enumerate(data): + if i.size == 0: + data[idx] = np.zeros(data[idx_largest_array].squeeze().shape) + + # unpack array into images and labels + images = np.array([mod for mod in data[:-1]]) + label = data[-1:][0] + + return images, label + + +class ClassificationTestDataset(YuccaTestDataset): + def __init__( + self, + raw_data_dir: str, + pred_save_dir: str, + overwrite_predictions: bool = False, + suffix="nii.gz", + prediction_suffix=None, + pred_include_cases: list = None, + ): + super().__init__( + raw_data_dir=raw_data_dir, + pred_save_dir=pred_save_dir, + overwrite_predictions=overwrite_predictions, + suffix=suffix, + prediction_suffix="txt", + pred_include_cases=pred_include_cases, + ) + + +class ClassificationTrainDatasetWithCovariates(ClassificationTrainDataset): + def __init__( + self, + samples: list, + patch_size: list | tuple, + keep_in_ram: Union[bool, None] = None, + label_dtype: Optional[Union[int, float]] = torch.int32, + task_type: str = "classification", + composed_transforms=None, + allow_missing_modalities=False, + p_oversample_foreground=0.33, + ): + super().__init__( + samples=samples, + patch_size=patch_size, + keep_in_ram=keep_in_ram, + label_dtype=label_dtype, + task_type=task_type, + composed_transforms=composed_transforms, + allow_missing_modalities=allow_missing_modalities, + p_oversample_foreground=p_oversample_foreground, + ) + + def __getitem__(self, idx): + # remove extension if file splits include extensions + case, _ = os.path.splitext(self.all_cases[idx]) + data = self.load_and_maybe_keep_volume(case) + metadata = self.load_and_maybe_keep_pickle(case) + + image, covariates, label = self.unpack(data) + data_dict = {"file_path": case} + data_dict.update({"image": image, "covariates": covariates, "label": label}) + + return self._transform(data_dict, metadata) + + def unpack(self, data): + return data[0], data[-2], data[-1][0] + + +class ClassificationTestDatasetWithCovariates(YuccaTestDataset): + def __init__( + self, + raw_data_dir: str, + pred_save_dir: str, + overwrite_predictions: bool = False, + suffix="nii.gz", + prediction_suffix=None, + pred_include_cases: list = None, + ): + super().__init__( + raw_data_dir=raw_data_dir, + pred_save_dir=pred_save_dir, + overwrite_predictions=overwrite_predictions, + suffix=suffix, + prediction_suffix="txt", + pred_include_cases=pred_include_cases, + ) + + def __getitem__(self, idx): + # Here we generate the paths to the cases along with their ID which they will be saved as. + # we pass "case" as a list of strings and case_id as a string to the dataloader which + # will convert them to a list of tuples of strings and a tuple of a string. + # i.e. ['path1', 'path2'] -> [('path1',), ('path2',)] + case_id = self.unique_cases[idx] + image_paths = [ + impath + for impath in subfiles(self.data_path, suffix=self.suffix) + if os.path.split(impath)[-1][: -len("_000." + self.suffix)] == case_id + ] + covariatepath = self.data_path.replace("imagesTs", "covariatesTs") + covariates = torch.tensor(np.loadtxt(os.path.join(covariatepath, case_id + "_COV.txt"))).unsqueeze(0) + return {"data_paths": image_paths, "covariates": covariates, "extension": self.suffix, "case_id": case_id} + + +if __name__ == "__main__": + from batchgenerators.utilities.file_and_folder_operations import subfiles + + files = subfiles("/home/zcr545/yuccadata/yucca_preprocessed/Task503_ADNI300_MRI/ClassificationV2_112x224x224") + data = ClassificationDataset(files, patch_size=(12, 12, 12)) diff --git a/yucca/modules/data/datasets/YuccaDataset.py b/yucca/modules/data/datasets/YuccaDataset.py index 54ffcf49..3ca0b7d6 100644 --- a/yucca/modules/data/datasets/YuccaDataset.py +++ b/yucca/modules/data/datasets/YuccaDataset.py @@ -99,7 +99,7 @@ def __getitem__(self, idx): image, label = self.unpack(data, supervised=self.supervised) data_dict = {"file_path": case} # metadata that can be very useful for debugging. - if self.task_type in ["classification", "segmentation"]: + if self.task_type == "segmentation": data_dict.update({"image": image, "label": label}) elif self.task_type == "self-supervised": data_dict.update({"image": image}) @@ -153,12 +153,14 @@ def __init__( pred_save_dir: str, overwrite_predictions: bool = False, suffix="nii.gz", + prediction_suffix="nii.gz", pred_include_cases: list = None, ): self.data_path = raw_data_dir self.pred_save_dir = pred_save_dir self.overwrite = overwrite_predictions self.suffix = suffix + self.prediction_suffix = prediction_suffix self.pred_include_cases = pred_include_cases self.unique_cases = np.unique( [i[: -len("_000." + suffix)] for i in subfiles(self.data_path, suffix=self.suffix, join=False)] @@ -166,7 +168,10 @@ def __init__( assert len(self.unique_cases) > 0, f"No cases found in {self.data_path}. Looking for files with suffix: {self.suffix}" self.cases_already_predicted = np.unique( - [i[: -len("." + suffix)] for i in subfiles(self.pred_save_dir, suffix=self.suffix, join=False)] + [ + i[: -len("." + self.prediction_suffix)] + for i in subfiles(self.pred_save_dir, suffix=self.prediction_suffix, join=False) + ] ) logging.info(f"Found {len(self.cases_already_predicted)} already predicted cases. Overwrite: {self.overwrite}") if not self.overwrite: diff --git a/yucca/modules/lightning_modules/BaseLightningModule.py b/yucca/modules/lightning_modules/BaseLightningModule.py index 5d53bcd2..ecd8c4a6 100644 --- a/yucca/modules/lightning_modules/BaseLightningModule.py +++ b/yucca/modules/lightning_modules/BaseLightningModule.py @@ -20,6 +20,7 @@ def __init__( num_classes: int, num_modalities: int, patch_size: tuple, + config: dict = None, crop_to_nonzero: bool = True, deep_supervision: bool = False, disable_inference_preprocessing: bool = False, diff --git a/yucca/modules/lightning_modules/ClassificationLightningModule.py b/yucca/modules/lightning_modules/ClassificationLightningModule.py new file mode 100644 index 00000000..0fbd68ff --- /dev/null +++ b/yucca/modules/lightning_modules/ClassificationLightningModule.py @@ -0,0 +1,156 @@ +import torch +import wandb +import logging +from torchmetrics import MetricCollection +from yucca.functional.utils.kwargs import filter_kwargs +from yucca.modules.metrics.training_metrics import Accuracy +from yucca.modules.lightning_modules.YuccaLightningModule import YuccaLightningModule +from yucca.modules.optimization.loss_functions.nnUNet_losses import DiceCE +from yucca.functional.visualization import get_cls_train_fig_with_inp_out_tar + + +class ClassificationLightningModule(YuccaLightningModule): + """ + The YuccaLightningModule class is an implementation of the PyTorch Lightning module designed for the Yucca project. + It extends the LightningModule class and encapsulates the neural network model, loss functions, and optimization logic. + This class is responsible for handling training, validation, and inference steps within the Yucca machine learning pipeline. + """ + + def __init__( + self, + config: dict, + model: torch.nn.Module, + deep_supervision: bool = False, + disable_inference_preprocessing: bool = False, + loss_fn: torch.nn.Module = DiceCE, + loss_kwargs: dict = { + "soft_dice_kwargs": {"apply_softmax": True}, + }, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = torch.optim.lr_scheduler.CosineAnnealingLR, + lr_scheduler_kwargs: dict = { + "eta_min": 1e-9, + }, + model_kwargs: dict = {}, + optimizer: torch.optim.Optimizer = torch.optim.SGD, + optimizer_kwargs={ + "lr": 1e-3, + }, + sliding_window_overlap: float = 0.5, + step_logging: bool = False, + test_time_augmentation: bool = False, + preprocessor=None, + progress_bar: bool = False, + log_image_every_n_epochs: int = None, + ): + super().__init__( + config=config, + model=model, + deep_supervision=deep_supervision, + disable_inference_preprocessing=disable_inference_preprocessing, + loss_fn=loss_fn, + loss_kwargs=loss_kwargs, + lr_scheduler=lr_scheduler, + lr_scheduler_kwargs=lr_scheduler_kwargs, + model_kwargs=model_kwargs, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + preprocessor=preprocessor, + progress_bar=progress_bar, + sliding_window_overlap=sliding_window_overlap, + step_logging=step_logging, + test_time_augmentation=test_time_augmentation, + ) + + self.config = config + self.log_image_every_n_epochs = log_image_every_n_epochs + self.get_train_fig_fn = get_cls_train_fig_with_inp_out_tar + self.save_hyperparameters(ignore=["model", "loss_fn", "lr_scheduler", "optimizer", "preprocessor"]) + + def setup(self, stage): # noqa: U100 + logging.info(f"Loading Model: {self.model_dimensions} {self.model.__name__}") + if self.model_dimensions == "3D": + conv_op = torch.nn.Conv3d + norm_op = torch.nn.InstanceNorm3d + else: + conv_op = torch.nn.Conv2d + norm_op = torch.nn.BatchNorm2d + + model_kwargs = {"conv_op": conv_op, "norm_op": norm_op} + model_kwargs.update(self.model_kwargs) + model_kwargs = filter_kwargs(self.model, model_kwargs) + self.model = self.model(input_channels=self.num_modalities, num_classes=self.num_classes, **model_kwargs) + self.visualize_model_with_FLOPs() + + def configure_metrics(self): + tmetrics_task = "multiclass" # if self.num_classes > 2 else "binary" + self.train_metrics = MetricCollection( + { + "train/acc": Accuracy(task=tmetrics_task, num_classes=self.num_classes), + # "train/roc_auc": AUROC(task=tmetrics_task, num_classes=self.num_classes), + } + ) + self.val_metrics = MetricCollection( + { + "val/acc": Accuracy(task=tmetrics_task, num_classes=self.num_classes), + # "val/roc_auc": AUROC(task=tmetrics_task, num_classes=self.num_classes), + } + ) + + def training_step(self, batch, batch_idx): + inputs, target, file_path = batch["image"], batch["label"], batch["file_path"] + output = self(inputs) + loss = self.loss_fn_train(output, target) + + if self.deep_supervision: + # If deep_supervision is enabled output and target will be a list of (downsampled) tensors. + # We only need the original ground truth and its corresponding prediction which is always the first entry in each list. + output = output[0] + target = target[0] + + metrics = self.compute_metrics(self.train_metrics, output, target, ignore_index=None) + self.log_dict( + {"train/loss": loss} | metrics, + on_step=self.step_logging, + on_epoch=self.epoch_logging, + prog_bar=self.progress_bar, + logger=True, + ) + + if batch_idx == 0 and wandb.run is not None and self.log_image_this_epoch is True: + self._log_dict_of_images_to_wandb( + { + "input": inputs.detach().cpu().to(torch.float32).numpy(), + "target": target.detach().cpu().to(torch.float32).numpy(), + "output": output.detach().cpu().to(torch.float32).numpy(), + "file_path": file_path, + }, + log_key="train", + ) + + return loss + + def validation_step(self, batch, batch_idx): + inputs, target, file_path = batch["image"], batch["label"], batch["file_path"] + output = self(inputs) + + loss = self.loss_fn_val(output, target) + + metrics = self.compute_metrics(self.val_metrics, output, target, ignore_index=None) + self.log_dict( + {"val/loss": loss} | metrics, + on_step=self.step_logging, + on_epoch=self.epoch_logging, + prog_bar=self.progress_bar, + logger=True, + ) + + if batch_idx == 0 and wandb.run is not None and self.log_image_this_epoch is True: + self._log_dict_of_images_to_wandb( + { + "input": inputs.detach().cpu().to(torch.float32).numpy(), + "target": target.detach().cpu().to(torch.float32).numpy(), + "output": output.detach().cpu().to(torch.float32).numpy(), + "file_path": file_path, + }, + log_key="val", + ) diff --git a/yucca/modules/lightning_modules/ClassificationLightningModule_Covariates.py b/yucca/modules/lightning_modules/ClassificationLightningModule_Covariates.py new file mode 100644 index 00000000..2bcf2901 --- /dev/null +++ b/yucca/modules/lightning_modules/ClassificationLightningModule_Covariates.py @@ -0,0 +1,101 @@ +from yucca.modules.lightning_modules.ClassificationLightningModule import ClassificationLightningModule +from yucca.functional.preprocessing import reverse_preprocessing +import wandb +import torch +import logging +from yucca.functional.utils.torch_utils import measure_FLOPs +from fvcore.nn import flop_count_table + + +class ClassificationLightningModule_Covariates(ClassificationLightningModule): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, inputs, cov): + return self.model(inputs, cov) + + def training_step(self, batch, batch_idx): + inputs, cov, target, file_path = batch["image"], batch["covariates"], batch["label"], batch["file_path"] + output = self(inputs, cov) + loss = self.loss_fn_train(output, target) + + if self.deep_supervision: + # If deep_supervision is enabled output and target will be a list of (downsampled) tensors. + # We only need the original ground truth and its corresponding prediction which is always the first entry in each list. + output = output[0] + target = target[0] + + metrics = self.compute_metrics(self.train_metrics, output, target, ignore_index=None) + self.log_dict( + {"train/loss": loss} | metrics, + on_step=self.step_logging, + on_epoch=self.epoch_logging, + prog_bar=self.progress_bar, + logger=True, + ) + + if batch_idx == 0 and wandb.run is not None and self.log_image_this_epoch is True: + self._log_dict_of_images_to_wandb( + { + "input": inputs.detach().cpu().to(torch.float32).numpy(), + "target": target.detach().cpu().to(torch.float32).numpy(), + "output": output.detach().cpu().to(torch.float32).numpy(), + "file_path": file_path, + }, + log_key="train", + ) + + return loss + + def validation_step(self, batch, batch_idx): + inputs, cov, target, file_path = batch["image"], batch["covariates"], batch["label"], batch["file_path"] + output = self(inputs, cov) + + loss = self.loss_fn_val(output, target) + + metrics = self.compute_metrics(self.val_metrics, output, target, ignore_index=None) + self.log_dict( + {"val/loss": loss} | metrics, + on_step=self.step_logging, + on_epoch=self.epoch_logging, + prog_bar=self.progress_bar, + logger=True, + ) + + if batch_idx == 0 and wandb.run is not None and self.log_image_this_epoch is True: + self._log_dict_of_images_to_wandb( + { + "input": inputs.detach().cpu().to(torch.float32).numpy(), + "target": target.detach().cpu().to(torch.float32).numpy(), + "output": output.detach().cpu().to(torch.float32).numpy(), + "file_path": file_path, + }, + log_key="val", + ) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): # noqa: U100 + logits = self.model.predict(x=batch["data"], cov=batch["covariates"]) + if self.disable_inference_preprocessing: + logits, data_properties = reverse_preprocessing( + crop_to_nonzero=self.crop_to_nonzero, + images=logits, + image_properties=batch["data_properties"], + n_classes=self.num_classes, + transpose_forward=self.transpose_forward, + transpose_backward=self.transpose_backward, + ) + else: + logits, data_properties = self.preprocessor.reverse_preprocessing( + logits, batch["data_properties"], num_classes=self.num_classes + ) + return {"logits": logits, "properties": data_properties, "case_id": batch["case_id"]} + + def visualize_model_with_FLOPs(self): + try: + data = torch.randn((self.config["batch_size"], self.num_modalities, *self.patch_size)) + cov = torch.randn((2)) + flops = measure_FLOPs(self.model, (data, cov)) + del data + logging.info("\n" + flop_count_table(flops)) + except RuntimeError: + logging.info("\n Model architecture could not be visualized.") diff --git a/yucca/modules/lightning_modules/YuccaLightningModule.py b/yucca/modules/lightning_modules/YuccaLightningModule.py index bb8b7f2c..8fb9c5f3 100644 --- a/yucca/modules/lightning_modules/YuccaLightningModule.py +++ b/yucca/modules/lightning_modules/YuccaLightningModule.py @@ -9,7 +9,7 @@ from yucca.modules.optimization.loss_functions.deep_supervision import DeepSupervisionLoss from yucca.functional.utils.kwargs import filter_kwargs from yucca.modules.metrics.training_metrics import Accuracy, AUROC, GeneralizedDiceScore -from yucca.functional.visualization import get_train_fig_with_inp_out_tar +from yucca.functional.visualization import get_segm_train_fig_with_inp_out_tar from yucca.modules.lightning_modules.BaseLightningModule import BaseLightningModule from yucca.functional.utils.torch_utils import measure_FLOPs from fvcore.nn import flop_count_table @@ -49,7 +49,6 @@ def __init__( progress_bar: bool = False, log_image_every_n_epochs: int = None, ): - self.task_type = config["task_type"] self.use_label_regions = "use_label_regions" in config.keys() and config["use_label_regions"] super().__init__( model=model, @@ -79,6 +78,7 @@ def __init__( ) self.config = config self.log_image_every_n_epochs = log_image_every_n_epochs + self.get_train_fig_fn = get_segm_train_fig_with_inp_out_tar self.save_hyperparameters(ignore=["model", "loss_fn", "lr_scheduler", "optimizer", "preprocessor"]) def setup(self, stage): # noqa: U100 @@ -115,78 +115,57 @@ def visualize_model_with_FLOPs(self): logging.info("\n Model architecture could not be visualized.") def configure_metrics(self): - if self.task_type == "classification": - tmetrics_task = "multiclass" if self.num_classes > 2 else "binary" - # can we get per-class? - self.train_metrics = MetricCollection( - { - "train/acc": Accuracy(task=tmetrics_task, num_classes=self.num_classes), - "train/roc_auc": AUROC(task=tmetrics_task, num_classes=self.num_classes), - } - ) - self.val_metrics = MetricCollection( - { - "val/acc": Accuracy(task=tmetrics_task, num_classes=self.num_classes), - "val/roc_auc": AUROC(task=tmetrics_task, num_classes=self.num_classes), - } - ) - - if self.task_type == "segmentation": - self.train_metrics = MetricCollection( - { - "train/aggregated_dice": GeneralizedDiceScore( - multilabel=self.use_label_regions, - num_classes=self.num_classes, - include_background=self.num_classes == 1 or self.use_label_regions, - weight_type="linear", - per_class=False, - ), - "train/mean_dice": GeneralizedDiceScore( - multilabel=self.use_label_regions, - num_classes=self.num_classes, - include_background=self.num_classes == 1 or self.use_label_regions, - weight_type="linear", - average=True, - ), - "train/dice": GeneralizedDiceScore( - multilabel=self.use_label_regions, - num_classes=self.num_classes, - include_background=self.num_classes == 1 or self.use_label_regions, - weight_type="linear", - per_class=True, - ), - }, - ) - - self.val_metrics = MetricCollection( - { - "val/aggregated_dice": GeneralizedDiceScore( - multilabel=self.use_label_regions, - num_classes=self.num_classes, - include_background=self.num_classes == 1 or self.use_label_regions, - weight_type="linear", - per_class=False, - ), - "val/mean_dice": GeneralizedDiceScore( - multilabel=self.use_label_regions, - num_classes=self.num_classes, - include_background=self.num_classes == 1 or self.use_label_regions, - weight_type="linear", - average=True, - ), - "val/dice": GeneralizedDiceScore( - multilabel=self.use_label_regions, - num_classes=self.num_classes, - include_background=self.num_classes == 1 or self.use_label_regions, - weight_type="linear", - per_class=True, - ), - }, - ) + self.train_metrics = MetricCollection( + { + "train/aggregated_dice": GeneralizedDiceScore( + multilabel=self.use_label_regions, + num_classes=self.num_classes, + include_background=self.num_classes == 1 or self.use_label_regions, + weight_type="linear", + per_class=False, + ), + "train/mean_dice": GeneralizedDiceScore( + multilabel=self.use_label_regions, + num_classes=self.num_classes, + include_background=self.num_classes == 1 or self.use_label_regions, + weight_type="linear", + average=True, + ), + "train/dice": GeneralizedDiceScore( + multilabel=self.use_label_regions, + num_classes=self.num_classes, + include_background=self.num_classes == 1 or self.use_label_regions, + weight_type="linear", + per_class=True, + ), + }, + ) - if self.task_type == "self-supervised": - self.train_metrics = MetricCollection({"train/MAE": MeanAbsoluteError()}) - self.val_metrics = MetricCollection({"train/MAE": MeanAbsoluteError()}) + self.val_metrics = MetricCollection( + { + "val/aggregated_dice": GeneralizedDiceScore( + multilabel=self.use_label_regions, + num_classes=self.num_classes, + include_background=self.num_classes == 1 or self.use_label_regions, + weight_type="linear", + per_class=False, + ), + "val/mean_dice": GeneralizedDiceScore( + multilabel=self.use_label_regions, + num_classes=self.num_classes, + include_background=self.num_classes == 1 or self.use_label_regions, + weight_type="linear", + average=True, + ), + "val/dice": GeneralizedDiceScore( + multilabel=self.use_label_regions, + num_classes=self.num_classes, + include_background=self.num_classes == 1 or self.use_label_regions, + weight_type="linear", + per_class=True, + ), + }, + ) def on_fit_start(self): if self.log_image_every_n_epochs is None: @@ -221,7 +200,6 @@ def training_step(self, batch, batch_idx): "file_path": file_path, }, log_key="train", - task_type=self.task_type, ) return loss @@ -249,7 +227,6 @@ def validation_step(self, batch, batch_idx): "file_path": file_path, }, log_key="val", - task_type=self.task_type, ) def configure_optimizers(self): @@ -281,16 +258,15 @@ def configure_optimizers(self): # Finally return the optimizer and scheduler - the loss is not returned. return {"optimizer": self.optim, "lr_scheduler": self.lr_scheduler} - def _log_dict_of_images_to_wandb(self, imagedict: {}, log_key: str, task_type: str = "segmentation"): + def _log_dict_of_images_to_wandb(self, imagedict: {}, log_key: str): batch_idx = np.random.randint(0, imagedict["input"].shape[0]) case = os.path.splitext(os.path.split(imagedict["file_path"][batch_idx])[-1])[0] - fig = get_train_fig_with_inp_out_tar( + fig = self.get_train_fig_fn( input=imagedict["input"][batch_idx], output=imagedict["output"][batch_idx], target=imagedict["target"][batch_idx], fig_title=case, - task_type=task_type, ) wandb.log({log_key: wandb.Image(fig)}, commit=False) plt.close(fig) @@ -321,7 +297,6 @@ def get_image_logging_epochs(final_epoch: int = 1000): "patch_size": (32, 32, 32), "plans_path": "", "patch_based_training": True, - "task_type": "segmentation", }, ) data = torch.randn((2, 1, *(32, 32, 32))) diff --git a/yucca/modules/lightning_modules/YuccaLightningModule_skeleton_loss.py b/yucca/modules/lightning_modules/YuccaLightningModule_skeleton_loss.py index 79e51b9a..06626b1e 100644 --- a/yucca/modules/lightning_modules/YuccaLightningModule_skeleton_loss.py +++ b/yucca/modules/lightning_modules/YuccaLightningModule_skeleton_loss.py @@ -47,7 +47,6 @@ def training_step(self, batch, batch_idx): "file_path": file_path, }, log_key="train", - task_type=self.task_type, ) return loss @@ -76,5 +75,4 @@ def validation_step(self, batch, batch_idx): "file_path": file_path, }, log_key="val", - task_type=self.task_type, ) diff --git a/yucca/modules/networks/blocks_and_layers/res_blocks.py b/yucca/modules/networks/blocks_and_layers/res_blocks.py new file mode 100644 index 00000000..2b794394 --- /dev/null +++ b/yucca/modules/networks/blocks_and_layers/res_blocks.py @@ -0,0 +1,144 @@ +from typing import Optional +import torch +from torch import nn, Tensor + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + conv_op=nn.Conv2d, + dropout_op: Optional[nn.Module] = None, + dropout_kwargs={"p": 0.25}, + norm_op=nn.BatchNorm2d, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, + ) -> None: + super().__init__() + + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv_k3( + conv_op=conv_op, in_planes=inplanes, out_planes=planes, stride=stride, groups=groups, dilation=dilation + ) + self.norm1 = norm_op(planes) + self.relu = nonlin(**nonlin_kwargs) + self.conv2 = conv_k3(conv_op=conv_op, in_planes=planes, out_planes=planes, stride=1, groups=1, dilation=1) + self.norm2 = norm_op(planes) + self.downsample = downsample + if dropout_op is not None: + self.dropout = dropout_op(**dropout_kwargs) + else: + self.dropout = None + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + out += identity + out = self.relu(out) + + if self.dropout is not None: + out = self.dropout(out) + + return out + + +def conv_k1(conv_op, in_planes: int, out_planes: int, stride: int = 1): + return conv_op(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def conv_k3(conv_op, in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1): + return conv_op( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +class Bottleneck(nn.Module): + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + conv_op=nn.Conv2d, + norm_op=nn.BatchNorm2d, + dropout_op: Optional[nn.Module] = None, + dropout_kwargs={"p": 0.25}, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, + ) -> None: + super().__init__() + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv_k1(conv_op=conv_op, in_planes=inplanes, out_planes=planes) + self.bn1 = norm_op(width) + self.conv2 = conv_k3( + conv_op=conv_op, in_planes=width, out_planes=width, stride=stride, groups=groups, dilation=dilation + ) + self.bn2 = norm_op(width) + self.conv3 = conv_k1(conv_op=conv_op, in_planes=width, out_planes=planes * self.expansion) + self.bn3 = norm_op(planes * self.expansion) + self.relu = nonlin(**nonlin_kwargs) + self.downsample = downsample + self.stride = stride + if dropout_op is not None: + self.dropout = dropout_op(**dropout_kwargs) + else: + self.dropout = None + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + if self.dropout is not None: + out = self.dropout(out) + + return out diff --git a/yucca/modules/networks/networks/3D_resnet.py b/yucca/modules/networks/networks/3D_resnet.py new file mode 100644 index 00000000..e8fe9e0b --- /dev/null +++ b/yucca/modules/networks/networks/3D_resnet.py @@ -0,0 +1,457 @@ +from typing import Union, List, Optional, Callable, Type +from pytorchvideo.models.resnet import create_resnet +from torch import nn, Tensor +import torch +from yucca.modules.networks.blocks_and_layers.res_blocks import BasicBlock, Bottleneck, conv_k1 +from yucca.modules.networks.networks import YuccaNet + + +class ResNet(YuccaNet): + def __init__( + self, + block: Type[BasicBlock], + layers: List[int], + in_channels, + num_classes: int = 1, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + dropout_op: Optional[nn.Module] = None, + dropout_kwargs: Optional[dict] = None, + conv_op=nn.Conv2d, + norm_op=nn.BatchNorm2d, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, + ) -> None: + super().__init__() + + self.inplanes = 64 + self.dilation = 1 + + if isinstance(conv_op, nn.Conv2d): + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + else: + self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) + + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + + self.groups = groups + self.base_width = width_per_group + self.conv1 = conv_op(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.norm1 = norm_op(self.inplanes) + self.relu = nonlin(**nonlin_kwargs) + self.layer1 = self._make_layer( + block=block, + blocks=layers[0], + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + conv_op=conv_op, + norm_op=norm_op, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + planes=64, + ) + self.layer2 = self._make_layer( + block=block, + blocks=layers[1], + conv_op=conv_op, + dilate=replace_stride_with_dilation[0], + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + norm_op=norm_op, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + planes=128, + stride=2, + ) + self.layer3 = self._make_layer( + block=block, + blocks=layers[2], + conv_op=conv_op, + dilate=replace_stride_with_dilation[1], + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + norm_op=norm_op, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + planes=256, + stride=2, + ) + self.layer4 = self._make_layer( + block=block, + blocks=layers[3], + conv_op=conv_op, + dilate=replace_stride_with_dilation[2], + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + norm_op=norm_op, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + planes=512, + stride=2, + ) + + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) or isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck) and m.bn3.weight is not None: + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock) and m.bn2.weight is not None: + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer( + self, + block: Type[BasicBlock], + blocks: int, + conv_op, + dropout_op, + dropout_kwargs, + norm_op, + nonlin, + nonlin_kwargs, + planes: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv_k1(conv_op, self.inplanes, planes * block.expansion, stride), + norm_op(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + base_width=self.base_width, + conv_op=conv_op, + dilation=previous_dilation, + downsample=downsample, + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + groups=self.groups, + inplanes=self.inplanes, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + norm_op=norm_op, + planes=planes, + stride=stride, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + base_width=self.base_width, + conv_op=conv_op, + dilation=self.dilation, + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + groups=self.groups, + inplanes=self.inplanes, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + norm_op=norm_op, + planes=planes, + ) + ) + + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + +class ResNet_cov(ResNet): + def __init__( + self, + n_covariates, + block: Type[BasicBlock], + layers: List[int], + in_channels, + num_classes: int = 1, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + dropout_op: Optional[nn.Module] = None, + dropout_kwargs: Optional[dict] = None, + conv_op=nn.Conv2d, + norm_op=nn.BatchNorm2d, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, + ) -> None: + super().__init__( + block=block, + layers=layers, + in_channels=in_channels, + num_classes=num_classes, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + replace_stride_with_dilation=replace_stride_with_dilation, + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + conv_op=conv_op, + norm_op=norm_op, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + ) + + self.fc1 = nn.Linear(512 * block.expansion, 100 - n_covariates) + self.fc2 = nn.Linear(100, num_classes) + + def forward(self, x: Tensor, cov: Tensor) -> Tensor: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = torch.concat((x, cov.to(x.dtype)), dim=1) + x = self.fc2(x) + + return x + + def predict(self, x: Tensor, cov: Tensor) -> Tensor: + return self.forward(x, cov) + + +def resnet18( + input_channels: int, + num_classes: int = 1, + conv_op=nn.Conv3d, + dropout_op=None, + dropout_kwargs={"p": 0.25}, + norm_op=nn.InstanceNorm3d, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, +) -> ResNet: + """ResNet-18 from `Deep Residual Learning for Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet18_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet18_Weights + :members: + """ + + return ResNet( + block=BasicBlock, + layers=[2, 2, 2, 2], + in_channels=input_channels, + num_classes=num_classes, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + conv_op=conv_op, + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + norm_op=norm_op, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + ) + + +def resnet18_2cov( + input_channels: int, + num_classes: int = 1, + conv_op=nn.Conv3d, + dropout_op=None, + dropout_kwargs={"p": 0.25}, + norm_op=nn.InstanceNorm3d, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, +) -> ResNet: + """ResNet-18 from `Deep Residual Learning for Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet18_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet18_Weights + :members: + """ + + return ResNet_cov( + block=BasicBlock, + layers=[2, 2, 2, 2], + in_channels=input_channels, + num_classes=num_classes, + n_covariates=2, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + conv_op=conv_op, + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + norm_op=norm_op, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + ) + + +def resnet34_2cov( + input_channels: int, + num_classes: int = 1, + conv_op=nn.Conv3d, + dropout_op=None, + dropout_kwargs={"p": 0.25}, + norm_op=nn.InstanceNorm3d, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, +) -> ResNet: + return ResNet_cov( + block=BasicBlock, + layers=[3, 4, 6, 3], + in_channels=input_channels, + num_classes=num_classes, + n_covariates=2, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + conv_op=conv_op, + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + norm_op=norm_op, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + ) + + +def resnet50_2cov( + input_channels: int, + num_classes: int = 1, + conv_op=nn.Conv3d, + dropout_op=None, + dropout_kwargs={"p": 0.25}, + norm_op=nn.InstanceNorm3d, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, +) -> ResNet: + return ResNet_cov( + block=Bottleneck, + layers=[3, 4, 6, 3], + in_channels=input_channels, + num_classes=num_classes, + n_covariates=2, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + conv_op=conv_op, + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + norm_op=norm_op, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + ) + + +def resnet18_dropout( + input_channels: int, + num_classes: int = 1, + conv_op=nn.Conv3d, + dropout_op=nn.Dropout3d, + dropout_kwargs={"p": 0.25}, + norm_op=nn.InstanceNorm3d, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, +) -> ResNet: + return ResNet( + block=BasicBlock, + layers=[2, 2, 2, 2], + in_channels=input_channels, + num_classes=num_classes, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + conv_op=conv_op, + dropout_op=dropout_op, + dropout_kwargs=dropout_kwargs, + norm_op=norm_op, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + ) + + +if __name__ == "__main__": + net = resnet50_2cov( + 2, + 1, + conv_op=nn.Conv3d, + norm_op=nn.InstanceNorm3d, + nonlin=nn.LeakyReLU, + ) + data = torch.zeros((2, 2, 64, 64, 64)) + cov = torch.zeros((2, 2)) + out = net(data, cov) + print(out) +# %% + +# %% diff --git a/yucca/modules/networks/networks/__init__.py b/yucca/modules/networks/networks/__init__.py index 4afe2946..c23a77fb 100644 --- a/yucca/modules/networks/networks/__init__.py +++ b/yucca/modules/networks/networks/__init__.py @@ -4,5 +4,6 @@ from .UNet import UNet from .UNetR import UNetR from .UXNet import UXNet +from .YuccaNet import YuccaNet networks = [MedNeXt, MultiResUNet, TinyUNet, UNet, UNetR, UXNet] diff --git a/yucca/modules/networks/networks/densenet.py b/yucca/modules/networks/networks/densenet.py new file mode 100644 index 00000000..aa2018af --- /dev/null +++ b/yucca/modules/networks/networks/densenet.py @@ -0,0 +1,296 @@ +from __future__ import annotations +from collections import OrderedDict +from collections.abc import Sequence +import torch +import torch.nn as nn +from monai.networks.layers.factories import Conv, Pool +from monai.networks.layers.utils import get_act_layer, get_norm_layer +from monai.networks.nets.densenet import _DenseBlock, _Transition + + +class DenseNet(nn.Module): + """ + Densenet based on: `Densely Connected Convolutional Networks `_. + Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16. + This network is non-deterministic When `spatial_dims` is 3 and CUDA is enabled. Please check the link below + for more details: + https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms + + Args: + spatial_dims: number of spatial dimensions of the input image. + input_channels: number of the input channel. + num_classes: number of the output classes. + init_features: number of filters in the first convolution layer. + growth_rate: how many filters to add each layer (k in paper). + block_config: how many layers in each pooling block. + bn_size: multiplicative factor for number of bottle neck layers. + (i.e. bn_size * k features in the bottleneck layer) + act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. + dropout_prob: dropout rate after each dense layer. + """ + + def __init__( + self, + conv_op: int, + input_channels: int, + num_classes: int, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 24, 16), + bn_size: int = 4, + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", + dropout_prob: float = 0.0, + ) -> None: + super().__init__() + + if isinstance(conv_op, nn.Conv2d): + spatial_dims = 2 + else: + spatial_dims = 3 + + conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] + pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] + self.avg_pool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[ + Pool.ADAPTIVEAVG, spatial_dims + ] + + self.features = nn.Sequential( + OrderedDict( + [ + ("conv0", conv_type(input_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ("norm0", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=init_features)), + ("relu0", get_act_layer(name=act)), + ("pool0", pool_type(kernel_size=3, stride=2, padding=1)), + ] + ) + ) + + input_channels = init_features + for i, num_layers in enumerate(block_config): + block = _DenseBlock( + spatial_dims=spatial_dims, + layers=num_layers, + in_channels=input_channels, + bn_size=bn_size, + growth_rate=growth_rate, + dropout_prob=dropout_prob, + act=act, + norm=norm, + ) + self.features.add_module(f"denseblock{i + 1}", block) + input_channels += num_layers * growth_rate + if i == len(block_config) - 1: + self.features.add_module( + "norm5", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=input_channels) + ) + else: + _out_channels = input_channels // 2 + trans = _Transition(spatial_dims, in_channels=input_channels, out_channels=_out_channels, act=act, norm=norm) + self.features.add_module(f"transition{i + 1}", trans) + input_channels = _out_channels + + # pooling and classification + self.fc_channels = input_channels + self.class_layers = nn.Sequential( + OrderedDict( + [ + ("relu", get_act_layer(name=act)), + ("pool", self.avg_pool_type(1)), + ("flatten", nn.Flatten(1)), + ("out", nn.Linear(self.fc_channels, num_classes)), + ] + ) + ) + for m in self.modules(): + if isinstance(m, conv_type): + nn.init.kaiming_normal_(torch.as_tensor(m.weight)) + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + nn.init.constant_(torch.as_tensor(m.weight), 1) + nn.init.constant_(torch.as_tensor(m.bias), 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(torch.as_tensor(m.bias), 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.features(x) + x = self.class_layers(x) + return x + + +class DenseNet121(DenseNet): + """DenseNet121 with optional pretrained support when `spatial_dims` is 2.""" + + def __init__( + self, + conv_op, + input_channels: int, + num_classes: int, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 24, 16), + **kwargs, + ) -> None: + super().__init__( + conv_op=conv_op, + input_channels=input_channels, + num_classes=num_classes, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) + + +class DenseNet169(DenseNet): + """DenseNet169 with optional pretrained support when `spatial_dims` is 2.""" + + def __init__( + self, + conv_op, + input_channels: int, + num_classes: int, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 32, 32), + **kwargs, + ) -> None: + super().__init__( + conv_op=conv_op, + input_channels=input_channels, + num_classes=num_classes, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) + + +class DenseNet201(DenseNet): + """DenseNet201 with optional pretrained support when `spatial_dims` is 2.""" + + def __init__( + self, + conv_op, + input_channels: int, + num_classes: int, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 48, 32), + **kwargs, + ) -> None: + super().__init__( + conv_op=conv_op, + input_channels=input_channels, + num_classes=num_classes, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) + + +class DenseNet264(DenseNet): + """DenseNet264""" + + def __init__( + self, + spatial_dims: int, + input_channels: int, + num_classes: int, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 64, 48), + **kwargs, + ) -> None: + super().__init__( + spatial_dims=spatial_dims, + input_channels=input_channels, + num_classes=num_classes, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) + + +class DenseNet_cov(DenseNet): + def __init__( + self, + conv_op: int, + input_channels: int, + num_classes: int, + n_covariates: int, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 24, 16), + bn_size: int = 4, + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", + dropout_prob: float = 0.0, + ) -> None: + super().__init__( + conv_op=conv_op, + input_channels=input_channels, + num_classes=num_classes, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + bn_size=bn_size, + act=act, + norm=norm, + dropout_prob=dropout_prob, + ) + + # pooling and classification + self.class_layers = nn.Sequential( + OrderedDict( + [ + ("relu", get_act_layer(name=act)), + ("pool", self.avg_pool_type(1)), + ("flatten", nn.Flatten(1)), + ("out", nn.Linear(self.fc_channels, 100 - n_covariates)), + ] + ) + ) + self.out = nn.Linear(100, num_classes) + + def forward(self, x: torch.Tensor, cov: torch.Tensor) -> torch.Tensor: + x = self.features(x) + x = self.class_layers(x) + x = torch.concat((x, cov.to(x.dtype)), dim=1) + x = self.out(x) + return x + + +def densenet121_2cov( + conv_op, + input_channels: int, + num_classes: int, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 24, 16), +): + return DenseNet_cov( + conv_op=conv_op, + input_channels=input_channels, + num_classes=num_classes, + n_covariates=2, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + ) + + +Densenet = DenseNet +Densenet121 = densenet121 = DenseNet121 +Densenet169 = densenet169 = DenseNet169 +Densenet201 = densenet201 = DenseNet201 +Densenet264 = densenet264 = DenseNet264 + +if __name__ == "__main__": + im = torch.ones((2, 1, 32, 32, 32)) + cov = torch.ones(2, 3) + net = DenseNet_cov(conv_op=torch.nn.Conv3d, n_covariates=3, input_channels=1, num_classes=4) + out = net(im, cov) + print(out) diff --git a/yucca/modules/networks/utils/model_memory_estimation.py b/yucca/modules/networks/utils/model_memory_estimation.py index 1a7fb5d0..9cd93c56 100644 --- a/yucca/modules/networks/utils/model_memory_estimation.py +++ b/yucca/modules/networks/utils/model_memory_estimation.py @@ -150,6 +150,9 @@ def find_optimal_tensor_dims( ): # ViT needs to be reinstantiated each time patch_size is changed so we use the normal UNet for proxy. model_name = "UNet" + elif model_name[-3:].lower() == "cov": + model_name = model_name.split("_")[0] + model = recursive_find_python_class( folder=[join(yucca.__path__[0], "modules", "networks")], class_name=model_name, diff --git a/yucca/modules/optimization/loss_functions/CE.py b/yucca/modules/optimization/loss_functions/CE.py index b5192df3..11ab0be0 100644 --- a/yucca/modules/optimization/loss_functions/CE.py +++ b/yucca/modules/optimization/loss_functions/CE.py @@ -7,6 +7,10 @@ class CE(nn.CrossEntropyLoss): + """ + input is expected to contain the unnormalized logits for each class + """ + def forward(self, input: Tensor, target: Tensor) -> Tensor: if len(target.shape) == len(input.shape): assert target.shape[1] == 1 diff --git a/yucca/pipeline/configuration/configure_callbacks.py b/yucca/pipeline/configuration/configure_callbacks.py index 6510c354..b171b478 100644 --- a/yucca/pipeline/configuration/configure_callbacks.py +++ b/yucca/pipeline/configuration/configure_callbacks.py @@ -110,7 +110,6 @@ def get_loggers( loggers = [ YuccaLogger( - disable_logging=not enable_logging, save_dir=save_dir, name=None, version=version, diff --git a/yucca/pipeline/configuration/configure_input_dims.py b/yucca/pipeline/configuration/configure_input_dims.py index e8edbe96..4dcf98c6 100644 --- a/yucca/pipeline/configuration/configure_input_dims.py +++ b/yucca/pipeline/configuration/configure_input_dims.py @@ -56,12 +56,12 @@ def get_input_dims_config( assert plan.get("new_max_size") == plan.get( "new_min_size" ), "sizes in dataset are not uniform. Non-patch based training only works for datasets with uniform data shapes." - patch_size = tuple(plan.get("new_max_size")) - logging.info(f"Getting patch size for non-patch based training") + patch_size = tuple(map(int, plan.get("new_max_size"))) + logging.info(f"Found patch size: {patch_size} for non-patch based training") else: # B.1. Try get patch from manager if patch_size is not None: - logging.info(f"Getting patch size based on manual input of: {patch_size}") + logging.info(f"Found patch size: {patch_size} from manual input") # Can be three things here: 1. a list/tuple of ints, 2. a list of one int/str or 3. just an int/str # First check case 1. if isinstance(patch_size, (list, tuple)) and len(patch_size) > 1: diff --git a/yucca/pipeline/managers/ClassificationManager.py b/yucca/pipeline/managers/ClassificationManager.py new file mode 100644 index 00000000..195e948d --- /dev/null +++ b/yucca/pipeline/managers/ClassificationManager.py @@ -0,0 +1,67 @@ +from yucca.pipeline.managers.YuccaManagerV2 import YuccaManagerV2 +from yucca.modules.data.augmentation.augmentation_presets import genericV2 +from yucca.modules.lightning_modules.ClassificationLightningModule import ClassificationLightningModule +from yucca.modules.data.datasets.ClassificationDataset import ClassificationTrainDataset, ClassificationTestDataset +from yucca.modules.optimization.loss_functions.CE import CE + + +class ClassificationManagerV2(YuccaManagerV2): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.augmentation_params = genericV2 + self.augmentation_params["skip_label"] = True + self.model_name = "resnet18" + self.loss = CE + self.lightning_module = ClassificationLightningModule + self.model_dimensions = "3D" + self.patch_based_training = False + self.deep_supervision = False + self.train_dataset_class = ClassificationTrainDataset + self.test_dataset_class = ClassificationTestDataset + + +class ClassificationManagerV9(ClassificationManagerV2): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.augmentation_params = self.set_aug_params() + self.optim_kwargs["weight_decay"] = 5e-2 + + def set_aug_params(self): + return { + "random_crop": True, + "mask_image_for_reconstruction": False, + "clip_to_input_range": True, # ensures no augmentations go beyond the input range of the image/patch + "normalize": False, + # label/segmentation transforms + "skip_label": True, + "label_dtype": int, + "copy_image_to_label": False, + # default augmentation probabilities + "additive_noise_p_per_sample": 0.4, + "biasfield_p_per_sample": 0.4, + "blurring_p_per_sample": 0.4, + "blurring_p_per_channel": 0.5, + "elastic_deform_p_per_sample": 0.0, + "gamma_p_per_sample": 0.4, + "gamma_p_invert_image": 0.05, + "gibbs_ringing_p_per_sample": 0.3, + "mirror_p_per_sample": 0.0, + "mirror_p_per_axis": 0.33, + "motion_ghosting_p_per_sample": 0.3, + "multiplicative_noise_p_per_sample": 0.4, + "rotation_p_per_sample": 0.33, + "rotation_p_per_axis": 0.66, + "scale_p_per_sample": 0.33, + "simulate_lowres_p_per_sample": 0.3, + "simulate_lowres_p_per_channel": 0.5, + "simulate_lowres_p_per_axis": 0.66, + # default augmentation values + } + + +class ClassificationManagerV10(ClassificationManagerV9): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + from yucca.modules.data.augmentation.augmentation_presets import channel_specific_probas + + self.augmentation_params.update(channel_specific_probas) diff --git a/yucca/pipeline/managers/alternative_managers/ClassificationManager_Covariates.py b/yucca/pipeline/managers/alternative_managers/ClassificationManager_Covariates.py new file mode 100644 index 00000000..b9b695d7 --- /dev/null +++ b/yucca/pipeline/managers/alternative_managers/ClassificationManager_Covariates.py @@ -0,0 +1,34 @@ +from yucca.pipeline.managers.ClassificationManager import ClassificationManagerV2, ClassificationManagerV9 +from yucca.modules.lightning_modules.ClassificationLightningModule_Covariates import ClassificationLightningModule_Covariates +from yucca.modules.data.datasets.ClassificationDataset import ( + ClassificationTrainDatasetWithCovariates, +) +from yucca.modules.data.augmentation.augmentation_presets import genericV2 + + +class ClassificationManagerV9_Cov(ClassificationManagerV9): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_name = "resnet18_2cov" + self.lightning_module = ClassificationLightningModule_Covariates + self.train_dataset_class = ClassificationTrainDatasetWithCovariates + # self.test_dataset_class = ClassificationTestDatasetWithCovariates + + +class ClassificationManagerV9_DenseCov(ClassificationManagerV9): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_name = "densenet121_2cov" + self.lightning_module = ClassificationLightningModule_Covariates + self.train_dataset_class = ClassificationTrainDatasetWithCovariates + # self.test_dataset_class = ClassificationTestDatasetWithCovariates + + +class ClassificationManagerV10_DenseCov(ClassificationManagerV9_DenseCov): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.augmentation_params = genericV2 + self.optim_kwargs["weight_decay"] = 1e-2 + self.optim_kwargs["lr"] = 1e-3 + self.augmentation_params["skip_label"] = True + self.augmentation_params["label_dtype"] = int diff --git a/yucca/pipeline/planning/ClassificationPlanner.py b/yucca/pipeline/planning/ClassificationPlanner.py index b87b9be9..82b6664d 100644 --- a/yucca/pipeline/planning/ClassificationPlanner.py +++ b/yucca/pipeline/planning/ClassificationPlanner.py @@ -1,4 +1,5 @@ from yucca.pipeline.planning.YuccaPlanner import YuccaPlanner +from yucca.pipeline.planning.YuccaPlannerV2 import YuccaPlannerV2 class ClassificationPlanner(YuccaPlanner): @@ -6,3 +7,62 @@ def __init__(self, task, preprocessor=None, threads=None, disable_unittests=Fals super().__init__(task, preprocessor, threads, disable_unittests, view) self.name = str(self.__class__.__name__) + str(view or "") self.preprocessor = "ClassificationPreprocessor" + + +class Classification_PsyBrain(YuccaPlannerV2): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = str(self.__class__.__name__) + self.preprocessor = "ClassificationPreprocessor" + self.keep_aspect_ratio_when_using_target_size = True + self.crop_to_nonzero = False + + def determine_target_size_from_fixed_size_or_spacing(self): + self.fixed_target_size = (192, 224, 192) + self.fixed_target_spacing = None + + def determine_transpose(self): + self.transpose_fw = [0, 1, 2] + self.transpose_bw = [0, 1, 2] + + +class Classification_PsyBrain128(YuccaPlannerV2): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = str(self.__class__.__name__) + self.preprocessor = "ClassificationPreprocessor" + self.keep_aspect_ratio_when_using_target_size = False + self.crop_to_nonzero = False + + def determine_target_size_from_fixed_size_or_spacing(self): + self.fixed_target_size = (128, 128, 128) + self.fixed_target_spacing = None + + def determine_transpose(self): + self.transpose_fw = [0, 1, 2] + self.transpose_bw = [0, 1, 2] + + +class Classification_PsyBrain128V2(Classification_PsyBrain128): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def determine_norm_op_per_modality(self): + self.norm_op_per_modality = ["volume_wise_znorm", "no_norm"] + + +class Classification_PsyBrain128V2Cov(Classification_PsyBrain128): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preprocessor = "ClassificationPreprocessorWithCovariates" + + def determine_norm_op_per_modality(self): + self.norm_op_per_modality = ["volume_wise_znorm", "no_norm"] + +class Classification_PsyBrain128V3Cov(Classification_PsyBrain128): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preprocessor = "ClassificationPreprocessorWithCovariates" + + def determine_norm_op_per_modality(self): + self.norm_op_per_modality = ["volume_wise_znorm"] \ No newline at end of file diff --git a/yucca/pipeline/planning/YuccaPlanner.py b/yucca/pipeline/planning/YuccaPlanner.py index 9013e870..459ffc52 100644 --- a/yucca/pipeline/planning/YuccaPlanner.py +++ b/yucca/pipeline/planning/YuccaPlanner.py @@ -323,3 +323,4 @@ def __init__(self, *args, **kwargs): def determine_target_size_from_fixed_size_or_spacing(self): self.fixed_target_size = (224, 224) self.fixed_target_spacing = None + diff --git a/yucca/pipeline/planning/dataset_properties.py b/yucca/pipeline/planning/dataset_properties.py index 25c82328..a4c6942e 100644 --- a/yucca/pipeline/planning/dataset_properties.py +++ b/yucca/pipeline/planning/dataset_properties.py @@ -53,7 +53,7 @@ def create_dataset_properties(data_dir, save_dir, suffix=".nii.gz", num_workers= suffix = f"_{mod_id:03}.{image_extension}" subjects = [] images = subfiles(images_dir, suffix=suffix, join=False) - + print(images_dir, suffix) for image in images: image_path = join(images_dir, image) # Remove modality encoding @@ -182,6 +182,8 @@ def process(subject: str, background_pixel_value: int = 0): mask = label > 0 image_msk = image[mask] else: + if np.max(image) <= background_pixel_value: + background_pixel_value = np.min(image) image_msk = image[image > background_pixel_value] mean = np.mean(image_msk) diff --git a/yucca/pipeline/preprocessing/ClassificationPreprocessor.py b/yucca/pipeline/preprocessing/ClassificationPreprocessor.py index b11e21e4..06a0bcbe 100644 --- a/yucca/pipeline/preprocessing/ClassificationPreprocessor.py +++ b/yucca/pipeline/preprocessing/ClassificationPreprocessor.py @@ -1,4 +1,8 @@ import torch +import os +import numpy as np +import re +from typing import Optional from yucca.pipeline.preprocessing.YuccaPreprocessor import YuccaPreprocessor @@ -6,14 +10,44 @@ class ClassificationPreprocessor(YuccaPreprocessor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # set up for classification - self.classification = True self.label_exists = True self.preprocess_label = False - def reverse_preprocessing(self, images: torch.Tensor, image_properties: dict): + def reverse_preprocessing(self, images: torch.Tensor, image_properties: dict, num_classes: Optional[int] = None): """ Expected shape of images are: (b, c, x) """ image_properties["save_format"] = "txt" - return images.cpu().numpy(), image_properties + return images.float().cpu().numpy(), image_properties + + def cast_to_numpy_array(self, images: list, label=None, classification=False): + canvas = np.empty(2, dtype="object") + images = np.vstack([image[np.newaxis] for image in images]) + canvas[:] = [images, label] + images = canvas + return images + + +class ClassificationPreprocessorWithCovariates(ClassificationPreprocessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def initialize_paths(self): + super().initialize_paths() + self.covariatepaths = os.path.join(self.input_dir, "covariatesTr") + + def _preprocess_train_subject(self, subject_id, label_exists: bool, preprocess_label: bool): + images, label, image_props = super()._preprocess_train_subject(subject_id, label_exists, preprocess_label) + covariates = np.loadtxt(os.path.join(self.covariatepaths, re.escape(subject_id) + "_COV.txt")) + label = np.array([covariates, label], dtype="object") + return images, label, image_props + + def cast_to_numpy_array(self, images: list, label=None, classification=False): + # In this scenario the labels will also contain the covariates + + canvas = np.empty(3, dtype="object") + images = np.vstack([image[np.newaxis] for image in images]) + canvas[:] = [images, label[0], label[-1]] + images = canvas + return images diff --git a/yucca/pipeline/preprocessing/UnsupervisedPreprocessor.py b/yucca/pipeline/preprocessing/UnsupervisedPreprocessor.py index 54c50b84..a38ccccd 100644 --- a/yucca/pipeline/preprocessing/UnsupervisedPreprocessor.py +++ b/yucca/pipeline/preprocessing/UnsupervisedPreprocessor.py @@ -11,7 +11,6 @@ class UnsupervisedPreprocessor(YuccaPreprocessor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # set up for self-/unsupervised - self.classification = False self.label_exists = False self.preprocess_label = False diff --git a/yucca/pipeline/preprocessing/YuccaPreprocessor.py b/yucca/pipeline/preprocessing/YuccaPreprocessor.py index 80ac0693..febae2b1 100644 --- a/yucca/pipeline/preprocessing/YuccaPreprocessor.py +++ b/yucca/pipeline/preprocessing/YuccaPreprocessor.py @@ -92,7 +92,6 @@ def __init__( self.target_spacing = [] # set up for segmentation - self.classification = False self.label_exists = True self.preprocess_label = True @@ -213,7 +212,7 @@ def preprocess_train_subject(self, subject_id): images, label, image_props = self._preprocess_train_subject( subject_id, label_exists=self.label_exists, preprocess_label=self.preprocess_label ) - images = self.cast_to_numpy_array(images=images, label=label, classification=self.classification) + images = self.cast_to_numpy_array(images=images, label=label) # save the image if self.compress: @@ -277,7 +276,6 @@ def _preprocess_train_subject(self, subject_id, label_exists: bool, preprocess_l # Check if impath is a modality of subject_id (subject_id + _XXX + .) where XXX are three digits if re.search(escaped_subject_id + "_" + r"\d{3}" + ".", os.path.split(impath)[-1]) ] - missing_modalities = self.sanity_check_modalities_and_return_missing( imagepaths=imagepaths, normalization_schemes=self.plans["normalization_scheme"], @@ -468,7 +466,9 @@ def sanity_check_modalities_and_return_missing(imagepaths, normalization_schemes assert len(imagepaths) > 0, "found no images" if not allow_missing_modalities: - assert not len(missing_modalities) > 0, "found missing modalities and allow_missing_modalities is not enabled." + assert ( + not len(missing_modalities) > 0 + ), f"found missing modalities and allow_missing_modalities is not enabled. Expected: {expected_modalities} and found: {found_modalities}" return missing_modalities def cast_to_numpy_array(self, images: list, label=None, classification=False): @@ -476,9 +476,6 @@ def cast_to_numpy_array(self, images: list, label=None, classification=False): images = np.array(images, dtype=np.float32) elif label is None and self.allow_missing_modalities: # self-supervised with missing mods images = np.array(images, dtype="object") - elif classification: # Classification is always "object" - images.append(label) - images = np.array(images, dtype="object") elif self.allow_missing_modalities: # segmentation with missing modalities images.append(np.array(label)[np.newaxis]) images = np.array(images, dtype="object") diff --git a/yucca/pipeline/task_conversion/Task503_ADNI300_MRI.py b/yucca/pipeline/task_conversion/Task503_ADNI300_MRI.py new file mode 100644 index 00000000..4abf9b30 --- /dev/null +++ b/yucca/pipeline/task_conversion/Task503_ADNI300_MRI.py @@ -0,0 +1,125 @@ +# Dataset containing 150 AD and 150 controls used to test Classification pipeline. +# In this version we ONLY use the MRI and the labels. + + +import shutil + +import numpy as np +import pandas as pd +from tqdm import tqdm +import gzip +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p as ensure_dir_exists, isfile +from yucca.paths import get_raw_data_path +from yucca.pipeline.task_conversion.utils import generate_dataset_json +from sklearn.model_selection import train_test_split +from yucca.functional.utils.nib_utils import get_nib_orientation, reorient_nib_image +import nibabel as nib + + +def convert(path: str = "/home/zcr545/data/data/projects/PsyBrainPrediction", subdir: str = "ADNI"): + path = join(path, subdir) + + # Train and Test images are in the same folder + images_dir = join(path, "Images") + + # OUTPUT DATA + # Target names + task_name = "Task503_ADNI300_MRI" + prefix = "ADNI300_MRI" + + # Target paths + target_base = join(get_raw_data_path(), task_name) + + target_imagesTr = join(target_base, "imagesTr") + target_labelsTr = join(target_base, "labelsTr") + + target_imagesTs = join(target_base, "imagesTs") + target_labelsTs = join(target_base, "labelsTs") + + ensure_dir_exists(target_imagesTr) + ensure_dir_exists(target_labelsTs) + ensure_dir_exists(target_imagesTs) + ensure_dir_exists(target_labelsTr) + + # collect labels + AD_cases = pd.read_csv(join(path, "AD_group.csv")) + CN_cases = pd.read_csv(join(path, "CN_group.csv")) + + # Check that no subjects appear twice in either group + assert len(set(AD_cases.Subject)) == 150 + assert len(set(CN_cases.Subject)) == 150 + + # Check that no subjects appear in both groups + assert len(np.intersect1d(AD_cases.Subject, CN_cases.Subject)) == 0 + + train_AD, test_AD = train_test_split(AD_cases, random_state=418920) + train_CN, test_CN = train_test_split(CN_cases, random_state=537289) + + all_train = pd.concat([train_AD, train_CN]) + all_test = pd.concat([test_AD, test_CN]) + + # Populate the training folders + for _, row in tqdm(all_train.iterrows(), total=len(all_train)): + subject = row["Subject"] + label = row["Group"] + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + else: + print("Found unexpected labels: ", label) + + image_path = join(path, "Data", subject, "T1.nii") + image_file = nib.load(image_path) + + ort = get_nib_orientation(image_file) + image_file = reorient_nib_image(image_file, original_orientation=ort, target_orientation="RAS") + + image_save_path = f"{target_imagesTr}/{prefix}_{subject}_000.nii.gz" + label_save_path = f"{target_labelsTr}/{prefix}_{subject}.txt" + + np.savetxt(label_save_path, np.array([label]), fmt="%s") + if not isfile(image_save_path): + nib.save(image_file, image_save_path) + + # Populate the test folders + for _, row in tqdm(all_test.iterrows(), total=len(all_test)): + subject = row["Subject"] + label = row["Group"] + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + else: + print("Found unexpected labels: ", label) + + image_path = join(path, "Data", subject, "T1.nii") + image_file = nib.load(image_path) + + ort = get_nib_orientation(image_file) + image_file = reorient_nib_image(image_file, original_orientation=ort, target_orientation="RAS") + + image_save_path = f"{target_imagesTs}/{prefix}_{subject}_000.nii.gz" + label_save_path = f"{target_labelsTs}/{prefix}_{subject}.txt" + + np.savetxt(label_save_path, np.array([label]), fmt="%s") + if not isfile(image_save_path): + nib.save(image_file, image_save_path) + + generate_dataset_json( + join(target_base, "dataset.json"), + target_imagesTr, + target_imagesTs, + ("T1",), + labels={0: "CN", 1: "AD"}, + dataset_name=task_name, + license="CC-BY 4.0", + dataset_description="ADNI", + dataset_reference="", + ) + + +if __name__ == "__main__": + convert() diff --git a/yucca/pipeline/task_conversion/Task504_ADNI300_AFFSS_MRI.py b/yucca/pipeline/task_conversion/Task504_ADNI300_AFFSS_MRI.py new file mode 100644 index 00000000..cd762ce4 --- /dev/null +++ b/yucca/pipeline/task_conversion/Task504_ADNI300_AFFSS_MRI.py @@ -0,0 +1,117 @@ +# Dataset containing 150 AD and 150 controls used to test Classification pipeline. +# In this version we ONLY use the MRI and the labels. + + +import shutil + +import numpy as np +import pandas as pd +from tqdm import tqdm +import gzip +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p as ensure_dir_exists, isfile +from yucca.paths import get_raw_data_path +from yucca.pipeline.task_conversion.utils import generate_dataset_json +from sklearn.model_selection import train_test_split +from yucca.functional.utils.nib_utils import get_nib_orientation, reorient_nib_image +import nibabel as nib + + +def convert(path: str = "/home/zcr545/data/data/projects/PsyBrainPrediction", subdir: str = "ADNI"): + path = join(path, subdir) + + # Train and Test images are in the same folder + images_dir = join(path, "Images") + + # OUTPUT DATA + # Target names + task_name = "Task504_ADNI300_AFFSS_MRI" + prefix = "ADNI300_AFFSS_MRI" + + # Target paths + target_base = join(get_raw_data_path(), task_name) + + target_imagesTr = join(target_base, "imagesTr") + target_labelsTr = join(target_base, "labelsTr") + + target_imagesTs = join(target_base, "imagesTs") + target_labelsTs = join(target_base, "labelsTs") + + ensure_dir_exists(target_imagesTr) + ensure_dir_exists(target_labelsTs) + ensure_dir_exists(target_imagesTs) + ensure_dir_exists(target_labelsTr) + + # collect labels + AD_cases = pd.read_csv(join(path, "AD_group.csv")) + CN_cases = pd.read_csv(join(path, "CN_group.csv")) + + # Check that no subjects appear twice in either group + assert len(set(AD_cases.Subject)) == 150 + assert len(set(CN_cases.Subject)) == 150 + + # Check that no subjects appear in both groups + assert len(np.intersect1d(AD_cases.Subject, CN_cases.Subject)) == 0 + + train_AD, test_AD = train_test_split(AD_cases, random_state=418920) + train_CN, test_CN = train_test_split(CN_cases, random_state=537289) + + all_train = pd.concat([train_AD, train_CN]) + all_test = pd.concat([test_AD, test_CN]) + + # Populate the training folders + for _, row in tqdm(all_train.iterrows(), total=len(all_train)): + subject = row["Subject"] + label = row["Group"] + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + else: + print("Found unexpected labels: ", label) + + image_path = join(path, "MNI/Data_Affine_SkullStripped", subject, "T1_mni_affine.nii.gz") + + image_save_path = f"{target_imagesTr}/{prefix}_{subject}_000.nii.gz" + label_save_path = f"{target_labelsTr}/{prefix}_{subject}.txt" + + np.savetxt(label_save_path, np.array([label]), fmt="%s") + if not isfile(image_save_path): + shutil.copy2(image_path, image_save_path) + + # Populate the test folders + for _, row in tqdm(all_test.iterrows(), total=len(all_test)): + subject = row["Subject"] + label = row["Group"] + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + else: + print("Found unexpected labels: ", label) + + image_path = join(path, "MNI/Data_Affine_SkullStripped", subject, "T1_mni_affine.nii.gz") + + image_save_path = f"{target_imagesTs}/{prefix}_{subject}_000.nii.gz" + label_save_path = f"{target_labelsTs}/{prefix}_{subject}.txt" + + np.savetxt(label_save_path, np.array([label]), fmt="%s") + if not isfile(image_save_path): + shutil.copy2(image_path, image_save_path) + + generate_dataset_json( + join(target_base, "dataset.json"), + target_imagesTr, + target_imagesTs, + ("T1",), + labels={0: "CN", 1: "AD"}, + dataset_name=task_name, + license="CC-BY 4.0", + dataset_description="ADNI", + dataset_reference="", + ) + + +if __name__ == "__main__": + convert() diff --git a/yucca/pipeline/task_conversion/Task505_ADNI900_MRI.py b/yucca/pipeline/task_conversion/Task505_ADNI900_MRI.py new file mode 100644 index 00000000..22f3f89e --- /dev/null +++ b/yucca/pipeline/task_conversion/Task505_ADNI900_MRI.py @@ -0,0 +1,134 @@ +# Dataset containing 150 AD and 150 controls used to test Classification pipeline. +# In this version we ONLY use the MRI and the labels. + + +import shutil + +import numpy as np +import pandas as pd +from tqdm import tqdm +import gzip +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p as ensure_dir_exists, isfile +from yucca.paths import get_raw_data_path +from yucca.pipeline.task_conversion.utils import generate_dataset_json +from sklearn.model_selection import train_test_split +from yucca.functional.utils.nib_utils import get_nib_orientation, reorient_nib_image +import nibabel as nib + + +def convert(path: str = "/home/zcr545/data/data/projects/PsyBrainPrediction", subdir: str = "ADNI"): + path = join(path, subdir) + + # Train and Test images are in the same folder + images_dir = join(path, "Images") + + # OUTPUT DATA + # Target names + task_name = "Task505_ADNI900_MRI" + prefix = "ADNI900_MRI" + + # Target paths + target_base = join(get_raw_data_path(), task_name) + + target_imagesTr = join(target_base, "imagesTr") + target_labelsTr = join(target_base, "labelsTr") + + target_imagesTs = join(target_base, "imagesTs") + target_labelsTs = join(target_base, "labelsTs") + + ensure_dir_exists(target_imagesTr) + ensure_dir_exists(target_labelsTs) + ensure_dir_exists(target_imagesTs) + ensure_dir_exists(target_labelsTr) + + # collect labels + AD_cases = pd.read_csv(join(path, "AD_group.csv")) + CN_cases = pd.read_csv(join(path, "CN_group.csv")) + MCI_cases = pd.read_csv(join(path, "MCI_group.csv")) + + # Check that no subjects appear twice in either group + assert len(set(AD_cases.Subject)) == 300 + assert len(set(CN_cases.Subject)) == 300 + assert len(set(MCI_cases.Subject)) == 300 + + # Check that no subjects appear in both groups + assert len(np.intersect1d(AD_cases.Subject, CN_cases.Subject)) == 0 + assert len(np.intersect1d(AD_cases.Subject, MCI_cases.Subject)) == 0 + assert len(np.intersect1d(CN_cases.Subject, MCI_cases.Subject)) == 0 + + train_AD, test_AD = train_test_split(AD_cases, random_state=418920) + train_CN, test_CN = train_test_split(CN_cases, random_state=537289) + train_MCI, test_MCI = train_test_split(MCI_cases, random_state=958013) + + all_train = pd.concat([train_AD, train_CN, train_MCI]) + all_test = pd.concat([test_AD, test_CN, test_MCI]) + + # Populate the training folders + for _, row in tqdm(all_train.iterrows(), total=len(all_train)): + subject = row["Subject"] + label = row["Group"] + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + elif label == "MCI": + label = 2 + else: + print("Found unexpected labels: '", label, "'") + + image_path = join(path, "Data", subject, "T1.nii") + image_file = nib.load(image_path) + + ort = get_nib_orientation(image_file) + image_file = reorient_nib_image(image_file, original_orientation=ort, target_orientation="RAS") + + image_save_path = f"{target_imagesTr}/{prefix}_{subject}_000.nii.gz" + label_save_path = f"{target_labelsTr}/{prefix}_{subject}.txt" + + np.savetxt(label_save_path, np.array([label]), fmt="%s") + if not isfile(image_save_path): + nib.save(image_file, image_save_path) + + # Populate the test folders + for _, row in tqdm(all_test.iterrows(), total=len(all_test)): + subject = row["Subject"] + label = row["Group"] + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + elif label == "MCI": + label = 2 + else: + print("Found unexpected labels: ", label) + + image_path = join(path, "Data", subject, "T1.nii") + image_file = nib.load(image_path) + + ort = get_nib_orientation(image_file) + image_file = reorient_nib_image(image_file, original_orientation=ort, target_orientation="RAS") + + image_save_path = f"{target_imagesTs}/{prefix}_{subject}_000.nii.gz" + label_save_path = f"{target_labelsTs}/{prefix}_{subject}.txt" + + np.savetxt(label_save_path, np.array([label]), fmt="%s") + if not isfile(image_save_path): + nib.save(image_file, image_save_path) + + generate_dataset_json( + join(target_base, "dataset.json"), + target_imagesTr, + target_imagesTs, + ("T1",), + labels={0: "CN", 1: "AD", 2: "MCI"}, + dataset_name=task_name, + license="CC-BY 4.0", + dataset_description="ADNI", + dataset_reference="", + ) + + +if __name__ == "__main__": + convert() diff --git a/yucca/pipeline/task_conversion/Task506_ADNI900_AFFSS_MRI.py b/yucca/pipeline/task_conversion/Task506_ADNI900_AFFSS_MRI.py new file mode 100644 index 00000000..0f5c8bab --- /dev/null +++ b/yucca/pipeline/task_conversion/Task506_ADNI900_AFFSS_MRI.py @@ -0,0 +1,126 @@ +# Dataset containing 150 AD and 150 controls used to test Classification pipeline. +# In this version we ONLY use the MRI and the labels. + + +import shutil + +import numpy as np +import pandas as pd +from tqdm import tqdm +import gzip +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p as ensure_dir_exists, isfile +from yucca.paths import get_raw_data_path +from yucca.pipeline.task_conversion.utils import generate_dataset_json +from sklearn.model_selection import train_test_split +from yucca.functional.utils.nib_utils import get_nib_orientation, reorient_nib_image +import nibabel as nib + + +def convert(path: str = "/home/zcr545/data/data/projects/PsyBrainPrediction", subdir: str = "ADNI"): + path = join(path, subdir) + + # Train and Test images are in the same folder + images_dir = join(path, "Images") + + # OUTPUT DATA + # Target names + task_name = "Task506_ADNI900_AFFSS_MRI" + prefix = "ADNI900_AFFSS_MRI" + + # Target paths + target_base = join(get_raw_data_path(), task_name) + + target_imagesTr = join(target_base, "imagesTr") + target_labelsTr = join(target_base, "labelsTr") + + target_imagesTs = join(target_base, "imagesTs") + target_labelsTs = join(target_base, "labelsTs") + + ensure_dir_exists(target_imagesTr) + ensure_dir_exists(target_labelsTs) + ensure_dir_exists(target_imagesTs) + ensure_dir_exists(target_labelsTr) + + # collect labels + AD_cases = pd.read_csv(join(path, "AD_group.csv")) + CN_cases = pd.read_csv(join(path, "CN_group.csv")) + MCI_cases = pd.read_csv(join(path, "MCI_group.csv")) + + # Check that no subjects appear twice in either group + assert len(set(AD_cases.Subject)) == 300 + assert len(set(CN_cases.Subject)) == 300 + assert len(set(MCI_cases.Subject)) == 300 + + # Check that no subjects appear in both groups + assert len(np.intersect1d(AD_cases.Subject, CN_cases.Subject)) == 0 + assert len(np.intersect1d(AD_cases.Subject, MCI_cases.Subject)) == 0 + assert len(np.intersect1d(CN_cases.Subject, MCI_cases.Subject)) == 0 + + train_AD, test_AD = train_test_split(AD_cases, random_state=418920) + train_CN, test_CN = train_test_split(CN_cases, random_state=537289) + train_MCI, test_MCI = train_test_split(MCI_cases, random_state=958013) + + all_train = pd.concat([train_AD, train_CN, train_MCI]) + all_test = pd.concat([test_AD, test_CN, test_MCI]) + + # Populate the training folders + for _, row in tqdm(all_train.iterrows(), total=len(all_train)): + subject = row["Subject"] + label = row["Group"] + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + elif label == "MCI": + label = 2 + else: + print("Found unexpected labels: '", label, "'") + + image_path = join(path, "MNI/Data_Affine_SkullStripped", subject, "T1_mni_affine.nii.gz") + + image_save_path = f"{target_imagesTr}/{prefix}_{subject}_000.nii.gz" + label_save_path = f"{target_labelsTr}/{prefix}_{subject}.txt" + + np.savetxt(label_save_path, np.array([label]), fmt="%s") + if not isfile(image_save_path): + shutil.copy2(image_path, image_save_path) + + # Populate the test folders + for _, row in tqdm(all_test.iterrows(), total=len(all_test)): + subject = row["Subject"] + label = row["Group"] + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + elif label == "MCI": + label = 2 + else: + print("Found unexpected labels: ", label) + + image_path = join(path, "MNI/Data_Affine_SkullStripped", subject, "T1_mni_affine.nii.gz") + + image_save_path = f"{target_imagesTs}/{prefix}_{subject}_000.nii.gz" + label_save_path = f"{target_labelsTs}/{prefix}_{subject}.txt" + + np.savetxt(label_save_path, np.array([label]), fmt="%s") + if not isfile(image_save_path): + shutil.copy2(image_path, image_save_path) + + generate_dataset_json( + join(target_base, "dataset.json"), + target_imagesTr, + target_imagesTs, + ("T1",), + labels={0: "CN", 1: "AD", 2: "MCI"}, + dataset_name=task_name, + license="CC-BY 4.0", + dataset_description="ADNI", + dataset_reference="", + ) + + +if __name__ == "__main__": + convert() diff --git a/yucca/pipeline/task_conversion/Task507_ADNI900_AFFSS_MRISEG.py b/yucca/pipeline/task_conversion/Task507_ADNI900_AFFSS_MRISEG.py new file mode 100644 index 00000000..bb31342b --- /dev/null +++ b/yucca/pipeline/task_conversion/Task507_ADNI900_AFFSS_MRISEG.py @@ -0,0 +1,130 @@ +# Dataset containing 300 AD, 300 MCI and 300 controls used to test Classification pipeline. +# In this version we use the MRI + Segmentation and the labels. + + +import shutil + +import numpy as np +import pandas as pd +from tqdm import tqdm +import gzip +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p as ensure_dir_exists, isfile +from yucca.paths import get_raw_data_path +from yucca.pipeline.task_conversion.utils import generate_dataset_json +from sklearn.model_selection import train_test_split +from yucca.functional.utils.nib_utils import get_nib_orientation, reorient_nib_image +import nibabel as nib + + +def convert(path: str = "/home/zcr545/data/data/projects/PsyBrainPrediction", subdir: str = "ADNI"): + path = join(path, subdir) + + # Train and Test images are in the same folder + images_dir = join(path, "Images") + + # OUTPUT DATA + # Target names + task_name = "Task507_ADNI900_AFFSS_MRISEG" + prefix = "ADNI900_AFFSS_MRISEG" + + # Target paths + target_base = join(get_raw_data_path(), task_name) + + target_imagesTr = join(target_base, "imagesTr") + target_labelsTr = join(target_base, "labelsTr") + + target_imagesTs = join(target_base, "imagesTs") + target_labelsTs = join(target_base, "labelsTs") + + ensure_dir_exists(target_imagesTr) + ensure_dir_exists(target_labelsTs) + ensure_dir_exists(target_imagesTs) + ensure_dir_exists(target_labelsTr) + + # collect labels + AD_cases = pd.read_csv(join(path, "AD_group.csv")) + CN_cases = pd.read_csv(join(path, "CN_group.csv")) + MCI_cases = pd.read_csv(join(path, "MCI_group.csv")) + + # Check that no subjects appear twice in either group + assert len(set(AD_cases.Subject)) == 300 + assert len(set(CN_cases.Subject)) == 300 + assert len(set(MCI_cases.Subject)) == 300 + + # Check that no subjects appear in both groups + assert len(np.intersect1d(AD_cases.Subject, CN_cases.Subject)) == 0 + assert len(np.intersect1d(AD_cases.Subject, MCI_cases.Subject)) == 0 + assert len(np.intersect1d(CN_cases.Subject, MCI_cases.Subject)) == 0 + + train_AD, test_AD = train_test_split(AD_cases, random_state=418920) + train_CN, test_CN = train_test_split(CN_cases, random_state=537289) + train_MCI, test_MCI = train_test_split(MCI_cases, random_state=958013) + + all_train = pd.concat([train_AD, train_CN, train_MCI]) + all_test = pd.concat([test_AD, test_CN, test_MCI]) + + # Populate the training folders + for _, row in tqdm(all_train.iterrows(), total=len(all_train)): + subject = row["Subject"] + label = row["Group"] + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + elif label == "MCI": + label = 2 + else: + print("Found unexpected labels: '", label, "'") + + image_path = join(path, "MNI/Data_Affine_SkullStripped", subject, "T1_mni_affine.nii.gz") + seg_path = join(path, "MNI/Seg_Affine_SkullStripped", subject, "seg_mni_affine.nii.gz") + + image_save_path = f"{target_imagesTr}/{prefix}_{subject}_000.nii.gz" + seg_save_path = f"{target_imagesTr}/{prefix}_{subject}_001.nii.gz" + label_save_path = f"{target_labelsTr}/{prefix}_{subject}.txt" + + np.savetxt(label_save_path, np.array([label]), fmt="%s") + shutil.copy2(image_path, image_save_path) + shutil.copy2(seg_path, seg_save_path) + + # Populate the test folders + for _, row in tqdm(all_test.iterrows(), total=len(all_test)): + subject = row["Subject"] + label = row["Group"] + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + elif label == "MCI": + label = 2 + else: + print("Found unexpected labels: ", label) + + image_path = join(path, "MNI/Data_Affine_SkullStripped", subject, "T1_mni_affine.nii.gz") + seg_path = join(path, "MNI/Seg_Affine_SkullStripped", subject, "seg_mni_affine.nii.gz") + + image_save_path = f"{target_imagesTs}/{prefix}_{subject}_000.nii.gz" + seg_save_path = f"{target_imagesTs}/{prefix}_{subject}_001.nii.gz" + label_save_path = f"{target_labelsTs}/{prefix}_{subject}.txt" + + np.savetxt(label_save_path, np.array([label]), fmt="%s") + shutil.copy2(image_path, image_save_path) + shutil.copy2(seg_path, seg_save_path) + + generate_dataset_json( + join(target_base, "dataset.json"), + target_imagesTr, + target_imagesTs, + ("T1", "Segmentation"), + labels={0: "CN", 1: "AD", 2: "MCI"}, + dataset_name=task_name, + license="CC-BY 4.0", + dataset_description="ADNI", + dataset_reference="", + ) + + +if __name__ == "__main__": + convert() diff --git a/yucca/pipeline/task_conversion/Task508_ADNI900_AFFSS_MRICOV.py b/yucca/pipeline/task_conversion/Task508_ADNI900_AFFSS_MRICOV.py new file mode 100644 index 00000000..ba78c196 --- /dev/null +++ b/yucca/pipeline/task_conversion/Task508_ADNI900_AFFSS_MRICOV.py @@ -0,0 +1,156 @@ +# Dataset containing 150 AD and 150 controls used to test Classification pipeline. +# In this version we ONLY use the MRI and the labels. +import shutil +import numpy as np +import pandas as pd +import nibabel as nib +import gzip +from tqdm import tqdm +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p as ensure_dir_exists, isfile +from yucca.paths import get_raw_data_path +from yucca.pipeline.task_conversion.utils import generate_dataset_json +from sklearn.model_selection import train_test_split +from yucca.functional.utils.nib_utils import get_nib_orientation, reorient_nib_image + + +def convert(path: str = "/home/zcr545/data/data/projects/PsyBrainPrediction", subdir: str = "ADNI"): + path = join(path, subdir) + + # Train and Test images are in the same folder + images_dir = join(path, "Images") + + # OUTPUT DATA + # Target names + task_name = "Task508_ADNI900_AFFSS_MRICOV" + prefix = "ADNI900_AFFSS_MRICOV" + + # Target paths + target_base = join(get_raw_data_path(), task_name) + + target_imagesTr = join(target_base, "imagesTr") + target_covariatesTr = join(target_base, "covariatesTr") + target_labelsTr = join(target_base, "labelsTr") + + target_imagesTs = join(target_base, "imagesTs") + target_covariatesTs = join(target_base, "covariatesTs") + target_labelsTs = join(target_base, "labelsTs") + + ensure_dir_exists(target_imagesTr) + ensure_dir_exists(target_covariatesTr) + ensure_dir_exists(target_labelsTs) + ensure_dir_exists(target_imagesTs) + ensure_dir_exists(target_covariatesTs) + ensure_dir_exists(target_labelsTr) + + # collect labels + AD_cases = pd.read_csv(join(path, "AD_group.csv")) + CN_cases = pd.read_csv(join(path, "CN_group.csv")) + MCI_cases = pd.read_csv(join(path, "MCI_group.csv")) + + # Check that no subjects appear twice in either group + assert len(set(AD_cases.Subject)) == 350 + assert len(set(CN_cases.Subject)) == 350 + assert len(set(MCI_cases.Subject)) == 350 + + # Check that no subjects appear in both groups + assert len(np.intersect1d(AD_cases.Subject, CN_cases.Subject)) == 0 + assert len(np.intersect1d(AD_cases.Subject, MCI_cases.Subject)) == 0 + assert len(np.intersect1d(CN_cases.Subject, MCI_cases.Subject)) == 0 + + train_AD, test_AD = train_test_split(AD_cases, random_state=418920) + train_CN, test_CN = train_test_split(CN_cases, random_state=537289) + train_MCI, test_MCI = train_test_split(MCI_cases, random_state=958013) + + all_train = pd.concat([train_AD, train_CN, train_MCI]) + all_test = pd.concat([test_AD, test_CN, test_MCI]) + + # Populate the training folders + for _, row in tqdm(all_train.iterrows(), total=len(all_train)): + subject = row["Subject"] + label = row["Group"] + age = row["Age"] + sex = row["Sex"] + + # Fix the labels + age = float(age / 122) # 122 = oldest human ever will have value 1.0 + + if sex == "F": + sex = 0 + elif sex == "M": + sex = 1 + else: + print("Found unexpected sex: '", sex, "'") + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + elif label == "MCI": + label = 2 + else: + print("Found unexpected labels: '", label, "'") + + image_path = join(path, "MNI/Data_Affine_SkullStripped", subject, "T1_mni_affine.nii.gz") + + image_save_path = f"{target_imagesTr}/{prefix}_{subject}_000.nii.gz" + cov_save_path = f"{target_covariatesTr}/{prefix}_{subject}_COV.txt" + label_save_path = f"{target_labelsTr}/{prefix}_{subject}.txt" + + np.savetxt(cov_save_path, np.array([age, sex]), fmt="%s") + np.savetxt(label_save_path, np.array([label]), fmt="%s") + if not isfile(image_save_path): + shutil.copy2(image_path, image_save_path) + + # Populate the test folders + for _, row in tqdm(all_test.iterrows(), total=len(all_test)): + subject = row["Subject"] + label = row["Group"] + age = row["Age"] + sex = row["Sex"] + + # Fix the labels + age = float(age / 122) # 122 = oldest human ever will have value 1.0 + + if sex == "F": + sex = 0 + elif sex == "M": + sex = 1 + else: + print("Found unexpected sex: '", sex, "'") + + if label == "CN": + label = 0 + elif label == "AD": + label = 1 + elif label == "MCI": + label = 2 + else: + print("Found unexpected labels: ", label) + + image_path = join(path, "MNI/Data_Affine_SkullStripped", subject, "T1_mni_affine.nii.gz") + + image_save_path = f"{target_imagesTs}/{prefix}_{subject}_000.nii.gz" + cov_save_path = f"{target_covariatesTs}/{prefix}_{subject}_COV.txt" + label_save_path = f"{target_labelsTs}/{prefix}_{subject}.txt" + + np.savetxt(cov_save_path, np.array([age, sex]), fmt="%f") + np.savetxt(label_save_path, np.array([label]), fmt="%f") + if not isfile(image_save_path): + shutil.copy2(image_path, image_save_path) + + generate_dataset_json( + join(target_base, "dataset.json"), + target_imagesTr, + target_imagesTs, + ("T1",), + labels={0: "CN", 1: "AD", 2: "MCI"}, + im_ext="nii.gz", + dataset_name=task_name, + license="CC-BY 4.0", + dataset_description="ADNI", + dataset_reference="", + ) + + +if __name__ == "__main__": + convert() diff --git a/yucca/pipeline/task_conversion/utils.py b/yucca/pipeline/task_conversion/utils.py index c23d46aa..2f4175ab 100644 --- a/yucca/pipeline/task_conversion/utils.py +++ b/yucca/pipeline/task_conversion/utils.py @@ -65,6 +65,7 @@ def generate_dataset_json( labels: dict, dataset_name: str, label_hierarchy: dict = {}, + im_ext: str = None, regions: dict = {}, tasks: list = [], license: str = "hands off!", @@ -104,7 +105,10 @@ def generate_dataset_json( :return: """ first_file = files_in_dir(imagesTr_dir)[0] - im_ext = os.path.split(first_file)[-1].split(os.extsep, 1)[-1] + + if im_ext is None: + im_ext = os.path.split(first_file)[-1].split(os.extsep, 1)[-1] + train_identifiers = get_identifiers_from_splitted_files(imagesTr_dir, im_ext, tasks) if imagesTs_dir is not None: