From f44c54f9664695c66d7cc3f4d4ad4ac6fc42b32d Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Wed, 11 Oct 2023 17:48:51 -0400 Subject: [PATCH 01/20] Add script to generate a config file --- scripts/init_data_config.py | 86 +++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 scripts/init_data_config.py diff --git a/scripts/init_data_config.py b/scripts/init_data_config.py new file mode 100644 index 0000000..43013f8 --- /dev/null +++ b/scripts/init_data_config.py @@ -0,0 +1,86 @@ +import os +import argparse +import random +import json +import itertools + +from utils import CONTRAST, get_img_path_from_label_path, fetch_contrast + + +CONTRAST_LOOKUP = {tuple(sorted(value)): key for key, value in CONTRAST.items()} + + +# Determine specified contrasts +def init_data_config(args): + """ + Create a JSON configuration file from a TXT file where images paths are specified + """ + if (args.split_validation + args.split_test) >= 1: + raise ValueError("The sum of the ratio between testing and validation cannot exceed 1") + + # Get input paths, could be label files or image files, + # and make sure they all exist. + file_paths = [os.path.abspath(path.replace('\n', '')) for path in open(args.txt)] + if args.type == 'LABEL': + label_paths = file_paths + img_paths = [get_img_path_from_label_path(lp) for lp in label_paths] + file_paths = label_paths + img_paths + elif args.type == 'IMAGE': + img_paths = file_paths + else: + raise ValueError(f"invalid args.type: {args.type}") + missing_paths = [ + path for path in file_paths + if not os.path.isfile(path) + ] + if missing_paths: + raise ValueError("missing files:\n" + '\n'.join(missing_paths)) + + # Look up the right code for the set of contrasts present + contrasts = CONTRAST_LOOKUP[tuple(sorted(set(map(fetch_contrast, img_paths))))] + + config = { + 'TYPE': args.type, + 'CONTRASTS': contrasts, + } + + # Split into training, validation, and testing sets + split_ratio = (1 - (args.split_validation + args.split_test), args.split_validation, args.split_test) # TRAIN, VALIDATION, and TEST + config_paths = label_paths if args.type == 'LABEL' else img_paths + random.shuffle(config_paths) + splits = [0] + [ + int(len(config_paths) * ratio) + for ratio in itertools.accumulate(split_ratio) + ] + for key, (begin, end) in zip( + ['TRAINING', 'VALIDATION', 'TESTING'], + pairwise(splits), + ): + config[key] = config_paths[begin:end] + + # Save the config + config_path = args.txt.replace('.txt', '') + '.json' + json.dump(config, open(config_path, 'w'), indent=4) + +def pairwise(iterable): + # pairwise('ABCDEFG') --> AB BC CD DE EF FG + # based on https://docs.python.org/3.11/library/itertools.html + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Create config JSON from a TXT file which contains list of paths') + + ## Parameters + parser.add_argument('--txt', required=True, + help='Path to TXT file that contains only image or label paths. (Required)') + parser.add_argument('--type', choices=('LABEL', 'IMAGE'), + help='Type of paths specified. Choices "LABEL" or "IMAGE". (Required)') + parser.add_argument('--split-validation', type=float, default=0.1, + help='Split ratio for validation. Default=0.1') + parser.add_argument('--split-test', type=float, default=0.1, + help='Split ratio for testing. Default=0.1') + + init_data_config(parser.parse_args()) From 643f4539caf3de8b817844de2fb91e7c3c07aeae Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Wed, 11 Oct 2023 17:50:19 -0400 Subject: [PATCH 02/20] Add script to analyze a config file --- scripts/analyze_data.py | 63 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 scripts/analyze_data.py diff --git a/scripts/analyze_data.py b/scripts/analyze_data.py new file mode 100644 index 0000000..9e21c08 --- /dev/null +++ b/scripts/analyze_data.py @@ -0,0 +1,63 @@ +import os +import argparse +import json + +from utils import get_img_path_from_mask_path, edit_metric_dict, save_graphs + +def run_analysis(args): + """ + Run analysis on a config file + """ + + disc_label_suffix = '_labels-disc-manual' + + # Check if labels are specified + if config_data['TYPE'] != 'LABEL': + raise ValueError("Pease specify l") + + # Read json file and create a dictionary + with open(args.config, "r") as file: + config_data = json.load(file) + + # Check analysis split + if args.split == 'ALL': + data_split = ['TRAINING', 'VALIDATION', 'TESTING'] + elif args.split in ['TRAINING', 'VALIDATION', 'TESTING']: + data_split = [args.split] + else: + raise ValueError(f"Invalid args.split: {args.split}") + + # Initialize metrics dictionary + metrics_dict = dict() + + # Extract information from the data + for split in data_split: + metrics_dict[split] = dict() + if config_data[split]: + for path in config_data[split]: + img_path = get_img_path_from_mask_path(path) + mask_path = path + + # Extract data + metrics_dict[split] = edit_metric_dict(metrics_dict[split], img_path, mask_path, disc_label_suffix=disc_label_suffix) + + # Plot data informations + save_graphs(output_folder='results', metrics_dict=metrics_dict) + + + + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Analyse config file') + + ## Parameters + parser.add_argument('--config', required=True, + help='Path to JSON config file that contains all the training splits (Required)') + parser.add_argument('--split', default='ALL', choices=('TRAINING', 'VALIDATION', 'TESTING', 'ALL'), + help='Split of the data that will be analysed (default="ALL")') + + # Start analysis + run_analysis(parser.parse_args()) \ No newline at end of file From db5461dbe4a20dd288bb37e37cc8003114718a4a Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Wed, 11 Oct 2023 17:50:40 -0400 Subject: [PATCH 03/20] Add utils functions --- scripts/utils.py | 246 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 scripts/utils.py diff --git a/scripts/utils.py b/scripts/utils.py new file mode 100644 index 0000000..4b6f964 --- /dev/null +++ b/scripts/utils.py @@ -0,0 +1,246 @@ +import os +import re +from pathlib import Path +import pandas as pd +import seaborn as sns +import numpy as np + +from utils import fetch_contrast +from image import Image + +## Global variables +CONTRAST = {'t1': ['T1w'], + 't2': ['T2w'], + 't2s':['T2star'], + 't1_t2': ['T1w', 'T2w']} + +## Functions +def get_img_path_from_mask_path(str_path): + """ + This function does 2 things: ⚠️ Files need to be stored in a BIDS compliant dataset + - Step 1: Remove label suffix (e.g. "_labels-disc-manual"). The suffix is always between the MRI contrast and the file extension. + - Step 2: Remove derivatives path (e.g. derivatives/labels/). The first folders is always called derivatives but the second may vary (e.g. labels_soft) + + :param path: absolute path to the label img. Example: //derivatives/labels/sub-amuALT/anat/sub-amuALT_T1w_labels-disc-manual.nii.gz + :return: img path. Example: //sub-amuALT/anat/sub-amuALT_T1w.nii.gz + Based on https://github.com/spinalcordtoolbox/disc-labeling-hourglass + """ + # Load path + path = Path(str_path) + + # Extract file extension + ext = ''.join(path.suffixes) + + # Get img name + img_name = '_'.join(path.name.split('_')[:-1]) + ext + + # Create a list of the directories + dir_list = str(path.parent).split('/') + + # Remove "derivatives" and "labels" folders + derivatives_idx = dir_list.index('derivatives') + dir_path = '/'.join(dir_list[0:derivatives_idx] + dir_list[derivatives_idx+2:]) + + # Recreate img path + img_path = os.path.join(dir_path, img_name) + + return img_path + +## +def get_mask_path_from_img_path(img_path, suffix='_seg', derivatives_path='/derivatives/labels'): + """ + This function returns the mask path from an image path. Images need to be stored in a BIDS compliant dataset. + + :param img_path: String path to niftii image + :param suffix: Mask suffix + :param derivatives_path: Relative path to derivatives folder where labels are stored (e.i. '/derivatives/labels') + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + """ + # Extract information from path + subjectID, sessionID, filename, contrast, echoID = fetch_subject_and_session(img_path) + + # Extract file extension + path_obj = Path(img_path) + ext = ''.join(path_obj.suffixes) + + # Create mask name + mask_name = path_obj.name.split('.')[0] + suffix + ext + + # Split path using "/" (TODO: check if it works for windows users) + path_list = img_path.split('/') + + # Extract subject folder index + sub_folder_idx = path_list.index(subjectID) + + # Reconstruct mask_path + mask_path = os.path.join('/'.join(path_list[:sub_folder_idx]), derivatives_path, path_list[sub_folder_idx:-1], mask_name) + return mask_path + +## +def change_mask_suffix(mask_path, new_suffix='_seg'): + """ + This function replace the current suffix with a new suffix suffix. If path is specified, make sure the dataset is BIDS compliant. + + :param mask_path: Input mask filepath or filename + :param new_suffix: New mask suffix + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + """ + + # Extract file extension + ext = ''.join(Path(mask_path).suffixes) + + # Change mask path + new_mask_path = '_'.join(mask_path.split('_')[:-1]) + new_suffix + ext + return new_mask_path + +## +def fetch_subject_and_session(filename_path): + """ + Get subject ID, session ID and filename from the input BIDS-compatible filename or file path + The function works both on absolute file path as well as filename + :param filename_path: input nifti filename (e.g., sub-001_ses-01_T1w.nii.gz) or file path + (e.g., /home/user/MRI/bids/derivatives/labels/sub-001/ses-01/anat/sub-001_ses-01_T1w.nii.gz + :return: subjectID: subject ID (e.g., sub-001) + :return: sessionID: session ID (e.g., ses-01) + :return: filename: nii filename (e.g., sub-001_ses-01_T1w.nii.gz) + :return: contrast: MRI modality (dwi or anat) + :return: echoID: echo ID (e.g., echo-1) + :return: acquisition: acquisition (e.g., acq_sag) + Based on https://github.com/spinalcordtoolbox/manual-correction + """ + + _, filename = os.path.split(filename_path) # Get just the filename (i.e., remove the path) + subject = re.search('sub-(.*?)[_/]', filename_path) # [_/] means either underscore or slash + subjectID = subject.group(0)[:-1] if subject else "" # [:-1] removes the last underscore or slash + + session = re.search('ses-(.*?)[_/]', filename_path) # [_/] means either underscore or slash + sessionID = session.group(0)[:-1] if session else "" # [:-1] removes the last underscore or slash + + echo = re.search('echo-(.*?)[_]', filename_path) # [_/] means either underscore or slash + echoID = echo.group(0)[:-1] if echo else "" # [:-1] removes the last underscore or slash + + acq = re.search('acq-(.*?)[_]', filename_path) # [_/] means either underscore or slash + acquisition = acq.group(0)[:-1] if acq else "" # [:-1] removes the last underscore or slash + # REGEX explanation + # . - match any character (except newline) + # *? - match the previous element as few times as possible (zero or more times) + + contrast = 'dwi' if 'dwi' in filename_path else 'anat' # Return contrast (dwi or anat) + + return subjectID, sessionID, filename, contrast, echoID, acquisition + + +def fetch_contrast(filename_path): + ''' + Extract MRI contrast from a BIDS-compatible filename/filepath + The function handles images only. + :param filename_path: image file path or file name. (e.g sub-001_ses-01_T1w.nii.gz) + Copied from https://github.com/spinalcordtoolbox/disc-labeling-hourglass + ''' + return filename_path.rstrip(''.join(Path(filename_path).suffixes)).split('_')[-1] + + +def edit_metric_dict(metrics_dict, img_path, mask_path, disc_label_suffix='_labels-disc-manual'): + ''' + This function extracts information and metadata from an image and its mask. Values are then + gathered inside a dictionary. + + :param metrics_dict: dictionary where information will be gathered + :param img_path: niftii image path + :param discs_mask_path: corresponding niftii discs mask path + + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + ''' + #-----------------------------------------------------------------------# + #----------------------- Extracting metadata ---------------------------# + #-----------------------------------------------------------------------# + + # Extract field of view information thanks to discs labels + if '_labels-disc' in mask_path: + discs_mask_path = mask_path + else: + discs_mask_path = change_mask_suffix(mask_path, new_suffix=disc_label_suffix) + + if os.path.exists(discs_mask_path): + discs_mask = Image(discs_mask_path) + disc_list = [list(coord)[-1] for coord in discs_mask.getNonZeroCoordinates(sorting='value')] + else: + disc_list = [] + + # Extract original image orientation + img = Image(img_path) + orientation = img.get_orientation + + # Extract image dimensions and resolutions + img_RSP = img.change_orientation("RSP") + nx, ny, nz, nt, px, py, pz, pt = img_RSP.get_dimension + + # Check for shape mismatch between mask and image + if img.data.shape != Image(mask_path).data.shape: + shape_mismatch = True + else: + shape_mismatch = False + + # Extract MRI contrast + contrast = fetch_contrast(img_path) + + #-----------------------------------------------------------------------# + #--------------------- Adding metadata to dict -------------------------# + #-----------------------------------------------------------------------# + list_of_metrics = [img_path, orientation, contrast, disc_list, shape_mismatch, nx, ny, nz, nt, px, py, pz, pt] + list_of_keys = ['img_path', 'orientation', 'contrast', 'disc_list', 'shape_mismatch', 'nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt'] + for key, metric in zip(list_of_keys, list_of_metrics): + if type(metric) != list(): + metric = [metric] + if key not in metrics_dict.keys(): + metrics_dict[key] = metric + else: + metrics_dict[key] += metric + + return metrics_dict + + +def save_violin(splits, values, output_path, y_axis): + ''' + Create a violin plot + :param splits: String list of the split name + :param values: Values associated with the split + :param output_path: Path to output folder where figures will be stored + :param y_axis: y-axis name + + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + ''' + + # Set position of bar on X axis + result_dict = {} + for i, split in enumerate(splits): + result_dict[split]=values[i] + result_df = pd.DataFrame(data=result_dict) + + # Make the plot + plot = sns.violinplot(data=result_df) + plot.set(xlabel='split', ylabel=y_axis) + plot.set(title=y_axis) + + # Save plot + plot.figure.savefig(output_path) + + +def save_graphs(output_folder, metrics_dict): + ''' + Plot and save metrics into an output folder + + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + ''' + # Extract subjects and metrics + splits = np.array(list(metrics_dict.keys())) + metrics_names = list(metrics_dict[splits[0]].keys()) + + # Use violin plots + for metric in ['disc_list', 'nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt']: + out_path = os.path.join(output_folder, f'{metric}.png') + save_violin(splits=splits, values=[metrics_dict[split][metric] for split in splits], output_path=out_path, y_axis=metric) + + # Use bar graphs + + From 6a5fa59df2f31556bed372e1d1683553ca32a84c Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Mon, 30 Oct 2023 17:54:39 -0400 Subject: [PATCH 04/20] analyse data implementation --- scripts/analyze_data.py | 76 +++++++++++------- scripts/utils.py | 167 +++++++++++++++++++++++++++++++--------- 2 files changed, 180 insertions(+), 63 deletions(-) diff --git a/scripts/analyze_data.py b/scripts/analyze_data.py index 9e21c08..1cee9f2 100644 --- a/scripts/analyze_data.py +++ b/scripts/analyze_data.py @@ -1,8 +1,10 @@ import os import argparse import json +from bids import BIDSLayout +import glob -from utils import get_img_path_from_mask_path, edit_metric_dict, save_graphs +from utils import get_img_path_from_mask_path, edit_metric_dict, save_graphs, change_mask_suffix def run_analysis(args): """ @@ -10,39 +12,57 @@ def run_analysis(args): """ disc_label_suffix = '_labels-disc-manual' + short_suffix_label = '_label' - # Check if labels are specified - if config_data['TYPE'] != 'LABEL': - raise ValueError("Pease specify l") + if args.config: + data_type = 'split' + # Read json file and create a dictionary + with open(args.config, "r") as file: + config_data = json.load(file) - # Read json file and create a dictionary - with open(args.config, "r") as file: - config_data = json.load(file) - - # Check analysis split - if args.split == 'ALL': - data_split = ['TRAINING', 'VALIDATION', 'TESTING'] - elif args.split in ['TRAINING', 'VALIDATION', 'TESTING']: - data_split = [args.split] + # Check if labels are specified + if config_data['TYPE'] != 'LABEL': + raise ValueError("Pease specify LABEL paths in config") + + elif args.paths_to_bids: + data_type = 'dataset' + # layout = BIDSLayout(args.paths_to_bids, derivatives=["derivatives/"]) + # tasks = layout.get_tasks() + # layout.get(scope='derivatives', return_type='file') + config_data = {} + for path_bids in args.paths_to_bids: + config_data[os.path.basename(path_bids)] = glob.glob(path_bids + "/**/*" + short_suffix_label + "*.nii.gz", recursive=True) else: - raise ValueError(f"Invalid args.split: {args.split}") + raise ValueError(f"Need to specify either args.paths_to_bids or args.config !") + - # Initialize metrics dictionary + # Initialize metrics dictionary`` metrics_dict = dict() + missing_data = [] # Extract information from the data - for split in data_split: - metrics_dict[split] = dict() - if config_data[split]: - for path in config_data[split]: - img_path = get_img_path_from_mask_path(path) - mask_path = path - - # Extract data - metrics_dict[split] = edit_metric_dict(metrics_dict[split], img_path, mask_path, disc_label_suffix=disc_label_suffix) + for key in config_data.keys(): + metrics_dict[key] = dict() + for path in config_data[key]: + img_path = get_img_path_from_mask_path(path) + mask_path = path + + # Extract field of view information thanks to discs labels + if short_suffix_label in mask_path: + discs_mask_path = mask_path + else: + discs_mask_path = change_mask_suffix(mask_path, new_suffix=disc_label_suffix) + + # Extract data + if os.path.exists(img_path) and os.path.exists(mask_path) and os.path.exists(discs_mask_path): + metrics_dict[key] = edit_metric_dict(metrics_dict[key], img_path, mask_path, discs_mask_path) + else: + missing_data.append(img_path) + + print("missing files:\n" + '\n'.join(missing_data)) # Plot data informations - save_graphs(output_folder='results', metrics_dict=metrics_dict) + save_graphs(output_folder='results', metrics_dict=metrics_dict, data_type=data_type) @@ -54,8 +74,10 @@ def run_analysis(args): parser = argparse.ArgumentParser(description='Analyse config file') ## Parameters - parser.add_argument('--config', required=True, - help='Path to JSON config file that contains all the training splits (Required)') + parser.add_argument('--paths-to-bids', default='', nargs='+', + help='Paths to BIDS compliant datasets') + parser.add_argument('--config', default='', + help='Path to JSON config file that contains all the training splits') parser.add_argument('--split', default='ALL', choices=('TRAINING', 'VALIDATION', 'TESTING', 'ALL'), help='Split of the data that will be analysed (default="ALL")') diff --git a/scripts/utils.py b/scripts/utils.py index 4b6f964..8a04628 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -3,9 +3,9 @@ from pathlib import Path import pandas as pd import seaborn as sns +import matplotlib.pyplot as plt import numpy as np -from utils import fetch_contrast from image import Image ## Global variables @@ -140,7 +140,7 @@ def fetch_contrast(filename_path): return filename_path.rstrip(''.join(Path(filename_path).suffixes)).split('_')[-1] -def edit_metric_dict(metrics_dict, img_path, mask_path, disc_label_suffix='_labels-disc-manual'): +def edit_metric_dict(metrics_dict, img_path, mask_path, discs_mask_path): ''' This function extracts information and metadata from an image and its mask. Values are then gathered inside a dictionary. @@ -155,25 +155,19 @@ def edit_metric_dict(metrics_dict, img_path, mask_path, disc_label_suffix='_labe #----------------------- Extracting metadata ---------------------------# #-----------------------------------------------------------------------# - # Extract field of view information thanks to discs labels - if '_labels-disc' in mask_path: - discs_mask_path = mask_path - else: - discs_mask_path = change_mask_suffix(mask_path, new_suffix=disc_label_suffix) - - if os.path.exists(discs_mask_path): + if os.path.exists(discs_mask_path): # TODO: deal with datasets with no discs labels discs_mask = Image(discs_mask_path) - disc_list = [list(coord)[-1] for coord in discs_mask.getNonZeroCoordinates(sorting='value')] + discs_labels = [list(coord)[-1] for coord in discs_mask.getNonZeroCoordinates(sorting='value')] else: - disc_list = [] + discs_labels = [] # Extract original image orientation img = Image(img_path) - orientation = img.get_orientation + orientation = img.orientation # Extract image dimensions and resolutions - img_RSP = img.change_orientation("RSP") - nx, ny, nz, nt, px, py, pz, pt = img_RSP.get_dimension + img_RPI = img.change_orientation("RPI") + nx, ny, nz, nt, px, py, pz, pt = img_RPI.dim # Check for shape mismatch between mask and image if img.data.shape != Image(mask_path).data.shape: @@ -187,10 +181,10 @@ def edit_metric_dict(metrics_dict, img_path, mask_path, disc_label_suffix='_labe #-----------------------------------------------------------------------# #--------------------- Adding metadata to dict -------------------------# #-----------------------------------------------------------------------# - list_of_metrics = [img_path, orientation, contrast, disc_list, shape_mismatch, nx, ny, nz, nt, px, py, pz, pt] - list_of_keys = ['img_path', 'orientation', 'contrast', 'disc_list', 'shape_mismatch', 'nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt'] + list_of_metrics = [img_path, orientation, contrast, discs_labels, shape_mismatch, nx, ny, nz, nt, px, py, pz, pt] + list_of_keys = ['img_path', 'orientation', 'contrast', 'discs_labels', 'shape_mismatch', 'nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt'] for key, metric in zip(list_of_keys, list_of_metrics): - if type(metric) != list(): + if not isinstance(metric, list): metric = [metric] if key not in metrics_dict.keys(): metrics_dict[key] = metric @@ -200,47 +194,148 @@ def edit_metric_dict(metrics_dict, img_path, mask_path, disc_label_suffix='_labe return metrics_dict -def save_violin(splits, values, output_path, y_axis): +def save_violin(names, values, output_path, x_axis, y_axis): ''' Create a violin plot - :param splits: String list of the split name - :param values: Values associated with the split - :param output_path: Path to output folder where figures will be stored + :param names: String list of the names + :param values: Values associated with the names + :param output_path: Output path (string) + :param x_axis: x-axis name :param y_axis: y-axis name Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark ''' # Set position of bar on X axis - result_dict = {} - for i, split in enumerate(splits): - result_dict[split]=values[i] + result_dict = {'names':[], 'values':[]} + for i, name in enumerate(names): + result_dict['values'] += values[i] + for j in range(len(values[i])): + result_dict['names'] += [name] + + result_df = pd.DataFrame(data=result_dict) + + # Make the plot + plt.figure() + sns.violinplot(x="names", y="values", data=result_df) + plt.xlabel(x_axis, fontsize = 15) + plt.ylabel(y_axis, fontsize = 15) + plt.title(y_axis, fontsize = 20) + + # Save plot + plt.savefig(output_path) + + +def save_hist(names, values, output_path, x_axis, y_axis): + ''' + Create a histogram plot + :param names: String list of the names + :param values: Values associated with the names + :param output_path: Output path (string) + :param x_axis: x-axis name + :param y_axis: y-axis name + + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + ''' + + # Set position of bar on X axis + result_dict = {'names':[], 'values':[]} + for i, name in enumerate(names): + result_dict['values'] += values[i] + for j in range(len(values[i])): + result_dict['names'] += [name] + result_df = pd.DataFrame(data=result_dict) - # Make the plot - plot = sns.violinplot(data=result_df) - plot.set(xlabel='split', ylabel=y_axis) - plot.set(title=y_axis) + # Make the plot + plt.figure() + sns.histplot(data=result_df, x="values", hue="names", multiple="dodge", binwidth=1/len(names)) + plt.xlabel(x_axis, fontsize = 15) + plt.xticks(np.arange(1, np.max(result_dict['values'])+1)) + plt.ylabel(x_axis, fontsize = 15) + plt.title(y_axis, fontsize = 20) # Save plot - plot.figure.savefig(output_path) + plt.savefig(output_path) -def save_graphs(output_folder, metrics_dict): +def save_pie(names, values, output_path, x_axis, y_axis): + ''' + Create a pie chart plot + :param names: String list of the names + :param values: Values associated with the names + :param output_path: Output path (string) + :param x_axis: x-axis name + :param y_axis: y-axis name + + Based on https://www.geeksforgeeks.org/how-to-create-a-pie-chart-in-seaborn/ + ''' + # Set position of bar on X axis + result_dict = {} + for i, name in enumerate(names): + result_dict[name] = {} + for val in values[i]: + if val not in result_dict[name].keys(): + result_dict[name][val] = 1 + else: + result_dict[name][val] += 1 + + # define Seaborn color palette to use + palette_color = sns.color_palette('bright') + + def autopct_format(values): + ''' + Based on https://stackoverflow.com/questions/53782591/how-to-display-actual-values-instead-of-percentages-on-my-pie-chart-using-matplo + ''' + def my_format(pct): + total = sum(values) + val = int(round(pct*total)/100) + return '{v:d}'.format(v=val) + return my_format + + # Make the plot + if len(names) == 1: + fig = plt.figure() + plt.pie(result_dict[names[0]].values(), labels=result_dict[names[0]].keys(), colors=palette_color, autopct=autopct_format(result_dict[names[0]].values())) + plt.title(y_axis, fontsize = 20) + plt.xlabel(x_axis, fontsize = 15) + plt.ylabel(y_axis, fontsize = 15) + else: + fig, axs = plt.subplots(1, len(names), figsize=(3*len(names),5)) + fig.suptitle(y_axis) + + for j, name in enumerate(result_dict.keys()): + axs[j].pie(result_dict[name].values(), labels=result_dict[name].keys(), colors=palette_color, autopct=autopct_format(result_dict[names[0]].values())) + axs[j].set_title(name) + + for ax, name in zip(axs.flat, names): + ax.set(xlabel=name, ylabel=y_axis) + + # Save plot + plt.savefig(output_path) + + +def save_graphs(output_folder, metrics_dict, data_type='split'): ''' Plot and save metrics into an output folder Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark ''' # Extract subjects and metrics - splits = np.array(list(metrics_dict.keys())) - metrics_names = list(metrics_dict[splits[0]].keys()) + data_name = np.array(list(metrics_dict.keys())) + metrics_names = list(metrics_dict[data_name[0]].keys()) # Use violin plots - for metric in ['disc_list', 'nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt']: + for metric in ['nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt']: out_path = os.path.join(output_folder, f'{metric}.png') - save_violin(splits=splits, values=[metrics_dict[split][metric] for split in splits], output_path=out_path, y_axis=metric) - - # Use bar graphs + save_violin(names=data_name, values=[metrics_dict[name][metric] for name in data_name], output_path=out_path, x_axis=data_type, y_axis=metric) + # Use bar pie chart + for metric in ['orientation', 'contrast']: + out_path = os.path.join(output_folder, f'{metric}.png') + save_pie(names=data_name, values=[metrics_dict[name][metric] for name in data_name], output_path=out_path, x_axis=data_type, y_axis=metric) + # Use bar graphs + for metric in ['discs_labels']: + out_path = os.path.join(output_folder, f'{metric}.png') + save_hist(names=data_name, values=[metrics_dict[name][metric] for name in data_name], output_path=out_path, x_axis=metric, y_axis='Count') From 110dbb7fd65a7d7ac7296ac251dbf2b3728350f3 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Wed, 8 Nov 2023 21:30:37 -0500 Subject: [PATCH 05/20] improve analyze_data --- scripts/analyze_data.py | 174 +++++++++++++++++++++++++++++++--------- 1 file changed, 134 insertions(+), 40 deletions(-) diff --git a/scripts/analyze_data.py b/scripts/analyze_data.py index 1cee9f2..eda8dfc 100644 --- a/scripts/analyze_data.py +++ b/scripts/analyze_data.py @@ -1,68 +1,158 @@ import os import argparse import json -from bids import BIDSLayout import glob +from progress.bar import Bar +import csv -from utils import get_img_path_from_mask_path, edit_metric_dict, save_graphs, change_mask_suffix +from utils import get_img_path_from_mask_path, get_mask_path_from_img_path, edit_metric_dict, save_graphs, change_mask_suffix, get_deriv_sub_from_img_path, str_to_float_list, str_to_str_list, mergedict def run_analysis(args): """ Run analysis on a config file """ - disc_label_suffix = '_labels-disc-manual' - short_suffix_label = '_label' + short_suffix_disc = '_label' + short_suffix_seg = '_seg' + derivatives_folder = 'derivatives' + output_folder = 'results' + + if not os.path.exists(output_folder): + os.makedirs(output_folder) if args.config: - data_type = 'split' + data_form = 'split' # Read json file and create a dictionary with open(args.config, "r") as file: config_data = json.load(file) - - # Check if labels are specified - if config_data['TYPE'] != 'LABEL': - raise ValueError("Pease specify LABEL paths in config") + + if config_data['TYPE'] == 'LABEL': + isImage = False + elif config_data['TYPE'] == 'IMAGE': + isImage = True + else: + raise ValueError(f'config with unknown TYPE {config_data['TYPE']}') + + # Remove keys that are not lists of paths + keys = list(config_data.keys()) + for key in keys: + if key not in ['TRAINING', 'VALIDATION', 'TESTING']: + del config_data[key] elif args.paths_to_bids: - data_type = 'dataset' - # layout = BIDSLayout(args.paths_to_bids, derivatives=["derivatives/"]) - # tasks = layout.get_tasks() - # layout.get(scope='derivatives', return_type='file') + data_form = 'dataset' config_data = {} for path_bids in args.paths_to_bids: - config_data[os.path.basename(path_bids)] = glob.glob(path_bids + "/**/*" + short_suffix_label + "*.nii.gz", recursive=True) + files = glob.glob(path_bids + "/**/" + "*.nii.gz", recursive=True) # Get all niftii files + config_data[os.path.basename(os.path.normpath(path_bids))] = [f for f in files if derivatives_folder not in f] # Remove masks from derivatives folder + isImage = True + + elif args.paths_to_csv: + data_form = 'dataset' + config_data = {} else: - raise ValueError(f"Need to specify either args.paths_to_bids or args.config !") + raise ValueError(f"Need to specify either args.paths_to_bids, args.config or args.paths_to_csv !") - - # Initialize metrics dictionary`` + # Initialize metrics dictionary metrics_dict = dict() - missing_data = [] - # Extract information from the data - for key in config_data.keys(): - metrics_dict[key] = dict() - for path in config_data[key]: - img_path = get_img_path_from_mask_path(path) - mask_path = path - - # Extract field of view information thanks to discs labels - if short_suffix_label in mask_path: - discs_mask_path = mask_path - else: - discs_mask_path = change_mask_suffix(mask_path, new_suffix=disc_label_suffix) - - # Extract data - if os.path.exists(img_path) and os.path.exists(mask_path) and os.path.exists(discs_mask_path): - metrics_dict[key] = edit_metric_dict(metrics_dict[key], img_path, mask_path, discs_mask_path) - else: - missing_data.append(img_path) - - print("missing files:\n" + '\n'.join(missing_data)) + if args.paths_to_csv: + for path_csv in args.paths_to_csv: + dataset_name = os.path.basename(path_csv).split('_')[-1].split('.csv')[0] + metrics_dict[dataset_name] = {} + with open(path_csv) as csv_file: + reader = csv.reader(csv_file) + for k, v in dict(reader).items(): + metric = k.split('_') + if len(metric) == 2: + metric_name, metric_value = metric + if metric_name not in metrics_dict[dataset_name].keys(): + metrics_dict[dataset_name][metric_name] = {metric_value:int(v)} + else: + metrics_dict[dataset_name][metric_name][metric_value] = int(v) + else: + if k.startswith('mismatch'): + metrics_dict[dataset_name][k] = int(v) + else: + metrics_dict[dataset_name][k] = str_to_str_list(v) + + # Initialize data finguerprint + fprint_dict = dict() + + if config_data.keys(): + missing_data = [] + # Extract information from the data + for key in config_data.keys(): + metrics_dict[key] = dict() + fprint_dict[key] = dict() + + # Init progression bar + bar = Bar(f'Analyze data {key} ', max=len(config_data[key])) + + for path in config_data[key]: + if isImage: + img_path = path # str + deriv_sub_folders = get_deriv_sub_from_img_path(img_path=img_path, derivatives_folder=derivatives_folder) # list of str + seg_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_seg, deriv_sub_folders=deriv_sub_folders) # list of str + discs_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_disc, deriv_sub_folders=deriv_sub_folders, counterexample=['compression', 'SC_mask']) # list of str + else: + img_path = get_img_path_from_mask_path(path, derivatives_folder=derivatives_folder) + deriv_sub_folders = [os.path.dirname(path)] + # Extract field of view information thanks to discs labels + if short_suffix_disc in path: + discs_paths = [path] + seg_paths = [change_mask_suffix(discs_paths, short_suffix=short_suffix_seg)] + elif short_suffix_seg in path: + seg_paths = [path] + discs_paths = [change_mask_suffix(seg_paths, short_suffix=short_suffix_disc)] + else: + seg_paths = [change_mask_suffix(path, short_suffix=short_suffix_seg)] + discs_paths = [change_mask_suffix(path, short_suffix=short_suffix_disc)] + + # Extract data + if os.path.exists(img_path): + metrics_dict[key], fprint_dict[key] = edit_metric_dict(metrics_dict[key], fprint_dict[key], img_path, seg_paths, discs_paths, deriv_sub_folders) + else: + missing_data.append(img_path) + + # Plot progress + bar.suffix = f'{config_data[key].index(path)+1}/{len(config_data[key])}' + bar.next() + bar.finish() + + # Store csv with computed metrics + if args.create_csv: + # Based on https://stackoverflow.com/questions/8685809/writing-a-dictionary-to-a-csv-file-with-one-line-for-every-key-value + out_csv_folder = os.path.join(output_folder, 'files') + if not os.path.exists(out_csv_folder): + os.makedirs(out_csv_folder) + csv_path_sum = os.path.join(out_csv_folder, f'computed_metrics_{key}.csv') + with open(csv_path_sum, 'w') as csv_file: + writer = csv.writer(csv_file) + for metric_name, metric in sorted(metrics_dict[key].items()): + if isinstance(metric,dict): + for metric_value, count in sorted(metric.items()): + k = f'{metric_name}_{metric_value}' + writer.writerow([k, count]) + else: + writer.writerow([metric_name, metric]) + + # Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + csv_path_fprint = os.path.join(out_csv_folder, f'fprint_{key}.csv') + sub_list = [sub for sub in fprint_dict[key].keys() if sub.startswith('sub')] + fields = ['subject'] + [k for k in fprint_dict[key][sub_list[0]].keys()] + with open(csv_path_fprint, 'w') as f: + w = csv.DictWriter(f, fields) + w.writeheader() + for k, v in fprint_dict[key].items(): + w.writerow(mergedict({'subject': k},v)) + + + if missing_data: + print("missing files:\n" + '\n'.join(missing_data)) # Plot data informations - save_graphs(output_folder='results', metrics_dict=metrics_dict, data_type=data_type) + save_graphs(output_folder=output_folder, metrics_dict=metrics_dict, data_form=data_form) @@ -75,11 +165,15 @@ def run_analysis(args): ## Parameters parser.add_argument('--paths-to-bids', default='', nargs='+', - help='Paths to BIDS compliant datasets') + help='Paths to BIDS compliant datasets (You can add multiple paths using spaces)') parser.add_argument('--config', default='', help='Path to JSON config file that contains all the training splits') + parser.add_argument('--paths-to-csv', default='', nargs='+', + help='Paths to csv files with already computed metrics (You can add multiple paths using spaces)') parser.add_argument('--split', default='ALL', choices=('TRAINING', 'VALIDATION', 'TESTING', 'ALL'), help='Split of the data that will be analysed (default="ALL")') + parser.add_argument('--create-csv', default=True, + help='Store computed metrics using a csv file in results/files (default=True)') # Start analysis run_analysis(parser.parse_args()) \ No newline at end of file From 75503ead5a87070d025327c3091807b39a9c6374 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Wed, 8 Nov 2023 21:30:57 -0500 Subject: [PATCH 06/20] Add utils functions --- scripts/utils.py | 319 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 255 insertions(+), 64 deletions(-) diff --git a/scripts/utils.py b/scripts/utils.py index 8a04628..879639b 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -5,6 +5,7 @@ import seaborn as sns import matplotlib.pyplot as plt import numpy as np +import glob from image import Image @@ -15,7 +16,7 @@ 't1_t2': ['T1w', 'T2w']} ## Functions -def get_img_path_from_mask_path(str_path): +def get_img_path_from_mask_path(str_path, derivatives_folder='derivatives'): """ This function does 2 things: ⚠️ Files need to be stored in a BIDS compliant dataset - Step 1: Remove label suffix (e.g. "_labels-disc-manual"). The suffix is always between the MRI contrast and the file extension. @@ -38,7 +39,7 @@ def get_img_path_from_mask_path(str_path): dir_list = str(path.parent).split('/') # Remove "derivatives" and "labels" folders - derivatives_idx = dir_list.index('derivatives') + derivatives_idx = dir_list.index(derivatives_folder) dir_path = '/'.join(dir_list[0:derivatives_idx] + dir_list[derivatives_idx+2:]) # Recreate img path @@ -47,37 +48,60 @@ def get_img_path_from_mask_path(str_path): return img_path ## -def get_mask_path_from_img_path(img_path, suffix='_seg', derivatives_path='/derivatives/labels'): +def get_mask_path_from_img_path(img_path, deriv_sub_folders, short_suffix='_seg', ext='.nii.gz', counterexample=[]): """ - This function returns the mask path from an image path. Images need to be stored in a BIDS compliant dataset. + This function returns the mask path from an image path or an empty string if the path does not exists. Images need to be stored in a BIDS compliant dataset. :param img_path: String path to niftii image :param suffix: Mask suffix - :param derivatives_path: Relative path to derivatives folder where labels are stored (e.i. '/derivatives/labels') + :param ext: File extension + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark """ # Extract information from path - subjectID, sessionID, filename, contrast, echoID = fetch_subject_and_session(img_path) + subjectID, sessionID, filename, contrast, echoID, acq = fetch_subject_and_session(img_path) + + # Find corresponding mask + mask_path = [] + for deriv_path in deriv_sub_folders: + if counterexample: # Deal with counter examples + paths = [] + for path in glob.glob(deriv_path + filename.split(ext)[0] + short_suffix + "*" + ext): + iswrong = False + for c in counterexample: + if c in path: + iswrong = True + if not iswrong: + paths.append(path) + else: + paths = glob.glob(deriv_path + filename.split(ext)[0] + short_suffix + "*" + ext) - # Extract file extension - path_obj = Path(img_path) - ext = ''.join(path_obj.suffixes) + if len(paths) > 1: + print(f'Image {img_path} has multiple masks\n: {'\n'.join(paths)}') + elif len(paths) == 1: + mask_path.append(paths[0]) + return mask_path - # Create mask name - mask_name = path_obj.name.split('.')[0] + suffix + ext +def get_deriv_sub_from_img_path(img_path, derivatives_folder='derivatives'): + """ + This function returns the derivatives path of the subject from an image path or an empty string if the path does not exists. Images need to be stored in a BIDS compliant dataset. - # Split path using "/" (TODO: check if it works for windows users) - path_list = img_path.split('/') + :param img_path: String path to niftii image + :param derivatives_folder: List of derivatives paths + :param ext: File extension + """ + # Extract information from path + subjectID, sessionID, filename, contrast, echoID, acq = fetch_subject_and_session(img_path) + path_bids, path_sub_folder = img_path.split(subjectID)[0:-1] + path_sub_folder = subjectID + path_sub_folder - # Extract subject folder index - sub_folder_idx = path_list.index(subjectID) + # Find corresponding mask + deriv_sub_folder = glob.glob(path_bids + "**/" + derivatives_folder + "/**/" + path_sub_folder, recursive=True) - # Reconstruct mask_path - mask_path = os.path.join('/'.join(path_list[:sub_folder_idx]), derivatives_path, path_list[sub_folder_idx:-1], mask_name) - return mask_path + return deriv_sub_folder ## -def change_mask_suffix(mask_path, new_suffix='_seg'): +def change_mask_suffix(mask_path, short_suffix='_seg', ext='.nii.gz'): """ This function replace the current suffix with a new suffix suffix. If path is specified, make sure the dataset is BIDS compliant. @@ -85,14 +109,33 @@ def change_mask_suffix(mask_path, new_suffix='_seg'): :param new_suffix: New mask suffix Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark """ + # Extract information from path + subjectID, sessionID, filename, contrast, echoID, acq = fetch_subject_and_session(mask_path) + path_deriv_sub = mask_path.split(filename)[0] - # Extract file extension - ext = ''.join(Path(mask_path).suffixes) + # Find corresponding new_mask + new_mask_path = glob.glob(path_deriv_sub + '_'.join(filename.split('_')[:-1]) + short_suffix + "*" + ext) + + if len(new_mask_path) > 1: + print(f'Multiple {short_suffix} masks for subject {subjectID} \n: {'\n'.join(new_mask_path)}') + mask_path = '' + elif len(new_mask_path) == 1: + new_mask_path = new_mask_path[0] + else: # mask does not exist + new_mask_path = '' - # Change mask path - new_mask_path = '_'.join(mask_path.split('_')[:-1]) + new_suffix + ext return new_mask_path + +def list_suffixes(folder_path, ext='.nii.gz'): + """ + This function return all the labels suffixes. If path is specified, make sure the dataset is BIDS compliant. + + :param folder_path: Path to folder where labels are stored. + """ + files = [file for file in os.listdir(folder_path) if file.endswith(ext)] + suffixes = ['_'+file.split('_')[-1].split(ext)[0] for file in files] + return suffixes ## def fetch_subject_and_session(filename_path): """ @@ -139,59 +182,185 @@ def fetch_contrast(filename_path): ''' return filename_path.rstrip(''.join(Path(filename_path).suffixes)).split('_')[-1] +def str_to_str_list(string): + string = string[1:-1] # remove brackets + return [s[1:-1] for s in string.split(', ')] + +def str_to_float_list(string): + string = string[1:-1] # remove brackets + return [float(s) for s in string.split(', ')] -def edit_metric_dict(metrics_dict, img_path, mask_path, discs_mask_path): + +def edit_metric_dict(metrics_dict, fprint_dict, img_path, seg_paths, discs_paths, deriv_sub_folders): ''' This function extracts information and metadata from an image and its mask. Values are then gathered inside a dictionary. - :param metrics_dict: dictionary where information will be gathered + :param metrics_dict: dictionary containing summary metadata + :param fprint_dict: dictionary containing all the informations :param img_path: niftii image path - :param discs_mask_path: corresponding niftii discs mask path + :param seg_path: corresponding niftii spinal cord segmentation path + :param discs_path: corresponding niftii discs mask path Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark ''' #-----------------------------------------------------------------------# #----------------------- Extracting metadata ---------------------------# #-----------------------------------------------------------------------# - - if os.path.exists(discs_mask_path): # TODO: deal with datasets with no discs labels - discs_mask = Image(discs_mask_path) - discs_labels = [list(coord)[-1] for coord in discs_mask.getNonZeroCoordinates(sorting='value')] - else: - discs_labels = [] - # Extract original image orientation img = Image(img_path) orientation = img.orientation + # Extract information from path + subjectID, sessionID, filename, c, echoID, acq = fetch_subject_and_session(img_path) + # Extract image dimensions and resolutions img_RPI = img.change_orientation("RPI") nx, ny, nz, nt, px, py, pz, pt = img_RPI.dim - - # Check for shape mismatch between mask and image - if img.data.shape != Image(mask_path).data.shape: - shape_mismatch = True - else: - shape_mismatch = False + + # Extract discs + check for shape mismatch between discs labels and image + discs_labels = [] + count_discs = 0 + if discs_paths: + for path in discs_paths: + discs_mask = Image(path).change_orientation("RPI") + discs_labels += [list(coord)[-1] for coord in discs_mask.getNonZeroCoordinates(sorting='value')] + if img_RPI.data.shape != discs_mask.data.shape: + count_discs += 1 + + # Check for shape mismatch between segmentation and image + count_seg = 0 + if seg_paths: + for path in seg_paths: + if img_RPI.data.shape != Image(path).change_orientation("RPI").data.shape: + count_seg += 1 + + # Compute image size + X, Y, Z = nx*px, ny*py, nz*pz # Extract MRI contrast contrast = fetch_contrast(img_path) - #-----------------------------------------------------------------------# - #--------------------- Adding metadata to dict -------------------------# - #-----------------------------------------------------------------------# - list_of_metrics = [img_path, orientation, contrast, discs_labels, shape_mismatch, nx, ny, nz, nt, px, py, pz, pt] - list_of_keys = ['img_path', 'orientation', 'contrast', 'discs_labels', 'shape_mismatch', 'nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt'] + # Extract suffixes + suffixes = [] + for path in deriv_sub_folders: + for suf in list_suffixes(path): + if not suf in suffixes: + suffixes.append(suf) + + # Extract derivatives folder + der_folders = [] + for path in deriv_sub_folders: + der_folders.append(os.path.basename(os.path.dirname(path.split(subjectID)[0]))) + + #-------------------------------------------------------------------------------# + #--------------------- Adding metadata to summary dict -------------------------# + #-------------------------------------------------------------------------------# + list_of_metrics = [orientation, contrast, X, Y, Z, nx, ny, nz, nt, px, py, pz, pt] + list_of_keys = ['orientation', 'contrast', 'X', 'Y', 'Z', 'nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt'] for key, metric in zip(list_of_keys, list_of_metrics): - if not isinstance(metric, list): - metric = [metric] + if not isinstance(metric,str): + metric = str(metric) if key not in metrics_dict.keys(): - metrics_dict[key] = metric + metrics_dict[key] = {metric:1} else: - metrics_dict[key] += metric + if metric not in metrics_dict[key].keys(): + metrics_dict[key][metric] = 1 + else: + metrics_dict[key][metric] += 1 + + # Add count shape mismatch + key_mis_seg = 'mismatch-seg' + if key_mis_seg not in metrics_dict.keys(): + metrics_dict[key_mis_seg] = count_seg + else: + metrics_dict[key_mis_seg] += count_seg + + key_mis_disc = 'mismatch-disc' + if key_mis_disc not in metrics_dict.keys(): + metrics_dict[key_mis_disc] = count_discs + else: + metrics_dict[key_mis_disc] += count_discs + + # Add discs labels + key_discs = 'discs-labels' + if discs_labels: + if key_discs not in metrics_dict.keys(): + metrics_dict[key_discs] = {} + for disc in discs_labels: + disc = str(disc) + if disc not in metrics_dict[key_discs].keys(): + metrics_dict[key_discs][disc] = 1 + else: + metrics_dict[key_discs][disc] += 1 + + # Add suffixes + suf_key = 'suffixes' + if suf_key not in metrics_dict.keys(): + metrics_dict[suf_key] = suffixes + else: + for suf in suffixes: + if not suf in metrics_dict[suf_key]: + metrics_dict[suf_key].append(suf) + + # Add derivatives folders + der_key = 'derivatives' + if der_key not in metrics_dict.keys(): + metrics_dict[der_key] = der_folders + else: + for der in der_folders: + if not der in metrics_dict[der_key]: + metrics_dict[der_key].append(der) + + #--------------------------------------------------------------------------------# + #--------------------- Storing metadata to exhaustive dict -------------------------# + #--------------------------------------------------------------------------------# + fprint_dict[filename] = {} + + # Add contrast + fprint_dict[filename]['contrast'] = contrast + + # Add orientation + fprint_dict[filename]['img_orientation'] = orientation + + # Add info SC segmentations + if seg_paths: + fprint_dict[filename]['seg-sc'] = True + suf_seg = ['_' + path.split('_')[-1].split('.')[0] for path in seg_paths] + fprint_dict[filename]['seg-suffix'] = '/'.join(suf_seg) + fprint_dict[filename]['seg-mismatch'] = count_seg + else: + fprint_dict[filename]['seg-sc'] = False + fprint_dict[filename]['seg-suffix'] = '' + fprint_dict[filename]['seg-mismatch'] = count_seg - return metrics_dict + # Add info discs labels + if discs_paths: + fprint_dict[filename]['discs-label'] = True + suf_discs = ['_' + path.split('_')[-1].split('.')[0] for path in discs_paths] + fprint_dict[filename]['discs-suffix'] = '/'.join(suf_discs) + fprint_dict[filename]['discs-mismatch'] = count_discs + else: + fprint_dict[filename]['discs-label'] = False + fprint_dict[filename]['discs-suffix'] = '' + fprint_dict[filename]['discs-mismatch'] = count_discs + + # Add discs labels + key_discs = 'discs-labels' + label_list = np.arange(1,27).tolist() + [49, 50, 60] + for num_label in label_list: + if num_label in discs_labels: + fprint_dict[filename][f'label_{str(num_label)}'] = True + else: + fprint_dict[filename][f'label_{str(num_label)}'] = False + + # Add dim and resolutions + list_of_metrics = [X, Y, Z, nx, ny, nz, nt, px, py, pz, pt] + list_of_keys = ['X', 'Y', 'Z', 'nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt'] + for key, metric in zip(list_of_keys, list_of_metrics): + fprint_dict[filename][key] = metric + + return metrics_dict, fprint_dict def save_violin(names, values, output_path, x_axis, y_axis): @@ -248,8 +417,10 @@ def save_hist(names, values, output_path, x_axis, y_axis): result_df = pd.DataFrame(data=result_dict) # Make the plot - plt.figure() - sns.histplot(data=result_df, x="values", hue="names", multiple="dodge", binwidth=1/len(names)) + binwidth= 1/(1*len(names)) if len(names) > 1 else 1/3 + shrink = 1 if len(names) > 1 else 0.7 + plt.figure(figsize=(np.max(result_dict['values']), 8)) + sns.histplot(data=result_df, x="values", hue="names", multiple="dodge", binwidth=binwidth, shrink=shrink) plt.xlabel(x_axis, fontsize = 15) plt.xticks(np.arange(1, np.max(result_dict['values'])+1)) plt.ylabel(x_axis, fontsize = 15) @@ -298,24 +469,40 @@ def my_format(pct): fig = plt.figure() plt.pie(result_dict[names[0]].values(), labels=result_dict[names[0]].keys(), colors=palette_color, autopct=autopct_format(result_dict[names[0]].values())) plt.title(y_axis, fontsize = 20) - plt.xlabel(x_axis, fontsize = 15) + plt.xlabel(names[0], fontsize = 15) plt.ylabel(y_axis, fontsize = 15) else: fig, axs = plt.subplots(1, len(names), figsize=(3*len(names),5)) - fig.suptitle(y_axis) + fig.suptitle(y_axis, fontsize = 8*len(names)) for j, name in enumerate(result_dict.keys()): - axs[j].pie(result_dict[name].values(), labels=result_dict[name].keys(), colors=palette_color, autopct=autopct_format(result_dict[names[0]].values())) + axs[j].pie(result_dict[name].values(), labels=result_dict[name].keys(), colors=palette_color, autopct=autopct_format(result_dict[names[j]].values())) axs[j].set_title(name) - - for ax, name in zip(axs.flat, names): - ax.set(xlabel=name, ylabel=y_axis) + + axs[0].set(ylabel=y_axis) # Save plot plt.savefig(output_path) +def convert_dict_to_float_list(dic): + """ + This function converts dictionary with {str(value):int(nb_occurence)} to a list [float(value)]*nb_occurence + """ + out_list = [] + for value, count in dic.items(): + out_list += [float(value)]*count + return out_list -def save_graphs(output_folder, metrics_dict, data_type='split'): +def convert_dict_to_list(dic): + """ + This function converts dictionary with {str(value):int(nb_occurence)} to a list [str(value)]*nb_occurence + """ + out_list = [] + for value, count in dic.items(): + out_list += [value]*count + return out_list + +def save_graphs(output_folder, metrics_dict, data_form='split'): ''' Plot and save metrics into an output folder @@ -323,19 +510,23 @@ def save_graphs(output_folder, metrics_dict, data_type='split'): ''' # Extract subjects and metrics data_name = np.array(list(metrics_dict.keys())) - metrics_names = list(metrics_dict[data_name[0]].keys()) # Use violin plots - for metric in ['nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt']: + for metric, unit in zip(['nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt', 'X', 'Y', 'Z'], ['pixel', 'pixel', 'pixel', '', 'mm/pixel', 'mm/pixel', 'mm/pixel', '', 'mm', 'mm', 'mm']): out_path = os.path.join(output_folder, f'{metric}.png') - save_violin(names=data_name, values=[metrics_dict[name][metric] for name in data_name], output_path=out_path, x_axis=data_type, y_axis=metric) + metric_name = metric + ' ' + f'({unit})' + save_violin(names=data_name, values=[convert_dict_to_float_list(metrics_dict[name][metric]) for name in data_name], output_path=out_path, x_axis=data_form, y_axis=metric_name) # Use bar pie chart for metric in ['orientation', 'contrast']: out_path = os.path.join(output_folder, f'{metric}.png') - save_pie(names=data_name, values=[metrics_dict[name][metric] for name in data_name], output_path=out_path, x_axis=data_type, y_axis=metric) + save_pie(names=data_name, values=[convert_dict_to_list(metrics_dict[name][metric]) for name in data_name], output_path=out_path, x_axis=data_form, y_axis=metric) # Use bar graphs - for metric in ['discs_labels']: + for metric in ['discs-labels']: out_path = os.path.join(output_folder, f'{metric}.png') - save_hist(names=data_name, values=[metrics_dict[name][metric] for name in data_name], output_path=out_path, x_axis=metric, y_axis='Count') + save_hist(names=data_name, values=[convert_dict_to_float_list(metrics_dict[name][metric]) for name in data_name], output_path=out_path, x_axis=metric, y_axis='Count') + +def mergedict(a,b): + a.update(b) + return a \ No newline at end of file From d7f4917753473446c846c18e0784a4a47e7e5133 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Fri, 10 Nov 2023 18:19:59 -0500 Subject: [PATCH 07/20] add counterexample --- scripts/analyze_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/analyze_data.py b/scripts/analyze_data.py index eda8dfc..b93bef9 100644 --- a/scripts/analyze_data.py +++ b/scripts/analyze_data.py @@ -94,7 +94,7 @@ def run_analysis(args): img_path = path # str deriv_sub_folders = get_deriv_sub_from_img_path(img_path=img_path, derivatives_folder=derivatives_folder) # list of str seg_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_seg, deriv_sub_folders=deriv_sub_folders) # list of str - discs_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_disc, deriv_sub_folders=deriv_sub_folders, counterexample=['compression', 'SC_mask']) # list of str + discs_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_disc, deriv_sub_folders=deriv_sub_folders, counterexample=['compression', 'SC_mask', 'seg']) # list of str else: img_path = get_img_path_from_mask_path(path, derivatives_folder=derivatives_folder) deriv_sub_folders = [os.path.dirname(path)] From c60b973944771c67f57927dde555da6b506028e3 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Fri, 10 Nov 2023 18:21:10 -0500 Subject: [PATCH 08/20] improve robustness --- scripts/utils.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/scripts/utils.py b/scripts/utils.py index 879639b..520f1de 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -74,7 +74,7 @@ def get_mask_path_from_img_path(img_path, deriv_sub_folders, short_suffix='_seg' if not iswrong: paths.append(path) else: - paths = glob.glob(deriv_path + filename.split(ext)[0] + short_suffix + "*" + ext) + paths = glob.glob(deriv_path + filename.split(ext)[0] + "*" + short_suffix + "*" + ext) if len(paths) > 1: print(f'Image {img_path} has multiple masks\n: {'\n'.join(paths)}') @@ -127,14 +127,25 @@ def change_mask_suffix(mask_path, short_suffix='_seg', ext='.nii.gz'): return new_mask_path -def list_suffixes(folder_path, ext='.nii.gz'): +def list_der_suffixes(folder_path, ext='.nii.gz'): """ This function return all the labels suffixes. If path is specified, make sure the dataset is BIDS compliant. :param folder_path: Path to folder where labels are stored. """ + folder_path = os.path.normpath(folder_path) files = [file for file in os.listdir(folder_path) if file.endswith(ext)] - suffixes = ['_'+file.split('_')[-1].split(ext)[0] for file in files] + suffixes = [] + for file in files: + subjectID, sessionID, filename, contrast, echoID, acquisition = fetch_subject_and_session(file) + split_file = file.split(ext)[0].split('_') + skip_idx = 0 + for sp in [subjectID, sessionID, echoID, acquisition]: + if sp: + skip_idx = skip_idx + 1 + suffix = '_' + '_'.join(split_file[skip_idx+1:]) # +1 to skip contrast + if not suffix =='_': + suffixes.append(suffix) return suffixes ## def fetch_subject_and_session(filename_path): @@ -175,7 +186,7 @@ def fetch_subject_and_session(filename_path): def fetch_contrast(filename_path): ''' - Extract MRI contrast from a BIDS-compatible filename/filepath + Extract MRI contrast from a BIDS-compatible IMAGE filename/filepath The function handles images only. :param filename_path: image file path or file name. (e.g sub-001_ses-01_T1w.nii.gz) Copied from https://github.com/spinalcordtoolbox/disc-labeling-hourglass @@ -238,13 +249,13 @@ def edit_metric_dict(metrics_dict, fprint_dict, img_path, seg_paths, discs_paths # Compute image size X, Y, Z = nx*px, ny*py, nz*pz - # Extract MRI contrast + # Extract MRI contrast from image only contrast = fetch_contrast(img_path) # Extract suffixes suffixes = [] for path in deriv_sub_folders: - for suf in list_suffixes(path): + for suf in list_der_suffixes(path): if not suf in suffixes: suffixes.append(suf) @@ -326,7 +337,7 @@ def edit_metric_dict(metrics_dict, fprint_dict, img_path, seg_paths, discs_paths # Add info SC segmentations if seg_paths: fprint_dict[filename]['seg-sc'] = True - suf_seg = ['_' + path.split('_')[-1].split('.')[0] for path in seg_paths] + suf_seg = [path.split(contrast)[-1].split('.')[0] for path in seg_paths] fprint_dict[filename]['seg-suffix'] = '/'.join(suf_seg) fprint_dict[filename]['seg-mismatch'] = count_seg else: @@ -337,7 +348,7 @@ def edit_metric_dict(metrics_dict, fprint_dict, img_path, seg_paths, discs_paths # Add info discs labels if discs_paths: fprint_dict[filename]['discs-label'] = True - suf_discs = ['_' + path.split('_')[-1].split('.')[0] for path in discs_paths] + suf_discs = [path.split(contrast)[-1].split('.')[0] for path in discs_paths] fprint_dict[filename]['discs-suffix'] = '/'.join(suf_discs) fprint_dict[filename]['discs-mismatch'] = count_discs else: From 39589e2fb3ab096ed2417d775542e339499001d0 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Mon, 13 Nov 2023 19:53:29 -0500 Subject: [PATCH 09/20] add counterexamples --- scripts/analyze_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/analyze_data.py b/scripts/analyze_data.py index b93bef9..6a778f0 100644 --- a/scripts/analyze_data.py +++ b/scripts/analyze_data.py @@ -93,8 +93,8 @@ def run_analysis(args): if isImage: img_path = path # str deriv_sub_folders = get_deriv_sub_from_img_path(img_path=img_path, derivatives_folder=derivatives_folder) # list of str - seg_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_seg, deriv_sub_folders=deriv_sub_folders) # list of str - discs_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_disc, deriv_sub_folders=deriv_sub_folders, counterexample=['compression', 'SC_mask', 'seg']) # list of str + seg_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_seg, deriv_sub_folders=deriv_sub_folders, counterexample=['lesion', 'GM', 'WM']) # list of str + discs_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_disc, deriv_sub_folders=deriv_sub_folders, counterexample=['compression', 'SC_mask', 'seg', 'lesion', 'GM', 'WM']) # list of str else: img_path = get_img_path_from_mask_path(path, derivatives_folder=derivatives_folder) deriv_sub_folders = [os.path.dirname(path)] From c9bd95b885702aae712716ecc4c12d98aa296755 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Mon, 13 Nov 2023 19:54:44 -0500 Subject: [PATCH 10/20] Add group plot instead of individual plots --- scripts/utils.py | 52 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/scripts/utils.py b/scripts/utils.py index 520f1de..09d532c 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -406,6 +406,40 @@ def save_violin(names, values, output_path, x_axis, y_axis): plt.savefig(output_path) +def save_group_violins(name, values, output_path, x_axis, y_axis): + ''' + Create a violin plot + :param name: Dataset name + :param values: List of metrics containing lists of values associated with the names + :param output_path: Output path (string) + :param x_axis: x-axis name + :param y_axis: List of y-axis name corresponding to each metrics + ''' + + # Create plot + fig, axs = plt.subplots(3, len(values)//3 + 1, figsize=(1.8*len(values),11)) + + fig.suptitle(f'{x_axis} : {name}', fontsize = 30) + + for idx_line, val in enumerate(values): + # Set position of bar on X axis + result_dict = {} + result_dict['values'] = val + result_dict['metrics'] = [y_axis[idx_line]]*len(val) + + result_df = pd.DataFrame(data=result_dict) + + # Make the plot + sns.violinplot(ax=axs[idx_line//4, idx_line%4], x="metrics", y="values", data=result_df) + axs[idx_line//4, idx_line%4].set(xticklabels=[]) + axs[idx_line//4, idx_line%4].set_ylabel("") + axs[idx_line//4, idx_line%4].set_xlabel("") + axs[idx_line//4, idx_line%4].set_title(y_axis[idx_line], fontsize=20) + + # Save plot + plt.savefig(output_path) + + def save_hist(names, values, output_path, x_axis, y_axis): ''' Create a histogram plot @@ -523,10 +557,20 @@ def save_graphs(output_folder, metrics_dict, data_form='split'): data_name = np.array(list(metrics_dict.keys())) # Use violin plots - for metric, unit in zip(['nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt', 'X', 'Y', 'Z'], ['pixel', 'pixel', 'pixel', '', 'mm/pixel', 'mm/pixel', 'mm/pixel', '', 'mm', 'mm', 'mm']): - out_path = os.path.join(output_folder, f'{metric}.png') - metric_name = metric + ' ' + f'({unit})' - save_violin(names=data_name, values=[convert_dict_to_float_list(metrics_dict[name][metric]) for name in data_name], output_path=out_path, x_axis=data_form, y_axis=metric_name) + # for metric, unit in zip(['nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt', 'X', 'Y', 'Z'], ['pixel', 'pixel', 'pixel', '', 'mm/pixel', 'mm/pixel', 'mm/pixel', '', 'mm', 'mm', 'mm']): + # out_path = os.path.join(output_folder, f'{metric}.png') + # metric_name = metric + ' ' + f'({unit})' + # save_violin(names=data_name, values=[convert_dict_to_float_list(metrics_dict[name][metric]) for name in data_name], output_path=out_path, x_axis=data_form, y_axis=metric_name) + + # Save violin plot in one fig + for name in data_name: + tot_values = [] + tot_names = [] + for metric, unit in zip(['nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt', 'X', 'Y', 'Z'], ['pixel', 'pixel', 'pixel', '', 'mm/pixel', 'mm/pixel', 'mm/pixel', '', 'mm', 'mm', 'mm']): + tot_values.append(convert_dict_to_float_list(metrics_dict[name][metric])) + tot_names.append(metric + ' ' + f'({unit})') + out_path = os.path.join(output_folder, f'violin_stats.png') + save_group_violins(name=name, values=tot_values, output_path=out_path, x_axis=data_form, y_axis=tot_names) # Use bar pie chart for metric in ['orientation', 'contrast']: From 4a14ae71f110ca4a115137d13fe1fa41be1d5b61 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Mon, 13 Nov 2023 19:55:21 -0500 Subject: [PATCH 11/20] deal with too small values in pie chart --- scripts/utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/scripts/utils.py b/scripts/utils.py index 09d532c..8d64e74 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -495,6 +495,18 @@ def save_pie(names, values, output_path, x_axis, y_axis): result_dict[name][val] = 1 else: result_dict[name][val] += 1 + # Regroup small values + other_count = 0 + other_name_list = [] + for v, count in result_dict[name].items(): + if count <= math.ceil(0.004*len(values[i])): + other_count += count + other_name_list.append(v) + for v in other_name_list: + del result_dict[name][v] + if other_name_list: + result_dict[name]['other'] = other_count + # define Seaborn color palette to use palette_color = sns.color_palette('bright') From 9c9e4673c060e13a680136659372c8d89aa1c06e Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Mon, 13 Nov 2023 19:55:34 -0500 Subject: [PATCH 12/20] quick fixes --- scripts/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/utils.py b/scripts/utils.py index 8d64e74..723ad41 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np import glob +import math from image import Image @@ -378,7 +379,7 @@ def save_violin(names, values, output_path, x_axis, y_axis): ''' Create a violin plot :param names: String list of the names - :param values: Values associated with the names + :param values: List of values associated with the names :param output_path: Output path (string) :param x_axis: x-axis name :param y_axis: y-axis name @@ -527,7 +528,7 @@ def my_format(pct): plt.pie(result_dict[names[0]].values(), labels=result_dict[names[0]].keys(), colors=palette_color, autopct=autopct_format(result_dict[names[0]].values())) plt.title(y_axis, fontsize = 20) plt.xlabel(names[0], fontsize = 15) - plt.ylabel(y_axis, fontsize = 15) + #plt.ylabel(y_axis, fontsize = 15) else: fig, axs = plt.subplots(1, len(names), figsize=(3*len(names),5)) fig.suptitle(y_axis, fontsize = 8*len(names)) From a6f20ab0af963e5b71ae65d0560325625403504f Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Fri, 17 Nov 2023 12:52:54 -0500 Subject: [PATCH 13/20] Improve sharing --- scripts/init_data_config.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/scripts/init_data_config.py b/scripts/init_data_config.py index 43013f8..9126c21 100644 --- a/scripts/init_data_config.py +++ b/scripts/init_data_config.py @@ -1,10 +1,13 @@ +# Based on https://github.com/spinalcordtoolbox/disc-labeling-hourglass + import os import argparse import random import json import itertools +import numpy as np -from utils import CONTRAST, get_img_path_from_label_path, fetch_contrast +from utils import CONTRAST, get_img_path_from_mask_path, fetch_contrast CONTRAST_LOOKUP = {tuple(sorted(value)): key for key, value in CONTRAST.items()} @@ -15,7 +18,7 @@ def init_data_config(args): """ Create a JSON configuration file from a TXT file where images paths are specified """ - if (args.split_validation + args.split_test) >= 1: + if (args.split_validation + args.split_test) > 1: raise ValueError("The sum of the ratio between testing and validation cannot exceed 1") # Get input paths, could be label files or image files, @@ -23,7 +26,7 @@ def init_data_config(args): file_paths = [os.path.abspath(path.replace('\n', '')) for path in open(args.txt)] if args.type == 'LABEL': label_paths = file_paths - img_paths = [get_img_path_from_label_path(lp) for lp in label_paths] + img_paths = [get_img_path_from_mask_path(lp) for lp in label_paths] file_paths = label_paths + img_paths elif args.type == 'IMAGE': img_paths = file_paths @@ -33,8 +36,18 @@ def init_data_config(args): path for path in file_paths if not os.path.isfile(path) ] + if missing_paths: raise ValueError("missing files:\n" + '\n'.join(missing_paths)) + + # Extract BIDS parent folder path + dataset_parent_path_list = ['/'.join(path.split('/sub')[0].split('/')[:-1]) for path in img_paths] + + # Check if all the BIDS folders are stored inside the same parent repository + if (np.array(dataset_parent_path_list) == dataset_parent_path_list[0]).all(): + dataset_parent_path = dataset_parent_path_list[0] + else: + raise ValueError('Please store all the BIDS datasets inside the same parent folder !') # Look up the right code for the set of contrasts present contrasts = CONTRAST_LOOKUP[tuple(sorted(set(map(fetch_contrast, img_paths))))] @@ -42,11 +55,13 @@ def init_data_config(args): config = { 'TYPE': args.type, 'CONTRASTS': contrasts, + 'DATASETS_PATH': dataset_parent_path } # Split into training, validation, and testing sets split_ratio = (1 - (args.split_validation + args.split_test), args.split_validation, args.split_test) # TRAIN, VALIDATION, and TEST config_paths = label_paths if args.type == 'LABEL' else img_paths + config_paths = [path.split(dataset_parent_path + '/')[-1] for path in config_paths] # Remove DATASETS_PATH random.shuffle(config_paths) splits = [0] + [ int(len(config_paths) * ratio) @@ -83,4 +98,9 @@ def pairwise(iterable): parser.add_argument('--split-test', type=float, default=0.1, help='Split ratio for testing. Default=0.1') - init_data_config(parser.parse_args()) + args = parser.parse_args() + + if args.split_test > 0.9: + args.split_validation = 1 - args.split_test + + init_data_config(args) From 2190249ae57903edf089c090c60b0db869475c06 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Wed, 22 Nov 2023 14:56:22 -0500 Subject: [PATCH 14/20] add contrasts to dict --- scripts/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scripts/utils.py b/scripts/utils.py index 723ad41..2367811 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -14,7 +14,12 @@ CONTRAST = {'t1': ['T1w'], 't2': ['T2w'], 't2s':['T2star'], - 't1_t2': ['T1w', 'T2w']} + 't1_t2': ['T1w', 'T2w'], + 'psir': ['PSIR'], + 'stir': ['STIR'], + 'psir_stir': ['PSIR', 'STIR'], + 't1_t2_psir_stir': ['T1w', 'T2w', 'PSIR', 'STIR'] + } ## Functions def get_img_path_from_mask_path(str_path, derivatives_folder='derivatives'): From 273fc47cb3f8a9cbcc12f18ba9a3f7fc7c02bc3c Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Mon, 11 Mar 2024 10:37:23 -0400 Subject: [PATCH 15/20] update init_data_config with contrast method --- scripts/init_data_config.py | 30 +++++++++++++++++++++--------- scripts/utils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/scripts/init_data_config.py b/scripts/init_data_config.py index 9126c21..66eac79 100644 --- a/scripts/init_data_config.py +++ b/scripts/init_data_config.py @@ -1,4 +1,6 @@ -# Based on https://github.com/spinalcordtoolbox/disc-labeling-hourglass +""" +Script copied from https://github.com/spinalcordtoolbox/disc-labeling-hourglass +""" import os import argparse @@ -7,10 +9,7 @@ import itertools import numpy as np -from utils import CONTRAST, get_img_path_from_mask_path, fetch_contrast - - -CONTRAST_LOOKUP = {tuple(sorted(value)): key for key, value in CONTRAST.items()} +from utils import get_img_path_from_mask_path, get_cont_path_from_other_cont, fetch_contrast, fetch_subject_and_session # Determine specified contrasts @@ -26,10 +25,17 @@ def init_data_config(args): file_paths = [os.path.abspath(path.replace('\n', '')) for path in open(args.txt)] if args.type == 'LABEL': label_paths = file_paths - img_paths = [get_img_path_from_mask_path(lp) for lp in label_paths] + img_paths = [get_img_path_from_label_path(lp) for lp in label_paths] file_paths = label_paths + img_paths elif args.type == 'IMAGE': img_paths = file_paths + elif args.type == 'CONTRAST': + if not args.cont: # If the target contrast is not specified + raise ValueError(f'When using the type CONTRAST, please specify the target contrast using the flag "--cont"') + img_paths = file_paths + new_contrast = args.cont + label_paths = [get_cont_path_from_other_cont(ip) for ip in img_paths] + file_paths = label_paths + img_paths else: raise ValueError(f"invalid args.type: {args.type}") missing_paths = [ @@ -50,7 +56,7 @@ def init_data_config(args): raise ValueError('Please store all the BIDS datasets inside the same parent folder !') # Look up the right code for the set of contrasts present - contrasts = CONTRAST_LOOKUP[tuple(sorted(set(map(fetch_contrast, img_paths))))] + contrasts = "_".join(tuple(sorted(set(map(fetch_contrast, img_paths))))) config = { 'TYPE': args.type, @@ -58,6 +64,10 @@ def init_data_config(args): 'DATASETS_PATH': dataset_parent_path } + # Add target contrast when the type CONTRAST is used + if args.type == 'CONTRAST': + config['TARGET_CONTRAST'] = args.cont + # Split into training, validation, and testing sets split_ratio = (1 - (args.split_validation + args.split_test), args.split_validation, args.split_test) # TRAIN, VALIDATION, and TEST config_paths = label_paths if args.type == 'LABEL' else img_paths @@ -91,8 +101,10 @@ def pairwise(iterable): ## Parameters parser.add_argument('--txt', required=True, help='Path to TXT file that contains only image or label paths. (Required)') - parser.add_argument('--type', choices=('LABEL', 'IMAGE'), - help='Type of paths specified. Choices "LABEL" or "IMAGE". (Required)') + parser.add_argument('--type', choices=('LABEL', 'IMAGE', 'CONTRAST'), + help='Type of paths specified. Choices are "LABEL", "IMAGE" or "CONTRAST". (Required)') + parser.add_argument('--cont', type=str, default='', + help='If the type CONTRAST is selected, this variable specifies the wanted contrast for target.') parser.add_argument('--split-validation', type=float, default=0.1, help='Split ratio for validation. Default=0.1') parser.add_argument('--split-test', type=float, default=0.1, diff --git a/scripts/utils.py b/scripts/utils.py index 2367811..48d860a 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -88,6 +88,33 @@ def get_mask_path_from_img_path(img_path, deriv_sub_folders, short_suffix='_seg' mask_path.append(paths[0]) return mask_path + +def get_cont_path_from_other_cont(str_path, cont): + """ + :param str_path: absolute path to the input nifti img. Example: //sub-amuALT/anat/sub-amuALT_T1w.nii.gz + :param cont: contrast of the target output image stored in the same data folder. Example: T2w + :return: path to the output target image. Example: //sub-amuALT/anat/sub-amuALT_T2w.nii.gz + + """ + # Load path + path = Path(str_path) + + # Extract file extension + ext = ''.join(path.suffixes) + + # Remove input contrast from name + path_list = path.name.split('_') + suffixes_pos = [1 if len(part.split('-')) == 1 else 0 for part in path_list] + contrast_idx = suffixes_pos.index(1) # Find suffix + + # New image name + img_name = '_'.join(path_list[:contrast_idx]+[cont]) + ext + + # Recreate img path + img_path = os.path.join(str(path.parent), img_name) + + return img_path + def get_deriv_sub_from_img_path(img_path, derivatives_folder='derivatives'): """ This function returns the derivatives path of the subject from an image path or an empty string if the path does not exists. Images need to be stored in a BIDS compliant dataset. From 6d23b64f3357408b7aaf5135df7914c8fda2c34b Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Mon, 11 Mar 2024 10:37:48 -0400 Subject: [PATCH 16/20] update with csv --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 1ea5fee..d61db89 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__ .idea/ venv/ +.csv \ No newline at end of file From b689e0b7e536909e0e1e12a0b9cc01fdb25c97ec Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Mon, 11 Mar 2024 10:38:59 -0400 Subject: [PATCH 17/20] update with output folders --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index d61db89..47f14f7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ __pycache__ .idea/ venv/ -.csv \ No newline at end of file +.csv +.vscode/ +results/ \ No newline at end of file From b093d015e1e15023c510fd02eb263853c5c7acf3 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Mon, 11 Mar 2024 14:12:27 -0400 Subject: [PATCH 18/20] update function name --- scripts/init_data_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/init_data_config.py b/scripts/init_data_config.py index 66eac79..1ce8915 100644 --- a/scripts/init_data_config.py +++ b/scripts/init_data_config.py @@ -25,7 +25,7 @@ def init_data_config(args): file_paths = [os.path.abspath(path.replace('\n', '')) for path in open(args.txt)] if args.type == 'LABEL': label_paths = file_paths - img_paths = [get_img_path_from_label_path(lp) for lp in label_paths] + img_paths = [get_img_path_from_mask_path(lp) for lp in label_paths] file_paths = label_paths + img_paths elif args.type == 'IMAGE': img_paths = file_paths From 446750b64489586468a859ae3822288b0d911590 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Mon, 11 Mar 2024 14:44:38 -0400 Subject: [PATCH 19/20] change quotes --- scripts/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/utils.py b/scripts/utils.py index 48d860a..c37dffa 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -83,7 +83,7 @@ def get_mask_path_from_img_path(img_path, deriv_sub_folders, short_suffix='_seg' paths = glob.glob(deriv_path + filename.split(ext)[0] + "*" + short_suffix + "*" + ext) if len(paths) > 1: - print(f'Image {img_path} has multiple masks\n: {'\n'.join(paths)}') + print(f'Image {img_path} has multiple masks\n: {"\n".join(paths)}') elif len(paths) == 1: mask_path.append(paths[0]) return mask_path From 898a94e15aeb9bb11d785cd22dec10a014b4e663 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Wed, 17 Apr 2024 09:58:55 -0400 Subject: [PATCH 20/20] add comments --- scripts/analyze_data.py | 4 ++++ scripts/init_data_config.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/scripts/analyze_data.py b/scripts/analyze_data.py index 6a778f0..de35480 100644 --- a/scripts/analyze_data.py +++ b/scripts/analyze_data.py @@ -1,3 +1,7 @@ +''' +This script loops on a config file (see init_data_config.py) to calculate metrics (.csv) and generate plots. +''' + import os import argparse import json diff --git a/scripts/init_data_config.py b/scripts/init_data_config.py index 1ce8915..96fe5d4 100644 --- a/scripts/init_data_config.py +++ b/scripts/init_data_config.py @@ -1,4 +1,6 @@ """ +Generate a config file with all the paths to the used files. +See https://github.com/spinalcordtoolbox/disc-labeling-hourglass/issues/25#issuecomment-1695818382 Script copied from https://github.com/spinalcordtoolbox/disc-labeling-hourglass """