diff --git a/BrainMaGe/utils/convert_ckpt_to_pt.py b/BrainMaGe/utils/convert_ckpt_to_pt.py index ec79cf3..80a6012 100755 --- a/BrainMaGe/utils/convert_ckpt_to_pt.py +++ b/BrainMaGe/utils/convert_ckpt_to_pt.py @@ -32,14 +32,16 @@ ckpt_file = os.path.abspath(args.input) pt_file = os.path.abspath(args.output) - print("Attempting to load file : ", ckpt_file) - weight_load = torch.load(ckpt_file) - print("Load Successful! Converting file.") + print("Attempting to load file:", ckpt_file) + with open(ckpt_file, "rb") as f: + weight_load = torch.load(f, map_location=torch.device("cpu")) + print("Load successful! Converting file.") new_state_dict = {} - for key in weight_load["state_dict"].keys(): + for key, value in weight_load["state_dict"].items(): new_key = key[6:] - new_state_dict[new_key] = weight_load["state_dict"][key] + new_state_dict[new_key] = value model_state_dict = {"model_state_dict": new_state_dict} print("Conversion successful!") - torch.save(model_state_dict, pt_file) - print("File saved successfully at :", pt_file) + with open(pt_file, "wb") as f: + torch.save(model_state_dict, f) + print("File saved successfully at:", pt_file) diff --git a/BrainMaGe/utils/csv_creator_adv.py b/BrainMaGe/utils/csv_creator_adv.py index 6438d91..b708600 100755 --- a/BrainMaGe/utils/csv_creator_adv.py +++ b/BrainMaGe/utils/csv_creator_adv.py @@ -14,217 +14,193 @@ from bids import BIDSLayout -def rex_o4a_csv(folder_path, to_save, ftype, modalities): - """[CSV generation for OneForAll] - [This function is used to generate a csv for OneForAll mode and creates a - csv] - Arguments: - folder_path {[string]} -- [Takes the folder to see where to look for - the different modaliies] - to_save {[string]} -- [Takes the folder as a string to save the csv] - ftype {[string]} -- [Are you trying to save train, validation or test, - if file type is set to test, it does not look for - ground truths] - modalities {[string]} -- [usually a string which looks like this - : ['t1', 't2', 't1ce']] +def rex_o4a_csv(folder_path, save_folder, file_type, modalities): + """ + CSV generation for OneForAll. + This function generates a csv for OneForAll mode and creates a csv. + + Args: + folder_path (str): The folder to see where to look for the different modalities + save_folder (str): The folder to save the csv + file_type (str): train, validation or test. If file_type is set to test, it does not look for ground truths. + modalities (list of str): The modalities to include in the csv. """ modalities = modalities[1:-1] modalities = re.findall("[^, ']+", modalities) if not modalities: - print( - "Could not find modalities! Are you sure you have put in \ - something in the modalities field?" - ) - sys.exit(0) - if ftype == "test": - csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") - csv_file.write("ID,Image_Path\n") + raise ValueError("Could not find modalities! Are you sure you have put in something in the modalities field?") + + if file_type == "test": + csv_path = os.path.join(save_folder, file_type + ".csv") + with open(csv_path, "w+") as csv_file: + csv_file.write("ID,Image_Path\n") else: - csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") - csv_file.write("ID,gt_path,Image_path\n") + csv_path = os.path.join(save_folder, file_type + ".csv") + with open(csv_path, "w+") as csv_file: + csv_file.write("ID,gt_path,Image_path\n") + folders = os.listdir(folder_path) for folder in folders: for modality in modalities: csv_file.write(folder + "_" + modality + ",") - if ftype != "test": - ground_truth = glob.glob( - os.path.join(folder_path, folder, "*mask.nii.gz") - )[0] - csv_file.write(ground_truth) + + if file_type != "test": + ground_truth_path = glob.glob(os.path.join(folder_path, folder, "*mask.nii.gz"))[0] + csv_file.write(ground_truth_path) csv_file.write(",") - img = glob.glob( - os.path.join(folder_path, folder, "*" + modality + ".nii.gz") - )[0] - csv_file.write(img) + + image_path = glob.glob(os.path.join(folder_path, folder, "*" + modality + ".nii.gz"))[0] + csv_file.write(image_path) csv_file.write("\n") - csv_file.close() + + print("CSV file saved successfully at:", csv_path) -def rex_sin_csv(folder_path, to_save, ftype, modalities): - """[CSV generation for Single Modalities] - [This function is used to generate a csv for Single mode and creates a csv] - Arguments: - folder_path {[string]} -- [Takes the folder to see where to look for - the different modaliies] - to_save {[string]} -- [Takes the folder as a string to save the csv] - ftype {[string]} -- [Are you trying to save train, validation or test, - if file type is set to test, it does not look for - ground truths] - modalities {[string]} -- [usually a string which looks like this - : ['t1']] +def rex_sin_csv(folder_path, save_folder, file_type, modalities): + """ + CSV generation for Single Modalities. + This function generates a csv for Single mode and creates a csv. + + Args: + folder_path (str): The folder to see where to look for the different modalities + save_folder (str): The folder to save the csv + file_type (str): train, validation or test. If file_type is set to test, it does not look for ground truths. + modalities (list of str): The modalities to include in the csv. """ modalities = modalities[1:-1] modalities = re.findall("[^, ']+", modalities) if len(modalities) > 1: - print("Found more than one modality, exiting!") - sys.exit(0) + raise ValueError("Found more than one modality, exiting!") if not modalities: - print( - "Could not find modalities! Are you sure you have put in \ - something in the modalities field?" - ) - sys.exit(0) - if ftype == "test": - csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") - csv_file.write("ID,") + raise ValueError("Could not find modalities! Are you sure you have put in something in the modalities field?") + + if file_type == "test": + csv_path = os.path.join(save_folder, file_type + ".csv") + with open(csv_path, "w+") as csv_file: + csv_file.write("ID,") else: - csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") - csv_file.write("ID,gt_path,") + csv_path = os.path.join(save_folder, file_type + ".csv") + with open(csv_path, "w+") as csv_file: + csv_file.write("ID,gt_path,") + modality = modalities[0] csv_file.write(modality + "_path\n") folders = os.listdir(folder_path) for folder in folders: csv_file.write(folder) csv_file.write(",") - if ftype != "test": - ground_truth = glob.glob(os.path.join(folder_path, folder, "*mask.nii.gz"))[ - 0 - ] - csv_file.write(ground_truth) + if file_type != "test": + ground_truth_path = glob.glob(os.path.join(folder_path, folder, "*mask.nii.gz"))[0] + csv_file.write(ground_truth_path) csv_file.write(",") - img = glob.glob(os.path.join(folder_path, folder, "*" + modality + ".nii.gz"))[ - 0 - ] - csv_file.write(img) + + image_path = glob.glob(os.path.join(folder_path, folder, "*" + modality + ".nii.gz"))[0] + csv_file.write(image_path) csv_file.write("\n") - csv_file.close() + + print("CSV file saved successfully at:", csv_path) -def rex_mul_csv(folder_path, to_save, ftype, modalities): - """[CSV generation for Multi Modalities] - [This function is used to generate a csv for multi mode and creates a csv] - Arguments: - folder_path {[string]} -- [Takes the folder to see where to look for - the different modaliies] - to_save {[string]} -- [Takes the folder as a string to save the csv] - ftype {[string]} -- [Are you trying to save train, validation or test, - if file type is set to test, it does not look for - ground truths] - modalities {[string]} -- [usually a string which looks like this - : ['t1']] +def rex_mul_csv(folder_path, save_folder, file_type, modalities): + """ + CSV generation for Multi Modalities. + This function generates a csv for multi mode and creates a csv. + + Args: + folder_path (str): The folder to see where to look for the different modalities + save_folder (str): The folder to save the csv + file_type (str): train, validation or test. If file_type is set to test, it does not look for ground truths. + modalities (list of str): The modalities to include in the csv. """ modalities = modalities[1:-1] modalities = re.findall("[^, ']+", modalities) if not modalities: - print( - "Could not find modalities! Are you sure you have put in \ - something in the modalities field?" - ) - sys.exit(0) - if ftype == "test": - csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") - csv_file.write("ID,") + raise ValueError("Could not find modalities! Are you sure you have put in something in the modalities field?") + + if file_type == "test": + csv_path = os.path.join(save_folder, file_type + ".csv") + with open(csv_path, "w+") as csv_file: + csv_file.write("ID,") else: - csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") - csv_file.write("ID,gt_path,") - for modality in modalities[:-1]: - csv_file.write(modality + "_path,") - modality = modalities[-1] - csv_file.write(modality + "_path\n") + csv_path = os.path.join(save_folder, file_type + ".csv") + with open(csv_path, "w+") as csv_file: + csv_file.write("ID,gt_path,") + + csv_file.write(",".join([f"{modality}_path" for modality in modalities])) + csv_file.write("\n") + folders = os.listdir(folder_path) for folder in folders: csv_file.write(folder) csv_file.write(",") - if ftype != "test": - ground_truth = glob.glob(os.path.join(folder_path, folder, "*mask.nii.gz"))[ - 0 - ] - csv_file.write(ground_truth) + + if file_type != "test": + ground_truth_path = glob.glob(os.path.join(folder_path, folder, "*mask.nii.gz"))[0] + csv_file.write(ground_truth_path) csv_file.write(",") + for modality in modalities[:-1]: - img = glob.glob( - os.path.join(folder_path, folder, "*" + modality + ".nii.gz") - )[0] - csv_file.write(img) + image_path = glob.glob(os.path.join(folder_path, folder, "*" + modality + ".nii.gz"))[0] + csv_file.write(image_path) csv_file.write(",") - modality = modalities[-1] - img = glob.glob(os.path.join(folder_path, folder, "*" + modality + ".nii.gz"))[ - 0 - ] - csv_file.write(img) + + image_path = glob.glob(os.path.join(folder_path, folder, "*" + modalities[-1] + ".nii.gz"))[0] + csv_file.write(image_path) csv_file.write("\n") - csv_file.close() + + print("CSV file saved successfully at:", csv_path) -def rex_bids_csv(folder_path, to_save, ftype): - """[CSV generation for BIDS datasets] - [This function is used to generate a csv for BIDS datasets] - Arguments: - folder_path {[string]} -- [Takes the folder to see where to look for - the different modaliies] - to_save {[string]} -- [Takes the folder as a string to save the csv] - ftype {[string]} -- [Are you trying to save train, validation or test, - if file type is set to test, it does not look for - ground truths] +def rex_bids_csv(folder_path, save_folder, file_type): + """ + CSV generation for BIDS datasets. + This function generates a csv for BIDS datasets. + + Args: + folder_path (str): The folder to see where to look for the different modalities + save_folder (str): The folder to save the csv + file_type (str): train, validation or test. If file_type is set to test, it does not look for ground truths. """ - if ftype == "test": - csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") - csv_file.write("ID,") + if file_type == "test": + csv_path = os.path.join(save_folder, file_type + ".csv") + with open(csv_path, "w+") as csv_file: + csv_file.write("ID,") else: - csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") - csv_file.write("ID,gt_path,") - # load BIDS dataset into memory + csv_path = os.path.join(save_folder, file_type + ".csv") + with open(csv_path, "w+") as csv_file: + csv_file.write("ID,gt_path,") + layout = BIDSLayout(folder_path) - bids_df = layout.to_df() - bids_modality_df = { - "t1": bids_df[bids_df["suffix"] == "T1w"], - "t2": bids_df[bids_df["suffix"] == "T2w"], - "flair": bids_df[bids_df["suffix"] == "FLAIR"], - "t1ce": bids_df[bids_df["suffix"] == "T1CE"], - } - # check what modalities the dataset contains - modalities = [] - for modality, df in bids_modality_df.items(): - if not df.empty: - modalities.append(modality) - # write headers for those modalities - for modality in modalities[:-1]: - csv_file.write(modality + "_path,") - modality = modalities[-1] - csv_file.write(modality + "_path\n") - # write image paths for each subject + modalities = ["t1", "t2", "flair", "t1ce"] + modalities = [modality for modality in modalities if layout.get(suffix=f"{modality}w")] + + csv_file.write(",".join([f"{modality}_path" for modality in modalities])) + csv_file.write("\n") + for sub in layout.get_subjects(): csv_file.write(sub) csv_file.write(",") - if ftype != "test": - ground_truth = glob.glob(os.path.join(folder_path, sub, "*mask.nii.gz"))[0] - csv_file.write(ground_truth) + + if file_type != "test": + ground_truth_path = layout.get(subject=sub, suffix="mask")[0].filename + csv_file.write(ground_truth_path) csv_file.write(",") + for modality in modalities[:-1]: - img = bids_modality_df[modality][bids_df["subject"] == sub].path.values - csv_file.write(img[0]) + image_path = layout.get(subject=sub, suffix=f"{modality}w")[0].filename + csv_file.write(image_path) csv_file.write(",") - modality = modalities[-1] - img = bids_modality_df[modality][bids_df["subject"] == sub].path.values - csv_file.write(img[0]) + + image_path = layout.get(subject=sub, suffix=f"{modalities[-1]}w")[0].filename + csv_file.write(image_path) csv_file.write("\n") - csv_file.close() + + print("CSV file saved successfully at:", csv_path) def generate_csv(folder_path, to_save, mode, ftype, modalities): """[Function to generate CSV] - [This function takes a look at the data directory and the modes and - generates a csv] + [This function takes a look at the data directory and the modes and generates a csv] Arguments: folder_path {[strin]} -- [description] to_save {[strin]} -- [description] @@ -232,15 +208,16 @@ def generate_csv(folder_path, to_save, mode, ftype, modalities): ftype {[string]} -- [description] modalities {[string]} -- [description] """ - print("Generating ", ftype, ".csv", sep="") - if mode.lower() == "ma": - rex_o4a_csv(folder_path, to_save, ftype, modalities) - elif mode.lower() == "single": - rex_sin_csv(folder_path, to_save, ftype, modalities) - elif mode.lower() == "multi": - rex_mul_csv(folder_path, to_save, ftype, modalities) - elif mode.lower() == "bids": - rex_bids_csv(folder_path, to_save, ftype) + print(f"Generating {ftype}.csv") + modes = { + "ma": rex_o4a_csv, + "single": rex_sin_csv, + "multi": rex_mul_csv, + "bids": rex_bids_csv + } + func = modes.get(mode.lower()) + if func: + func(folder_path, to_save, ftype, modalities) else: print("Sorry, this mode is not supported") sys.exit(0) diff --git a/BrainMaGe/utils/optimizers.py b/BrainMaGe/utils/optimizers.py index 2bc12e1..bd672a4 100755 --- a/BrainMaGe/utils/optimizers.py +++ b/BrainMaGe/utils/optimizers.py @@ -7,35 +7,31 @@ """ + import torch.optim as optim import sys def fetch_optimizer(optimizer, lr, model): + # Mapping optimizer name to class + optimizer_map = { + "sgd": optim.SGD, + "adam": optim.Adam, + "rms": optim.RMSprop, + "adagrad": optim.Adagrad, + } + # Setting up the optimizer - if optimizer.lower() == "sgd": - optimizer = optim.SGD( - model.parameters(), lr=float(lr), momentum=0.9, nesterov=True - ) - elif optimizer.lower() == "adam": - optimizer = optim.Adam( - model.parameters(), lr=float(lr), betas=(0.9, 0.999), weight_decay=0.00005 - ) - elif optimizer.lower() == "rms": - optimizer = optim.RMSprop( - model.parameters(), lr=float(lr), momentum=0.9, weight_decay=0.00005 - ) - elif optimizer.lower() == "adagrad": - optimizer = optim.Adagrad( - model.parameters(), lr=float(lr), weight_decay=0.00005 - ) - else: - print( - "Sorry, {} is not supported or some sort of spell error. Please\ - choose from the given options!".format( - optimizer - ) - ) - sys.stdout.flush() - sys.exit(0) + optimizer_class = optimizer_map.get(optimizer.lower()) + if optimizer_class is None: + print(f"Sorry, {optimizer} is not supported. Please choose from the given options!") + sys.exit(1) + optimizer = optimizer_class( + model.parameters(), + lr=float(lr), + **({"momentum": 0.9, "nesterov": True} if optimizer.lower() == "sgd" else {}), + **({"betas": (0.9, 0.999), "weight_decay": 0.00005} if optimizer.lower() == "adam" else {}), + **({"momentum": 0.9, "weight_decay": 0.00005} if optimizer.lower() == "rms" else {}), + **({"weight_decay": 0.00005} if optimizer.lower() == "adagrad" else {}), + ) return optimizer diff --git a/BrainMaGe/utils/utils_test.py b/BrainMaGe/utils/utils_test.py index f6f07d9..6dd56c5 100755 --- a/BrainMaGe/utils/utils_test.py +++ b/BrainMaGe/utils/utils_test.py @@ -16,49 +16,24 @@ def pad_image(image): Parameters ---------- image : ndarray - DESCRIPTION. + Input image as a numpy array. Returns ------- - TYPE - padded image and its information + tuple + Padded image and its padding information as a tuple. """ - padded_image = image - pad_x1, pad_x2, pad_y1, pad_y2, pad_z1, pad_z2 = 0, 0, 0, 0, 0, 0 - # Padding on X axes - if image.shape[0] <= 240: - pad_x1 = (240 - image.shape[0]) // 2 - pad_x2 = 240 - image.shape[0] - pad_x1 - padded_image = np.pad( - padded_image, - ((pad_x1, pad_x2), (0, 0), (0, 0)), - mode="constant", - constant_values=0, - ) - - # Padding on Y axes - if image.shape[1] <= 240: - pad_y1 = (240 - image.shape[1]) // 2 - pad_y2 = 240 - image.shape[1] - pad_y1 - padded_image = np.pad( - padded_image, - ((0, 0), (pad_y1, pad_y2), (0, 0)), - mode="constant", - constant_values=0, - ) - - # Padding on Z axes - if image.shape[2] <= 160: - pad_z2 = 160 - image.shape[2] - padded_image = np.pad( - padded_image, - ((0, 0), (0, 0), (pad_z2, 0)), - mode="constant", - constant_values=0, - ) - - return padded_image, ((pad_x1, pad_x2), (pad_y1, pad_y2), (pad_z1, pad_z2)) + pad_info = ((0, 0), (0, 0), (0, 0)) + for i, dim in enumerate(image.shape): + if dim < (240 if i != 2 else 160): + pad_size = (240 if i != 2 else 160) - dim + pad_before = pad_size // 2 + pad_after = pad_size - pad_before + pad_info[i] = (pad_before, pad_after) + padded_image = np.pad( + image, pad_info, mode="constant", constant_values=0) + return padded_image, pad_info def process_image(image): @@ -68,41 +43,32 @@ def process_image(image): Parameters ---------- - image : TYPE - DESCRIPTION. + image : ndarray + Input image as a numpy array. Returns ------- - image : TYPE - DESCRIPTION. + ndarray + Preprocessed image as a numpy array. """ - to_return = image - new_image_temp = image[image >= image.mean()] - p1 = np.percentile(new_image_temp, 2) - p2 = np.percentile(new_image_temp, 95) - to_return[to_return > p2] = p2 - to_return = (to_return - p1) / p2 - return to_return + p1, p2 = np.percentile(image[image >= image.mean()], [2, 95]) + image = np.clip(image, None, p2) + image = (image - p1) / p2 + return image def padder_and_cropper(image, pad_info): (pad_x1, pad_x2), (pad_y1, pad_y2), (pad_z1, pad_z2) = pad_info - if pad_x2 == 0: - pad_x2 = -image.shape[0] - if pad_y2 == 0: - pad_y2 = -image.shape[1] - if pad_z2 == 0: - pad_z2 = -image.shape[2] - image = image[pad_x1:-pad_x2, pad_y1:-pad_y2, pad_z2:] - return image + x_start, x_end = pad_x1, image.shape[0] - pad_x2 if pad_x2 != 0 else None + y_start, y_end = pad_y1, image.shape[1] - pad_y2 if pad_y2 != 0 else None + z_start, z_end = pad_z1, image.shape[2] - pad_z2 if pad_z2 != 0 else None + return image[x_start:x_end, y_start:y_end, z_start:z_end] def unpad_image(image): - image = image[:, :, :155] - return image + return image[:, :, :155] def interpolate_image(image, output_shape): - new_image = resize(image, (output_shape), order=3, mode="edge", cval=0) - return new_image + return resize(image, output_shape, order=3, mode="edge", cval=0)