diff --git a/dataset_conversion/convert_bids_to_nnUNetv2_multi-channel.py b/dataset_conversion/convert_bids_to_nnUNetv2_multi-channel.py new file mode 100644 index 0000000..bd973c8 --- /dev/null +++ b/dataset_conversion/convert_bids_to_nnUNetv2_multi-channel.py @@ -0,0 +1,418 @@ +""" +Convert BIDS-structured datasets (dcm-zurich-lesions, dcm-zurich-lesions-20231115) to the nnUNetv2 dataset MULTI-CHANNEL +format. + +dataset.json: + +```json + "channel_names": { + "0": "acq-ax_T2w", + "1": "SC_seg" + }, + "labels": { + "background": 0, + "lesion": 1 + }, +``` + +Full details about the format can be found here: +https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md + +The script to be used on a single dataset or multiple datasets. + +Note: the script performs RPI reorientation of the images and labels + +Usage example multiple datasets: + python convert_bids_to_nnUNetv2_multi-channel.py + --path-data ~/data/dcm-zurich-lesions ~/data/dcm-zurich-lesions-20231115 + --path-out ${nnUNet_raw} + -dname DCMlesions + -dnum 601 + --split 0.8 0.2 + --seed 50 + +Usage example single dataset: + python convert_bids_to_nnUNetv2_multi-channel.py + --path-data ~/data/dcm-zurich-lesions + --path-out ${nnUNet_raw} + -dname DCMlesions + -dnum 601 + --split 0.8 0.2 + --seed 50 + +Authors: Jan Valosek, Naga Karthik +""" + +import argparse +from pathlib import Path +import json +import os +import re +import shutil +import yaml +from collections import OrderedDict +from loguru import logger +from sklearn.model_selection import train_test_split +from utils import create_multi_channel_label, get_git_branch_and_commit, Image +from tqdm import tqdm + +import nibabel as nib + + +def get_parser(): + # parse command line arguments + parser = argparse.ArgumentParser(description='Convert BIDS-structured dataset to nnUNetV2 MULTI-CHANNEL format.') + parser.add_argument('--path-data', nargs='+', required=True, type=str, + help='Path to BIDS dataset(s) (list).') + parser.add_argument('--path-out', help='Path to output directory.', required=True) + parser.add_argument('--dataset-name', '-dname', default='DCMlesionsMultiChannel', type=str, + help='Specify the task name.') + parser.add_argument('--dataset-number', '-dnum', default=602, type=int, + help='Specify the task number, has to be greater than 500 but less than 999. e.g 502') + parser.add_argument('--seed', default=42, type=int, + help='Seed to be used for the random number generator split into training and test sets.') + # argument that accepts a list of floats as train val test splits + parser.add_argument('--split', nargs='+', type=float, default=[0.8, 0.2], + help='Ratios of training (includes validation) and test splits lying between 0-1. Example: ' + '--split 0.8 0.2') + return parser + + +def get_multi_channel_label(subject_label_file, subject_image_file, sub_ses_name, thr=0.5): + # define path for sc seg file + subject_seg_file = subject_label_file.replace('_label-lesion', '_label-SC_mask-manual') + + # check if the seg file exists + if not os.path.exists(subject_seg_file): + logger.info(f"Spinal cord segmentation file for subject {sub_ses_name} does not exist. Skipping.") + return None + + # create label for the multi-channel training (makes sure that the lesion seg is part of the spinal cord seg + # (the spinal cord seg is the first channel)) + seg_lesion_nii = create_multi_channel_label(subject_label_file, subject_seg_file, subject_image_file, + sub_ses_name, thr=thr) + + # save the label + combined_seg_file = subject_label_file.replace('_label-lesion', '_SC-lesion') + nib.save(seg_lesion_nii, combined_seg_file) + + return combined_seg_file + + +def create_directories(path_out, site): + """Create test directories for a specified site. + + Args: + path_out (str): Base output directory. + site (str): Site identifier, such as 'dcm-zurich-lesions + """ + paths = [Path(path_out, f'imagesTs_{site}'), + Path(path_out, f'labelsTs_{site}')] + + for path in paths: + path.mkdir(parents=True, exist_ok=True) + + +def find_site_in_path(path): + """Extracts site identifier from the given path. + + Args: + path (str): Input path containing a site identifier. + + Returns: + str: Extracted site identifier or None if not found. + """ + # Find 'dcm-zurich-lesions' or 'dcm-zurich-lesions-20231115' + match = re.search(r'dcm-zurich-lesions(-\d{8})?', path) + return match.group(0) if match else None + + +def create_yaml(train_niftis, test_nifitis, path_out, args, train_ctr, test_ctr, dataset_commits): + # create a yaml file containing the list of training and test niftis + niftis_dict = { + f"train": sorted(train_niftis), + f"test": sorted(test_nifitis) + } + + # write the train and test niftis to a yaml file + with open(os.path.join(path_out, f"train_test_split_seed{args.seed}.yaml"), "w") as outfile: + yaml.dump(niftis_dict, outfile, default_flow_style=False) + + # c.f. dataset json generation + # In nnUNet V2, dataset.json file has become much shorter. The description of the fields and changes + # can be found here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md#datasetjson + # this file can be automatically generated using the following code here: + # https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/dataset_conversion/generate_dataset_json.py + # example: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/dataset_conversion/Task055_SegTHOR.py + + json_dict = OrderedDict() + json_dict['name'] = args.dataset_name + json_dict['description'] = args.dataset_name + json_dict['reference'] = "TBD" + json_dict['licence'] = "TBD" + json_dict['release'] = "0.0" + json_dict['numTraining'] = train_ctr + json_dict['numTest'] = test_ctr + json_dict['seed_used'] = args.seed + json_dict['dataset_versions'] = dataset_commits + json_dict['image_orientation'] = "RPI" + + # The following keys are the most important ones. + """ + channel_names: + Channel names must map the index to the name of the channel. For BIDS, this refers to the contrast suffix. + { + 0: 'T1', + 1: 'CT' + } + Note that the channel names may influence the normalization scheme!! Learn more in the documentation. + + labels: + This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not. + Example regular labels: + { + 'background': 0, + 'left atrium': 1, + 'some other label': 2 + } + Example region-based training: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/region_based_training.md + { + 'background': 0, + 'whole tumor': (1, 2, 3), + 'tumor core': (2, 3), + 'enhancing tumor': 3 + } + Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background! + """ + + json_dict['channel_names'] = { + 0: "acq-ax_T2w", + 1: "SC_seg", + } + + json_dict['labels'] = { + "background": 0, + "lesion": 1, + } + + # Needed for finding the files correctly. IMPORTANT! File endings must match between images and segmentations! + json_dict['file_ending'] = ".nii.gz" + + # create dataset_description.json + json_object = json.dumps(json_dict, indent=4) + # write to dataset description + # nn-unet requires it to be "dataset.json" + dataset_dict_name = f"dataset.json" + with open(os.path.join(path_out, dataset_dict_name), "w") as outfile: + outfile.write(json_object) + + +def main(): + parser = get_parser() + args = parser.parse_args() + + train_ratio, test_ratio = args.split + path_out = Path(os.path.join(os.path.abspath(args.path_out), f'Dataset{args.dataset_number}_{args.dataset_name}' + f'Seed{args.seed}')) + + # create individual directories for train and test images and labels + path_out_imagesTr = Path(os.path.join(path_out, 'imagesTr')) + path_out_labelsTr = Path(os.path.join(path_out, 'labelsTr')) + # create the training directories + Path(path_out).mkdir(parents=True, exist_ok=True) + Path(path_out_imagesTr).mkdir(parents=True, exist_ok=True) + Path(path_out_labelsTr).mkdir(parents=True, exist_ok=True) + + # save output to a log file + logger.add(os.path.join(path_out, "logs.txt"), rotation="10 MB", level="INFO") + + # Check if dataset paths exist + for path in args.path_data: + if not os.path.exists(path): + raise ValueError(f"Path {path} does not exist.") + + # Get sites from the input paths + sites = set(find_site_in_path(path) for path in args.path_data if find_site_in_path(path)) + # Single site + if len(sites) == 1: + create_directories(path_out, sites.pop()) + # Multiple sites + else: + for site in sites: + create_directories(path_out, site) + + all_lesion_files, train_images, test_images = [], {}, {} + # temp dict for storing dataset commits + dataset_commits = {} + + # loop over the datasets + for dataset in args.path_data: + root = Path(dataset) + + # get the git branch and commit ID of the dataset + dataset_name = os.path.basename(os.path.normpath(dataset)) + branch, commit = get_git_branch_and_commit(dataset) + dataset_commits[dataset_name] = f"git-{branch}-{commit}" + + # get recursively all GT '_label-lesion' files + lesion_files = [str(path) for path in root.rglob('*_label-lesion.nii.gz')] + + # add to the list of all subjects + all_lesion_files.extend(lesion_files) + + # Get the training and test splits + tr_subs, te_subs = train_test_split(lesion_files, test_size=test_ratio, random_state=args.seed) + + # update the train and test images dicts with the key as the subject and value as the path to the subject + train_images.update({sub: os.path.join(root, sub) for sub in tr_subs}) + test_images.update({sub: os.path.join(root, sub) for sub in te_subs}) + + logger.info(f"Found subjects in the training set (combining all datasets): {len(train_images)}") + logger.info(f"Found subjects in the test set (combining all datasets): {len(test_images)}") + # Print test images for each site + for site in sites: + logger.info(f"Test subjects in {site}: {len([sub for sub in test_images if site in sub])}") + + # print version of each dataset in a separate line + for dataset_name, dataset_commit in dataset_commits.items(): + logger.info(f"{dataset_name} dataset version: {dataset_commit}") + + # Counters for train and test sets + train_ctr, test_ctr = 0, 0 + train_niftis, test_nifitis = [], [] + # Loop over all images + for subject_label_file in tqdm(all_lesion_files, desc="Iterating over all images"): + + # Construct path to the background image + subject_image_file = subject_label_file.replace('/derivatives/labels', '').replace('_label-lesion', '') + + # Train images + if subject_label_file in train_images.keys(): + + train_ctr += 1 + # add the subject image file to the list of training niftis + train_niftis.append(os.path.basename(subject_image_file)) + + # create the new convention names for nnunet + sub_name = f"{str(Path(subject_image_file).name).replace('.nii.gz', '')}" + + # channel 0: T2w + subject_image_file_nnunet = os.path.join(path_out_imagesTr, + f"{args.dataset_name}_{sub_name}_{train_ctr:03d}_0000.nii.gz") + # channel 1: SC seg + subject_sc_file_nnunet = os.path.join(path_out_imagesTr, + f"{args.dataset_name}_{sub_name}_{train_ctr:03d}_0001.nii.gz") + # lesion label (lesion is part of SC) + subject_label_file_nnunet = os.path.join(path_out_labelsTr, + f"{args.dataset_name}_{sub_name}_{train_ctr:03d}.nii.gz") + + # overwritten the subject_sc_file_nnunet with the label for multi-channel training (lesion is part of SC) + subject_sc_file = get_multi_channel_label(subject_label_file, subject_image_file, sub_name, thr=0.5) + + # copy the files to new structure + # channel 0: T2w + shutil.copyfile(subject_image_file, subject_image_file_nnunet) + # channel 1: SC seg (lesion is part of SC) + shutil.copyfile(subject_sc_file, subject_sc_file_nnunet) + # lesion label + shutil.copyfile(subject_label_file, subject_label_file_nnunet) + + # convert the image and label to RPI using the Image class + image = Image(subject_image_file_nnunet) + image.change_orientation("RPI") + image.save(subject_image_file_nnunet) + + sc = Image(subject_sc_file_nnunet) + sc.change_orientation("RPI") + sc.save(subject_sc_file_nnunet) + + label = Image(subject_label_file_nnunet) + label.change_orientation("RPI") + label.save(subject_label_file_nnunet) + + # Test images + elif subject_label_file in test_images: + + test_ctr += 1 + # add the image file to the list of testing niftis + test_nifitis.append(os.path.basename(subject_image_file)) + + # create the new convention names for nnunet + sub_name = f"{str(Path(subject_image_file).name).replace('.nii.gz', '')}" + + # channel 0: T2w + subject_image_file_nnunet = os.path.join(Path(path_out, + f'imagesTs_{find_site_in_path(test_images[subject_label_file])}'), + f'{args.dataset_name}_{sub_name}_{test_ctr:03d}_0000.nii.gz') + # channel 1: SC seg (lesion is part of SC) + subject_sc_file_nnunet = os.path.join(Path(path_out, + f'imagesTs_{find_site_in_path(test_images[subject_label_file])}'), + f'{args.dataset_name}_{sub_name}_{test_ctr:03d}_0001.nii.gz') + # lesion label + subject_label_file_nnunet = os.path.join(Path(path_out, + f'labelsTs_{find_site_in_path(test_images[subject_label_file])}'), + f'{args.dataset_name}_{sub_name}_{test_ctr:03d}.nii.gz') + + # overwritten the subject_label_file with the region-based label + subject_sc_file = get_multi_channel_label(subject_label_file, subject_image_file, sub_name, thr=0.5) + + # copy the files to new structure + # channel 0: T2w + shutil.copyfile(subject_image_file, subject_image_file_nnunet) + print(f"\nCopying {subject_image_file} to {subject_image_file_nnunet}") + # channel 1: SC seg (lesion is part of SC) + shutil.copyfile(subject_sc_file, subject_sc_file_nnunet) + print(f"\nCopying {subject_sc_file} to {subject_sc_file_nnunet}") + # lesion label + shutil.copyfile(subject_label_file, subject_label_file_nnunet) + print(f"\nCopying {subject_label_file} to {subject_label_file_nnunet}") + + # convert the image and label to RPI using the Image class + image = Image(subject_image_file_nnunet) + image.change_orientation("RPI") + image.save(subject_image_file_nnunet) + + sc = Image(subject_sc_file_nnunet) + sc.change_orientation("RPI") + sc.save(subject_sc_file_nnunet) + + label = Image(subject_label_file_nnunet) + label.change_orientation("RPI") + label.save(subject_label_file_nnunet) + + else: + print("Skipping file, could not be located in the Train or Test splits split.", subject_label_file) + + logger.info(f"----- Dataset conversion finished! -----") + logger.info(f"Number of training and validation images (across all sites): {train_ctr}") + # Get number of train and val images per site + train_images_per_site = {} + for train_subject in train_images: + site = find_site_in_path(train_subject) + if site in train_images_per_site: + train_images_per_site[site] += 1 + else: + train_images_per_site[site] = 1 + # Print number of train images per site + for site, num_images in train_images_per_site.items(): + logger.info(f"Number of training and validation images in {site}: {num_images}") + + logger.info(f"Number of test images (across all sites): {test_ctr}") + # Get number of test images per site + test_images_per_site = {} + for test_subject in test_images: + site = find_site_in_path(test_subject) + if site in test_images_per_site: + test_images_per_site[site] += 1 + else: + test_images_per_site[site] = 1 + # Print number of test images per site + for site, num_images in test_images_per_site.items(): + logger.info(f"Number of test images in {site}: {num_images}") + + # create the yaml file containing the train and test niftis + create_yaml(train_niftis, test_nifitis, path_out, args, train_ctr, test_ctr, dataset_commits) + + +if __name__ == "__main__": + main() diff --git a/dataset_conversion/convert_bids_to_nnUNetv2_region-based.py b/dataset_conversion/convert_bids_to_nnUNetv2_region-based.py new file mode 100644 index 0000000..cc43a04 --- /dev/null +++ b/dataset_conversion/convert_bids_to_nnUNetv2_region-based.py @@ -0,0 +1,431 @@ +""" +Convert BIDS-structured datasets (dcm-zurich-lesions, dcm-zurich-lesions-20231115) to the nnUNetv2 REGION-BASED format. + +dataset.json: + +```json + "channel_names": { + "0": "acq-ax_T2w" + }, + "labels": { + "background": 0, + "sc": [ + 1, + 2 + ], + "lesion": 2 + }, + "regions_class_order": [ + 1, + 2 + ], +``` + +Full details about the format can be found here: +https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md + +The script to be used on a single dataset or multiple datasets. + +The script in default creates region-based labels for segmenting both lesion and the spinal cord. + +Currently only supports the conversion of a single contrast. In case of multiple contrasts, the script should be +modified to include those as well. + +Note: the script performs RPI reorientation of the images and labels + +Usage example multiple datasets: + python convert_bids_to_nnUNetv2_region-based.py + --path-data ~/data/dcm-zurich-lesions ~/data/dcm-zurich-lesions-20231115 + --path-out ${nnUNet_raw} + -dname DCMlesions + -dnum 601 + --split 0.8 0.2 + --seed 50 + --region-based + +Usage example single dataset: + python convert_bids_to_nnUNetv2_region-based.py + --path-data ~/data/dcm-zurich-lesions + --path-out ${nnUNet_raw} + -dname DCMlesions + -dnum 601 + --split 0.8 0.2 + --seed 50 + --region-based + +Authors: Naga Karthik, Jan Valosek +""" + +import argparse +from pathlib import Path +import json +import os +import re +import shutil +import yaml +from collections import OrderedDict +from loguru import logger +from sklearn.model_selection import train_test_split +from utils import binarize_label, create_region_based_label, get_git_branch_and_commit, Image +from tqdm import tqdm + +import nibabel as nib + + +def get_parser(): + # parse command line arguments + parser = argparse.ArgumentParser(description='Convert BIDS-structured dataset to nnUNetV2 REGION-BASED format.') + parser.add_argument('--path-data', nargs='+', required=True, type=str, + help='Path to BIDS dataset(s) (list).') + parser.add_argument('--path-out', help='Path to output directory.', required=True) + parser.add_argument('--dataset-name', '-dname', default='DCMlesionsRegionBased', type=str, + help='Specify the task name.') + parser.add_argument('--dataset-number', '-dnum', default=601, type=int, + help='Specify the task number, has to be greater than 500 but less than 999. e.g 502') + parser.add_argument('--seed', default=42, type=int, + help='Seed to be used for the random number generator split into training and test sets.') + parser.add_argument('--region-based', action='store_true', default=True, + help='If set, the script will create labels for region-based nnUNet training. Default: True') + # argument that accepts a list of floats as train val test splits + parser.add_argument('--split', nargs='+', type=float, default=[0.8, 0.2], + help='Ratios of training (includes validation) and test splits lying between 0-1. Example: ' + '--split 0.8 0.2') + return parser + + +def get_region_based_label(subject_label_file, subject_image_file, sub_ses_name, thr=0.5): + # define path for sc seg file + subject_seg_file = subject_label_file.replace('_label-lesion', '_label-SC_mask-manual') + + # check if the seg file exists + if not os.path.exists(subject_seg_file): + logger.info(f"Spinal cord segmentation file for subject {sub_ses_name} does not exist. Skipping.") + return None + + # create region-based label + seg_lesion_nii = create_region_based_label(subject_label_file, subject_seg_file, subject_image_file, + sub_ses_name, thr=thr) + + # save the region-based label + combined_seg_file = subject_label_file.replace('_label-lesion', '_SC-lesion') + nib.save(seg_lesion_nii, combined_seg_file) + + return combined_seg_file + + +def create_directories(path_out, site): + """Create test directories for a specified site. + + Args: + path_out (str): Base output directory. + site (str): Site identifier, such as 'dcm-zurich-lesions + """ + paths = [Path(path_out, f'imagesTs_{site}'), + Path(path_out, f'labelsTs_{site}')] + + for path in paths: + path.mkdir(parents=True, exist_ok=True) + + +def find_site_in_path(path): + """Extracts site identifier from the given path. + + Args: + path (str): Input path containing a site identifier. + + Returns: + str: Extracted site identifier or None if not found. + """ + # Find 'dcm-zurich-lesions' or 'dcm-zurich-lesions-20231115' + match = re.search(r'dcm-zurich-lesions(-\d{8})?', path) + return match.group(0) if match else None + + +def create_yaml(train_niftis, test_nifitis, path_out, args, train_ctr, test_ctr, dataset_commits): + # create a yaml file containing the list of training and test niftis + niftis_dict = { + f"train": sorted(train_niftis), + f"test": sorted(test_nifitis) + } + + # write the train and test niftis to a yaml file + with open(os.path.join(path_out, f"train_test_split_seed{args.seed}.yaml"), "w") as outfile: + yaml.dump(niftis_dict, outfile, default_flow_style=False) + + # c.f. dataset json generation + # In nnUNet V2, dataset.json file has become much shorter. The description of the fields and changes + # can be found here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md#datasetjson + # this file can be automatically generated using the following code here: + # https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/dataset_conversion/generate_dataset_json.py + # example: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/dataset_conversion/Task055_SegTHOR.py + + json_dict = OrderedDict() + json_dict['name'] = args.dataset_name + json_dict['description'] = args.dataset_name + json_dict['reference'] = "TBD" + json_dict['licence'] = "TBD" + json_dict['release'] = "0.0" + json_dict['numTraining'] = train_ctr + json_dict['numTest'] = test_ctr + json_dict['seed_used'] = args.seed + json_dict['dataset_versions'] = dataset_commits + json_dict['image_orientation'] = "RPI" + + # The following keys are the most important ones. + """ + channel_names: + Channel names must map the index to the name of the channel. For BIDS, this refers to the contrast suffix. + { + 0: 'T1', + 1: 'CT' + } + Note that the channel names may influence the normalization scheme!! Learn more in the documentation. + + labels: + This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not. + Example regular labels: + { + 'background': 0, + 'left atrium': 1, + 'some other label': 2 + } + Example region-based training: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/region_based_training.md + { + 'background': 0, + 'whole tumor': (1, 2, 3), + 'tumor core': (2, 3), + 'enhancing tumor': 3 + } + Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background! + """ + + json_dict['channel_names'] = { + 0: "acq-ax_T2w", + } + + if not args.region_based: + json_dict['labels'] = { + "background": 0, + "lesion": 1, + } + else: + json_dict['labels'] = { + "background": 0, + "sc": [1, 2], + # "sc": 1, + "lesion": 2, + } + json_dict['regions_class_order'] = [1, 2] + + # Needed for finding the files correctly. IMPORTANT! File endings must match between images and segmentations! + json_dict['file_ending'] = ".nii.gz" + + # create dataset_description.json + json_object = json.dumps(json_dict, indent=4) + # write to dataset description + # nn-unet requires it to be "dataset.json" + dataset_dict_name = f"dataset.json" + with open(os.path.join(path_out, dataset_dict_name), "w") as outfile: + outfile.write(json_object) + + +def main(): + parser = get_parser() + args = parser.parse_args() + + train_ratio, test_ratio = args.split + path_out = Path(os.path.join(os.path.abspath(args.path_out), f'Dataset{args.dataset_number}_{args.dataset_name}')) + + # create individual directories for train and test images and labels + path_out_imagesTr = Path(os.path.join(path_out, 'imagesTr')) + path_out_labelsTr = Path(os.path.join(path_out, 'labelsTr')) + # create the training directories + Path(path_out).mkdir(parents=True, exist_ok=True) + Path(path_out_imagesTr).mkdir(parents=True, exist_ok=True) + Path(path_out_labelsTr).mkdir(parents=True, exist_ok=True) + + # save output to a log file + logger.add(os.path.join(path_out, "logs.txt"), rotation="10 MB", level="INFO") + + # Check if dataset paths exist + for path in args.path_data: + if not os.path.exists(path): + raise ValueError(f"Path {path} does not exist.") + + # Get sites from the input paths + sites = set(find_site_in_path(path) for path in args.path_data if find_site_in_path(path)) + # Single site + if len(sites) == 1: + create_directories(path_out, sites.pop()) + # Multiple sites + else: + for site in sites: + create_directories(path_out, site) + + all_lesion_files, train_images, test_images = [], {}, {} + # temp dict for storing dataset commits + dataset_commits = {} + + # loop over the datasets + for dataset in args.path_data: + root = Path(dataset) + + # get the git branch and commit ID of the dataset + dataset_name = os.path.basename(os.path.normpath(dataset)) + branch, commit = get_git_branch_and_commit(dataset) + dataset_commits[dataset_name] = f"git-{branch}-{commit}" + + # get recursively all GT '_label-lesion' files + lesion_files = [str(path) for path in root.rglob('*_label-lesion.nii.gz')] + + # add to the list of all subjects + all_lesion_files.extend(lesion_files) + + # Get the training and test splits + tr_subs, te_subs = train_test_split(lesion_files, test_size=test_ratio, random_state=args.seed) + + # update the train and test images dicts with the key as the subject and value as the path to the subject + train_images.update({sub: os.path.join(root, sub) for sub in tr_subs}) + test_images.update({sub: os.path.join(root, sub) for sub in te_subs}) + + logger.info(f"Found subjects in the training set (combining all datasets): {len(train_images)}") + logger.info(f"Found subjects in the test set (combining all datasets): {len(test_images)}") + # Print test images for each site + for site in sites: + logger.info(f"Test subjects in {site}: {len([sub for sub in test_images if site in sub])}") + + # print version of each dataset in a separate line + for dataset_name, dataset_commit in dataset_commits.items(): + logger.info(f"{dataset_name} dataset version: {dataset_commit}") + + # Counters for train and test sets + train_ctr, test_ctr = 0, 0 + train_niftis, test_nifitis = [], [] + # Loop over all images + for subject_label_file in tqdm(all_lesion_files, desc="Iterating over all images"): + + # Construct path to the background image + subject_image_file = subject_label_file.replace('/derivatives/labels', '').replace('_label-lesion', '') + + # Train images + if subject_label_file in train_images.keys(): + + train_ctr += 1 + # add the subject image file to the list of training niftis + train_niftis.append(os.path.basename(subject_image_file)) + + # create the new convention names for nnunet + sub_name = f"{str(Path(subject_image_file).name).replace('.nii.gz', '')}" + + subject_image_file_nnunet = os.path.join(path_out_imagesTr, + f"{args.dataset_name}_{sub_name}_{train_ctr:03d}_0000.nii.gz") + subject_label_file_nnunet = os.path.join(path_out_labelsTr, + f"{args.dataset_name}_{sub_name}_{train_ctr:03d}.nii.gz") + + # use region-based labels if required + if args.region_based: + # overwritten the subject_label_file with the region-based label + subject_label_file = get_region_based_label(subject_label_file, + subject_image_file, sub_name, thr=0.5) + if subject_label_file is None: + print(f"Skipping since the region-based label could not be generated") + continue + + # copy the files to new structure + shutil.copyfile(subject_image_file, subject_image_file_nnunet) + shutil.copyfile(subject_label_file, subject_label_file_nnunet) + + # convert the image and label to RPI using the Image class + image = Image(subject_image_file_nnunet) + image.change_orientation("RPI") + image.save(subject_image_file_nnunet) + + label = Image(subject_label_file_nnunet) + label.change_orientation("RPI") + label.save(subject_label_file_nnunet) + + # binarize the label file only if region-based training is not set (since the region-based labels are + # already binarized) + if not args.region_based: + binarize_label(subject_image_file_nnunet, subject_label_file_nnunet) + + # Test images + elif subject_label_file in test_images: + + test_ctr += 1 + # add the image file to the list of testing niftis + test_nifitis.append(os.path.basename(subject_image_file)) + + # create the new convention names for nnunet + sub_name = f"{str(Path(subject_image_file).name).replace('.nii.gz', '')}" + + subject_image_file_nnunet = os.path.join(Path(path_out, + f'imagesTs_{find_site_in_path(test_images[subject_label_file])}'), + f'{args.dataset_name}_{sub_name}_{test_ctr:03d}_0000.nii.gz') + subject_label_file_nnunet = os.path.join(Path(path_out, + f'labelsTs_{find_site_in_path(test_images[subject_label_file])}'), + f'{args.dataset_name}_{sub_name}_{test_ctr:03d}.nii.gz') + + # use region-based labels if required + if args.region_based: + # overwritten the subject_label_file with the region-based label + subject_label_file = get_region_based_label(subject_label_file, + subject_image_file, sub_name, thr=0.5) + if subject_label_file is None: + continue + + shutil.copyfile(subject_label_file, subject_label_file_nnunet) + print(f"\nCopying {subject_label_file} to {subject_label_file_nnunet}") + label = Image(subject_label_file_nnunet) + label.change_orientation("RPI") + label.save(subject_label_file_nnunet) + + # copy the files to new structure + shutil.copyfile(subject_image_file, subject_image_file_nnunet) + print(f"\nCopying {subject_image_file} to {subject_image_file_nnunet}") + # convert the image and label to RPI using the Image class + image = Image(subject_image_file_nnunet) + image.change_orientation("RPI") + image.save(subject_image_file_nnunet) + + # binarize the label file only if region-based training is not set (since the region-based labels are + # already binarized) + if not args.region_based: + binarize_label(subject_image_file_nnunet, subject_label_file_nnunet) + + else: + print("Skipping file, could not be located in the Train or Test splits split.", subject_label_file) + + logger.info(f"----- Dataset conversion finished! -----") + logger.info(f"Number of training and validation images (across all sites): {train_ctr}") + # Get number of train and val images per site + train_images_per_site = {} + for train_subject in train_images: + site = find_site_in_path(train_subject) + if site in train_images_per_site: + train_images_per_site[site] += 1 + else: + train_images_per_site[site] = 1 + # Print number of train images per site + for site, num_images in train_images_per_site.items(): + logger.info(f"Number of training and validation images in {site}: {num_images}") + + logger.info(f"Number of test images (across all sites): {test_ctr}") + # Get number of test images per site + test_images_per_site = {} + for test_subject in test_images: + site = find_site_in_path(test_subject) + if site in test_images_per_site: + test_images_per_site[site] += 1 + else: + test_images_per_site[site] = 1 + # Print number of test images per site + for site, num_images in test_images_per_site.items(): + logger.info(f"Number of test images in {site}: {num_images}") + + # create the yaml file containing the train and test niftis + create_yaml(train_niftis, test_nifitis, path_out, args, train_ctr, test_ctr, dataset_commits) + + +if __name__ == "__main__": + main() diff --git a/dataset_conversion/requirements.txt b/dataset_conversion/requirements.txt new file mode 100644 index 0000000..69b0718 --- /dev/null +++ b/dataset_conversion/requirements.txt @@ -0,0 +1,6 @@ +loguru +scikit-learn +tqdm +nibabel +pyyaml +collections-extended diff --git a/dataset_conversion/segment_sc.sh b/dataset_conversion/segment_sc.sh new file mode 100644 index 0000000..e561338 --- /dev/null +++ b/dataset_conversion/segment_sc.sh @@ -0,0 +1,127 @@ +#!/bin/bash +# +# Run the SCIseg nnUNet model on T2w images to segment spinal cord. +# +# Note: conda environment with nnUNetV2 is required to run this script. +# For details how to install nnUNetV2, see: +# https://github.com/ivadomed/utilities/blob/main/quick_start_guides/nnU-Net_quick_start_guide.md#installation +# +# Usage: +# sct_run_batch -config config.json +# +# Example of config.json: +# { +# "path_data" : "", +# "path_output" : "_2024-XX-XX", +# "script" : "~/code/model-seg-dcm/dataset_conversion/segment_sc", +# "jobs" : 8, +# "script_args" : "~/code/model_seg_sci/packaging/run_inference_single_subject.py ~/models/sci-multisite-model" +# } +# +# Author: Jan Valosek, Naga Karthik +# + +# Uncomment for full verbose +set -x + +# Immediately exit if error +set -e -o pipefail + +# Exit if user presses CTRL+C (Linux) or CMD+C (OSX) +trap "echo Caught Keyboard Interrupt within script. Exiting now.; exit" INT + +# Print retrieved variables from the sct_run_batch script to the log (to allow easier debug) +echo "Retrieved variables from from the caller sct_run_batch:" +echo "PATH_DATA: ${PATH_DATA}" +echo "PATH_DATA_PROCESSED: ${PATH_DATA_PROCESSED}" +echo "PATH_RESULTS: ${PATH_RESULTS}" +echo "PATH_LOG: ${PATH_LOG}" +echo "PATH_QC: ${PATH_QC}" + +SUBJECT=$1 +PATH_NNUNET_SCRIPT=$2 +PATH_NNUNET_MODEL=$3 + +echo "SUBJECT: ${SUBJECT}" +echo "PATH_NNUNET_SCRIPT: ${PATH_NNUNET_SCRIPT}" +echo "PATH_NNUNET_MODEL: ${PATH_NNUNET_MODEL}" + +# ------------------------------------------------------------------------------ +# CONVENIENCE FUNCTIONS +# ------------------------------------------------------------------------------ + +# Segment spinal cord using our nnUNet model +segment_sc_nnUNet(){ + local file="$1" + local kernel="$2" # 2d or 3d + + # output file name + FILESEG="${file}_seg_nnunet_${kernel}" + + # Get the start time + start_time=$(date +%s) + # Run SC segmentation + python ${PATH_NNUNET_SCRIPT} -i ${file}.nii.gz -o ${FILESEG}.nii.gz -path-model ${PATH_NNUNET_MODEL}/nnUNetTrainer__nnUNetPlans__${kernel}_fullres -pred-type sc -use-gpu + # Get the end time + end_time=$(date +%s) + # Calculate the time difference + execution_time=$(python3 -c "print($end_time - $start_time)") + echo "${FILESEG},${execution_time}" >> ${PATH_RESULTS}/execution_time.csv + + # Generate spinal cord QC report + sct_qc -i ${file}.nii.gz -s ${FILESEG}.nii.gz -p sct_deepseg_sc -qc ${PATH_QC} -qc-subject ${SUBJECT} + # Compute ANIMA segmentation performance metrics + #compute_anima_metrics ${FILESEG} ${file}_seg-manual +} + +# ------------------------------------------------------------------------------ +# SCRIPT STARTS HERE +# ------------------------------------------------------------------------------ +# get starting time: +start=`date +%s` + +# Display useful info for the log, such as SCT version, RAM and CPU cores available +sct_check_dependencies -short + +# Go to folder where data will be copied and processed +cd $PATH_DATA_PROCESSED + +# Copy source T2w images +# Note: we use '/./' in order to include the sub-folder 'ses-0X' +# We do a substitution '/' --> '_' in case there is a subfolder 'ses-0X/' +# Note: we copy only axial T2w image to save space +rsync -Ravzh ${PATH_DATA}/./${SUBJECT}/anat/${SUBJECT//[\/]/_}_*acq-ax_T2w.* . + +# Go to subject folder for source images +cd ${SUBJECT}/anat + +# ------------------------------------------------------------------------------ +# T2w axial +# ------------------------------------------------------------------------------ +# We do a substitution '/' --> '_' in case there is a subfolder 'ses-0X/' + +file_t2="${SUBJECT//[\/]/_}"_acq-ax_T2w + +# Check if file_t2 exists +if [[ ! -e ${file_t2}.nii.gz ]]; then + echo "File ${file_t2}.nii.gz does not exist" >> ${PATH_LOG}/missing_files.log + echo "ERROR: File ${file_t2}.nii.gz does not exist. Exiting." + exit 1 +fi + +# Segment SC using the SCIseg nnUNet model +segment_sc_nnUNet "${file_t2}" '3d' + +# ------------------------------------------------------------------------------ +# End +# ------------------------------------------------------------------------------ + +# Display results (to easily compare integrity across SCT versions) +end=`date +%s` +runtime=$((end-start)) +echo +echo "~~~" +echo "SCT version: `sct_version`" +echo "Ran on: `uname -nsr`" +echo "Duration: $(($runtime / 3600))hrs $((($runtime / 60) % 60))min $(($runtime % 60))sec" +echo "~~~" diff --git a/dataset_conversion/utils.py b/dataset_conversion/utils.py new file mode 100644 index 0000000..78ed774 --- /dev/null +++ b/dataset_conversion/utils.py @@ -0,0 +1,818 @@ +import os +import nibabel as nib +import numpy as np +import logging +from copy import deepcopy +import subprocess + +logger = logging.getLogger(__name__) + + +def binarize_label(subject_path, label_path): + label_npy = nib.load(label_path).get_fdata() + threshold = 0.5 + label_npy = np.where(label_npy > threshold, 1, 0) + ref = nib.load(subject_path) + label_bin = nib.Nifti1Image(label_npy, ref.affine, ref.header) + # overwrite the original label file with the binarized version + nib.save(label_bin, label_path) + + +def create_region_based_label(lesion_label_file, seg_label_file, image_file, sub_ses_name, thr=0.5): + """ + Creates region-based labels for REGION-BASED nnUNet training. The regions are: + 0: background + 1: spinal cord seg + 2: lesion seg + """ + # load the labels + lesion_label_npy = nib.load(lesion_label_file).get_fdata() + seg_label_npy = nib.load(seg_label_file).get_fdata() + + # binarize the labels + lesion_label_npy = np.where(lesion_label_npy > thr, 1, 0) + seg_label_npy = np.where(seg_label_npy > thr, 1, 0) + + # check if the shapes of the labels match + assert lesion_label_npy.shape == seg_label_npy.shape, \ + f'Shape mismatch between lesion label and segmentation label for subject {sub_ses_name}. Check the labels.' + + # create a new label array with the same shape as the original labels + label_npy = np.zeros(lesion_label_npy.shape, dtype=np.int16) + # spinal cord + label_npy[seg_label_npy == 1] = 1 + # lesion seg + label_npy[lesion_label_npy == 1] = 2 + # TODO: what happens when the subject has no lesion? + + # print unique values in the label array + # print(f'Unique values in the label array for subject {sub_ses_name}: {np.unique(label_npy)}') + + # save the new label file + ref = nib.load(image_file) + label_nii = nib.Nifti1Image(label_npy, ref.affine, ref.header) + + return label_nii + + +def create_multi_channel_label(lesion_label_file, seg_label_file, image_file, sub_ses_name, thr=0.5): + """ + Creates labels for MULTI-CHANNEL nnUNet training. The regions are: + 0: background + 1: lesion seg + The function makes sure that the lesion seg is part of the spinal cord seg (the spinal cord seg is the first + channel). + """ + # load the labels + lesion_label_npy = nib.load(lesion_label_file).get_fdata() + seg_label_npy = nib.load(seg_label_file).get_fdata() + + # binarize the labels + lesion_label_npy = np.where(lesion_label_npy > thr, 1, 0) + seg_label_npy = np.where(seg_label_npy > thr, 1, 0) + + # check if the shapes of the labels match + assert lesion_label_npy.shape == seg_label_npy.shape, \ + f'Shape mismatch between lesion label and segmentation label for subject {sub_ses_name}. Check the labels.' + + # create a new label array with the same shape as the original labels + label_npy = np.zeros(lesion_label_npy.shape, dtype=np.int16) + # spinal cord + label_npy[seg_label_npy == 1] = 1 + # lesion seg + label_npy[lesion_label_npy == 1] = 1 + + # print unique values in the label array + # print(f'Unique values in the label array for subject {sub_ses_name}: {np.unique(label_npy)}') + + # save the new label file + ref = nib.load(image_file) + label_nii = nib.Nifti1Image(label_npy, ref.affine, ref.header) + + return label_nii + + +def get_git_branch_and_commit(dataset_path=None): + """ + :return: git branch and commit ID, with trailing '*' if modified + Taken from: https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/utils/sys.py#L476 + and https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/utils/sys.py#L461 + """ + + # branch info + b = subprocess.Popen(["git", "rev-parse", "--abbrev-ref", "HEAD"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, cwd=dataset_path) + b_output, _ = b.communicate() + b_status = b.returncode + + if b_status == 0: + branch = b_output.decode().strip() + else: + branch = "!?!" + + # commit info + p = subprocess.Popen(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=dataset_path) + output, _ = p.communicate() + status = p.returncode + if status == 0: + commit = output.decode().strip() + else: + commit = "?!?" + + p = subprocess.Popen(["git", "status", "--porcelain"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=dataset_path) + output, _ = p.communicate() + status = p.returncode + if status == 0: + unclean = True + for line in output.decode().strip().splitlines(): + line = line.rstrip() + if line.startswith("??"): # ignore ignored files, they can't hurt + continue + break + else: + unclean = False + if unclean: + commit += "*" + + return branch, commit + + +class Image(object): + """ + Compact version of SCT's Image Class (https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/image.py#L245) + Create an object that behaves similarly to nibabel's image object. Useful additions include: dims, change_orientation and getNonZeroCoordinates. + Taken from: https://github.com/ivadomed/utilities/blob/main/scripts/image.py + Changed default verbosity to 0. + """ + + def __init__(self, param=None, hdr=None, orientation=None, absolutepath=None, dim=None): + """ + :param param: string indicating a path to a image file or an `Image` object. + """ + + # initialization of all parameters + self.affine = None + self.data = None + self._path = None + self.ext = "" + + if absolutepath is not None: + self._path = os.path.abspath(absolutepath) + + # Case 1: load an image from file + if isinstance(param, str): + self.loadFromPath(param) + # Case 2: create a copy of an existing `Image` object + elif isinstance(param, type(self)): + self.copy(param) + # Case 3: create a blank image from a list of dimensions + elif isinstance(param, list): + self.data = np.zeros(param) + self.hdr = hdr.copy() if hdr is not None else nib.Nifti1Header() + self.hdr.set_data_shape(self.data.shape) + # Case 4: create an image from an existing data array + elif isinstance(param, (np.ndarray, np.generic)): + self.data = param + self.hdr = hdr.copy() if hdr is not None else nib.Nifti1Header() + self.hdr.set_data_shape(self.data.shape) + else: + raise TypeError('Image constructor takes at least one argument.') + + # Fix any mismatch between the array's datatype and the header datatype + self.fix_header_dtype() + + @property + def dim(self): + return get_dimension(self) + + @property + def orientation(self): + return get_orientation(self) + + @property + def absolutepath(self): + """ + Storage path (either actual or potential) + + Notes: + + - As several tools perform chdir() it's very important to have absolute paths + - When set, if relative: + + - If it already existed, it becomes a new basename in the old dirname + - Else, it becomes absolute (shortcut) + + Usually not directly touched (use `Image.save`), but in some cases it's + the best way to set it. + """ + return self._path + + @absolutepath.setter + def absolutepath(self, value): + if value is None: + self._path = None + return + elif not os.path.isabs(value) and self._path is not None: + value = os.path.join(os.path.dirname(self._path), value) + elif not os.path.isabs(value): + value = os.path.abspath(value) + self._path = value + + @property + def header(self): + return self.hdr + + @header.setter + def header(self, value): + self.hdr = value + + def __deepcopy__(self, memo): + return type(self)(deepcopy(self.data, memo), deepcopy(self.hdr, memo), deepcopy(self.orientation, memo), deepcopy(self.absolutepath, memo), deepcopy(self.dim, memo)) + + def copy(self, image=None): + if image is not None: + self.affine = deepcopy(image.affine) + self.data = deepcopy(image.data) + self.hdr = deepcopy(image.hdr) + self._path = deepcopy(image._path) + else: + return deepcopy(self) + + def loadFromPath(self, path): + """ + This function load an image from an absolute path using nibabel library + + :param path: path of the file from which the image will be loaded + :return: + """ + + self.absolutepath = os.path.abspath(path) + im_file = nib.load(self.absolutepath, mmap=True) + self.affine = im_file.affine.copy() + self.data = np.asanyarray(im_file.dataobj) + self.hdr = im_file.header.copy() + if path != self.absolutepath: + logger.debug("Loaded %s (%s) orientation %s shape %s", path, self.absolutepath, self.orientation, self.data.shape) + else: + logger.debug("Loaded %s orientation %s shape %s", path, self.orientation, self.data.shape) + + def change_orientation(self, orientation, inverse=False): + """ + Change orientation on image (in-place). + + :param orientation: orientation string (SCT "from" convention) + + :param inverse: if you think backwards, use this to specify that you actually\ + want to transform *from* the specified orientation, not *to*\ + it. + + """ + change_orientation(self, orientation, self, inverse=inverse) + return self + + def getNonZeroCoordinates(self, sorting=None, reverse_coord=False): + """ + This function return all the non-zero coordinates that the image contains. + Coordinate list can also be sorted by x, y, z, or the value with the parameter sorting='x', sorting='y', sorting='z' or sorting='value' + If reverse_coord is True, coordinate are sorted from larger to smaller. + + Removed Coordinate object + """ + n_dim = 1 + if self.dim[3] == 1: + n_dim = 3 + else: + n_dim = 4 + if self.dim[2] == 1: + n_dim = 2 + + if n_dim == 3: + X, Y, Z = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]] for i in range(0, len(X))] + elif n_dim == 2: + try: + X, Y = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], 0, self.data[X[i], Y[i]]] for i in range(0, len(X))] + except ValueError: + X, Y, Z = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], 0, self.data[X[i], Y[i], 0]] for i in range(0, len(X))] + + if sorting is not None: + if reverse_coord not in [True, False]: + raise ValueError('reverse_coord parameter must be a boolean') + + if sorting == 'x': + list_coordinates = sorted(list_coordinates, key=lambda el: el[0], reverse=reverse_coord) + elif sorting == 'y': + list_coordinates = sorted(list_coordinates, key=lambda el: el[1], reverse=reverse_coord) + elif sorting == 'z': + list_coordinates = sorted(list_coordinates, key=lambda el: el[2], reverse=reverse_coord) + elif sorting == 'value': + list_coordinates = sorted(list_coordinates, key=lambda el: el[3], reverse=reverse_coord) + else: + raise ValueError("sorting parameter must be either 'x', 'y', 'z' or 'value'") + + return list_coordinates + + def change_type(self, dtype): + """ + Change data type on image. + + Note: the image path is voided. + """ + change_type(self, dtype, self) + return self + + def fix_header_dtype(self): + """ + Change the header dtype to the match the datatype of the array. + """ + # Using bool for nibabel headers is unsupported, so use uint8 instead: + # `nibabel.spatialimages.HeaderDataError: data dtype "bool" not supported` + dtype_data = self.data.dtype + if dtype_data == bool: + dtype_data = np.uint8 + + dtype_header = self.hdr.get_data_dtype() + if dtype_header != dtype_data: + logger.warning(f"Image header specifies datatype '{dtype_header}', but array is of type " + f"'{dtype_data}'. Header metadata will be overwritten to use '{dtype_data}'.") + self.hdr.set_data_dtype(dtype_data) + + def save(self, path=None, dtype=None, verbose=0, mutable=False): + """ + Write an image in a nifti file + + :param path: Where to save the data, if None it will be taken from the\ + absolutepath member.\ + If path is a directory, will save to a file under this directory\ + with the basename from the absolutepath member. + + :param dtype: if not set, the image is saved in the same type as input data\ + if 'minimize', image storage space is minimized\ + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),\ + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),\ + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),\ + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),\ + (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),\ + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),\ + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),\ + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),\ + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),\ + (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),\ + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),\ + (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),\ + (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),\ + (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"), + + :param mutable: whether to update members with newly created path or dtype + """ + if mutable: # do all modifications in-place + # Case 1: `path` not specified + if path is None: + if self.absolutepath: # Fallback to the original filepath + path = self.absolutepath + else: + raise ValueError("Don't know where to save the image (no absolutepath or path parameter)") + # Case 2: `path` points to an existing directory + elif os.path.isdir(path): + if self.absolutepath: # Use the original filename, but save to the directory specified by `path` + path = os.path.join(os.path.abspath(path), os.path.basename(self.absolutepath)) + else: + raise ValueError("Don't know where to save the image (path parameter is dir, but absolutepath is " + "missing)") + # Case 3: `path` points to a file (or a *nonexistent* directory) so use its value as-is + # (We're okay with letting nonexistent directories slip through, because it's difficult to distinguish + # between nonexistent directories and nonexistent files. Plus, `nibabel` will catch any further errors.) + else: + pass + + if os.path.isfile(path) and verbose: + logger.warning("File %s already exists. Will overwrite it.", path) + if os.path.isabs(path): + logger.debug("Saving image to %s orientation %s shape %s", + path, self.orientation, self.data.shape) + else: + logger.debug("Saving image to %s (%s) orientation %s shape %s", + path, os.path.abspath(path), self.orientation, self.data.shape) + + # Now that `path` has been set and log messages have been written, we can assign it to the image itself + self.absolutepath = os.path.abspath(path) + + if dtype is not None: + self.change_type(dtype) + + if self.hdr is not None: + self.hdr.set_data_shape(self.data.shape) + self.fix_header_dtype() + + # nb. that copy() is important because if it were a memory map, save() would corrupt it + dataobj = self.data.copy() + affine = None + header = self.hdr.copy() if self.hdr is not None else None + nib.save(nib.nifti1.Nifti1Image(dataobj, affine, header), self.absolutepath) + if not os.path.isfile(self.absolutepath): + raise RuntimeError(f"Couldn't save image to {self.absolutepath}") + else: + # if we're not operating in-place, then make any required modifications on a throw-away copy + self.copy().save(path, dtype, verbose, mutable=True) + return self + + +class SlicerOneAxis(object): + """ + Image slicer to use when you don't care about the 2D slice orientation, + and don't want to specify them. + The slicer will just iterate through the right axis that corresponds to + its specification. + + Can help getting ranges and slice indices. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + + def __init__(self, im, axis="IS"): + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + axis_labels = "LRPAIS" + if len(axis) != 2: + raise ValueError() + if axis[0] not in axis_labels: + raise ValueError() + if axis[1] not in axis_labels: + raise ValueError() + if axis[0] != opposite_character[axis[1]]: + raise ValueError() + + for idx_axis in range(2): + dim_nr = im.orientation.find(axis[idx_axis]) + if dim_nr != -1: + break + if dim_nr == -1: + raise ValueError() + + # SCT convention + from_dir = im.orientation[dim_nr] + self.direction = +1 if axis[0] == from_dir else -1 + self.nb_slices = im.dim[dim_nr] + self.im = im + self.axis = axis + self._slice = lambda idx: tuple([(idx if x in axis else slice(None)) for x in im.orientation]) + + def __len__(self): + return self.nb_slices + + def __getitem__(self, idx): + """ + + :return: an image slice, at slicing index idx + :param idx: slicing index (according to the slicing direction) + """ + if isinstance(idx, slice): + raise NotImplementedError() + + if idx >= self.nb_slices: + raise IndexError("I just have {} slices!".format(self.nb_slices)) + + if self.direction == -1: + idx = self.nb_slices - 1 - idx + + return self.im.data[self._slice(idx)] + +def get_dimension(im_file, verbose=1): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + Get dimension from Image or nibabel object. Manages 2D, 3D or 4D images. + + :param: im_file: Image or nibabel object + :return: nx, ny, nz, nt, px, py, pz, pt + """ + if not isinstance(im_file, (nib.nifti1.Nifti1Image, Image)): + raise TypeError("The provided image file is neither a nibabel.nifti1.Nifti1Image instance nor an Image instance") + # initializating ndims [nx, ny, nz, nt] and pdims [px, py, pz, pt] + ndims = [1, 1, 1, 1] + pdims = [1, 1, 1, 1] + data_shape = im_file.header.get_data_shape() + zooms = im_file.header.get_zooms() + for i in range(min(len(data_shape), 4)): + ndims[i] = data_shape[i] + pdims[i] = zooms[i] + return *ndims, *pdims + + +def change_orientation(im_src, orientation, im_dst=None, inverse=False): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im_src: source image + :param orientation: orientation string (SCT "from" convention) + :param im_dst: destination image (can be the source image for in-place + operation, can be unset to generate one) + :param inverse: if you think backwards, use this to specify that you actually + want to transform *from* the specified orientation, not *to* it. + :return: an image with changed orientation + + .. note:: + - the resulting image has no path member set + - if the source image is < 3D, it is reshaped to 3D and the destination is 3D + """ + + if len(im_src.data.shape) < 3: + pass # Will reshape to 3D + elif len(im_src.data.shape) == 3: + pass # OK, standard 3D volume + elif len(im_src.data.shape) == 4: + pass # OK, standard 4D volume + elif len(im_src.data.shape) == 5 and im_src.header.get_intent()[0] == "vector": + pass # OK, physical displacement field + else: + raise NotImplementedError("Don't know how to change orientation for this image") + + im_src_orientation = im_src.orientation + im_dst_orientation = orientation + if inverse: + im_src_orientation, im_dst_orientation = im_dst_orientation, im_src_orientation + + perm, inversion = _get_permutations(im_src_orientation, im_dst_orientation) + + if im_dst is None: + im_dst = im_src.copy() + im_dst._path = None + + im_src_data = im_src.data + if len(im_src_data.shape) < 3: + im_src_data = im_src_data.reshape(tuple(list(im_src_data.shape) + ([1] * (3 - len(im_src_data.shape))))) + + # Update data by performing inversions and swaps + + # axes inversion (flip) + data = im_src_data[::inversion[0], ::inversion[1], ::inversion[2]] + + # axes manipulations (transpose) + if perm == [1, 0, 2]: + data = np.swapaxes(data, 0, 1) + elif perm == [2, 1, 0]: + data = np.swapaxes(data, 0, 2) + elif perm == [0, 2, 1]: + data = np.swapaxes(data, 1, 2) + elif perm == [2, 0, 1]: + data = np.swapaxes(data, 0, 2) # transform [2, 0, 1] to [1, 0, 2] + data = np.swapaxes(data, 0, 1) # transform [1, 0, 2] to [0, 1, 2] + elif perm == [1, 2, 0]: + data = np.swapaxes(data, 0, 2) # transform [1, 2, 0] to [0, 2, 1] + data = np.swapaxes(data, 1, 2) # transform [0, 2, 1] to [0, 1, 2] + elif perm == [0, 1, 2]: + # do nothing + pass + else: + raise NotImplementedError() + + # Update header + + im_src_aff = im_src.hdr.get_best_affine() + aff = nib.orientations.inv_ornt_aff( + np.array((perm, inversion)).T, + im_src_data.shape) + im_dst_aff = np.matmul(im_src_aff, aff) + + im_dst.header.set_qform(im_dst_aff) + im_dst.header.set_sform(im_dst_aff) + im_dst.header.set_data_shape(data.shape) + im_dst.data = data + + return im_dst + + +def _get_permutations(im_src_orientation, im_dst_orientation): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im_src_orientation str: Orientation of source image. Example: 'RPI' + :param im_dest_orientation str: Orientation of destination image. Example: 'SAL' + :return: list of axes permutations and list of inversions to achieve an orientation change + """ + + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + + perm = [0, 1, 2] + inversion = [1, 1, 1] + for i, character in enumerate(im_src_orientation): + try: + perm[i] = im_dst_orientation.index(character) + except ValueError: + perm[i] = im_dst_orientation.index(opposite_character[character]) + inversion[i] = -1 + + return perm, inversion + + +def get_orientation(im): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im: an Image + :return: reference space string (ie. what's in Image.orientation) + """ + res = "".join(nib.orientations.aff2axcodes(im.hdr.get_best_affine())) + return orientation_string_nib2sct(res) + + +def orientation_string_nib2sct(s): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :return: SCT reference space code from nibabel one + """ + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + return "".join([opposite_character[x] for x in s]) + + +def change_type(im_src, dtype, im_dst=None): + """ + Change the voxel type of the image + + :param dtype: if not set, the image is saved in standard type\ + if 'minimize', image space is minimize\ + if 'minimize_int', image space is minimize and values are approximated to integers\ + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),\ + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),\ + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),\ + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),\ + (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),\ + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),\ + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),\ + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),\ + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),\ + (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),\ + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),\ + (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),\ + (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),\ + (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"), + :return: + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + """ + + if im_dst is None: + im_dst = im_src.copy() + im_dst._path = None + + if dtype is None: + return im_dst + + # get min/max from input image + min_in = np.nanmin(im_src.data) + max_in = np.nanmax(im_src.data) + + # find optimum type for the input image + if dtype in ('minimize', 'minimize_int'): + # warning: does not take intensity resolution into account, neither complex voxels + + # check if voxel values are real or integer + isInteger = True + if dtype == 'minimize': + for vox in im_src.data.flatten(): + if int(vox) != vox: + isInteger = False + break + + if isInteger: + if min_in >= 0: # unsigned + if max_in <= np.iinfo(np.uint8).max: + dtype = np.uint8 + elif max_in <= np.iinfo(np.uint16): + dtype = np.uint16 + elif max_in <= np.iinfo(np.uint32).max: + dtype = np.uint32 + elif max_in <= np.iinfo(np.uint64).max: + dtype = np.uint64 + else: + raise ValueError("Maximum value of the image is to big to be represented.") + else: + if max_in <= np.iinfo(np.int8).max and min_in >= np.iinfo(np.int8).min: + dtype = np.int8 + elif max_in <= np.iinfo(np.int16).max and min_in >= np.iinfo(np.int16).min: + dtype = np.int16 + elif max_in <= np.iinfo(np.int32).max and min_in >= np.iinfo(np.int32).min: + dtype = np.int32 + elif max_in <= np.iinfo(np.int64).max and min_in >= np.iinfo(np.int64).min: + dtype = np.int64 + else: + raise ValueError("Maximum value of the image is to big to be represented.") + else: + # if max_in <= np.finfo(np.float16).max and min_in >= np.finfo(np.float16).min: + # type = 'np.float16' # not supported by nibabel + if max_in <= np.finfo(np.float32).max and min_in >= np.finfo(np.float32).min: + dtype = np.float32 + elif max_in <= np.finfo(np.float64).max and min_in >= np.finfo(np.float64).min: + dtype = np.float64 + + dtype = to_dtype(dtype) + else: + dtype = to_dtype(dtype) + + # if output type is int, check if it needs intensity rescaling + if "int" in dtype.name: + # get min/max from output type + min_out = np.iinfo(dtype).min + max_out = np.iinfo(dtype).max + # before rescaling, check if there would be an intensity overflow + + if (min_in < min_out) or (max_in > max_out): + # This condition is important for binary images since we do not want to scale them + logger.warning(f"To avoid intensity overflow due to convertion to +{dtype.name}+, intensity will be rescaled to the maximum quantization scale") + # rescale intensity + data_rescaled = im_src.data * (max_out - min_out) / (max_in - min_in) + im_dst.data = data_rescaled - (data_rescaled.min() - min_out) + + # change type of data in both numpy array and nifti header + im_dst.data = getattr(np, dtype.name)(im_dst.data) + im_dst.hdr.set_data_dtype(dtype) + return im_dst + + +def to_dtype(dtype): + """ + Take a dtypeification and return an np.dtype + + :param dtype: dtypeification (string or np.dtype or None are supported for now) + :return: dtype or None + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + """ + # TODO add more or filter on things supported by nibabel + + if dtype is None: + return None + if isinstance(dtype, type): + if isinstance(dtype(0).dtype, np.dtype): + return dtype(0).dtype + if isinstance(dtype, np.dtype): + return dtype + if isinstance(dtype, str): + return np.dtype(dtype) + + raise TypeError("data type {}: {} not understood".format(dtype.__class__, dtype)) + + +def zeros_like(img, dtype=None): + """ + + :param img: reference image + :param dtype: desired data type (optional) + :return: an Image with the same shape and header, filled with zeros + + Similar to numpy.zeros_like(), the goal of the function is to show the developer's + intent and avoid doing a copy, which is slower than initialization with a constant. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + zimg = Image(np.zeros_like(img.data), hdr=img.hdr.copy()) + if dtype is not None: + zimg.change_type(dtype) + return zimg + + +def empty_like(img, dtype=None): + """ + :param img: reference image + :param dtype: desired data type (optional) + :return: an Image with the same shape and header, whose data is uninitialized + + Similar to numpy.empty_like(), the goal of the function is to show the developer's + intent and avoid touching the allocated memory, because it will be written to + afterwards. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + dst = change_type(img, dtype) + return dst + + +def find_zmin_zmax(im, threshold=0.1): + """ + Find the min (and max) z-slice index below which (and above which) slices only have voxels below a given threshold. + + :param im: Image object + :param threshold: threshold to apply before looking for zmin/zmax, typically corresponding to noise level. + :return: [zmin, zmax] + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + slicer = SlicerOneAxis(im, axis="IS") + + # Make sure image is not empty + if not np.any(slicer): + logger.error('Input image is empty') + + # Iterate from bottom to top until we find data + for zmin in range(0, len(slicer)): + if np.any(slicer[zmin] > threshold): + break + + # Conversely from top to bottom + for zmax in range(len(slicer) - 1, zmin, -1): + if np.any(slicer[zmax] > threshold): + break + + return zmin, zmax diff --git a/training/01_run_training_dcm-zurich-lesions.sh b/training/01_run_training_dcm-zurich-lesions.sh new file mode 100755 index 0000000..57e02f5 --- /dev/null +++ b/training/01_run_training_dcm-zurich-lesions.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# +# Run nnUNet training and testing on dcm-zurich-lesions and dcm-zurich-lesions-20231115 datasets +# +# Usage: +# cd ~/code/model-seg-dcm +# ./training/01_run_training_dcm-zurich-lesions.sh +# +# Author: Jan Valosek, Naga Karthik +# + +# Uncomment for full verbose +set -x + +# Immediately exit if error +set -e -o pipefail + +# Exit if user presses CTRL+C (Linux) or CMD+C (OSX) +trap "echo Caught Keyboard Interrupt within script. Exiting now.; exit" INT + + +# define arguments for nnUNet +dataset_num="603" +seed="710" +dataset_name="Dataset${dataset_num}_DCMlesionsSeed${seed}" +nnunet_trainer="nnUNetTrainerDiceCELoss_noSmooth" +nnunet_trainer="nnUNetTrainer" +#nnunet_trainer="nnUNetTrainer_2000epochs" # default: nnUNetTrainer +configuration="3d_fullres" # for 2D training, use "2d" +cuda_visible_devices=1 +folds=(1 2) +#folds=(3) +sites=(dcm-zurich-lesions dcm-zurich-lesions-20231115) +region_based="--region-based" +#region_based="" + + +echo "-------------------------------------------------------" +echo "Running preprocessing and verifying dataset integrity" +echo "-------------------------------------------------------" +nnUNetv2_plan_and_preprocess -d ${dataset_num} --verify_dataset_integrity -c ${configuration} + + +for fold in ${folds[@]}; do + echo "-------------------------------------------" + echo "Training on Fold $fold" + echo "-------------------------------------------" + + # training + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_num} ${configuration} ${fold} -tr ${nnunet_trainer} + + echo "" + echo "-------------------------------------------" + echo "Training completed, Testing on Fold $fold" + echo "-------------------------------------------" + + # run inference on testing sets for each site + for site in ${sites[@]}; do + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name}/imagesTs_${site} -tr ${nnunet_trainer} -o ${nnUNet_results}/${dataset_name}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num} -f ${fold} -c ${configuration} # -step_size 0.9 --disable_tta + + echo "-------------------------------------------------------" + echo "Running ANIMA evaluation on Test set for ${site} " + echo "-------------------------------------------------------" + + python training/02_compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name}/labelsTs_${site} --dataset-name ${site} ${region_based} + + done + +done diff --git a/training/02_compute_anima_metrics.py b/training/02_compute_anima_metrics.py new file mode 100644 index 0000000..3363972 --- /dev/null +++ b/training/02_compute_anima_metrics.py @@ -0,0 +1,370 @@ +""" +This script evaluates the reference segmentations and model predictions +using the "animaSegPerfAnalyzer" command + +**************************************************************************************** +SegPerfAnalyser (Segmentation Performance Analyzer) provides different marks, metrics +and scores for segmentation evaluation. +3 categories are available: + - SEGMENTATION EVALUATION: + Dice, the mean overlap + Jaccard, the union overlap + Sensitivity + Specificity + NPV (Negative Predictive Value) + PPV (Positive Predictive Value) + RVE (Relative Volume Error) in percentage + - SURFACE DISTANCE EVALUATION: + Hausdorff distance + Contour mean distance + Average surface distance + - DETECTION LESIONS EVALUATION: + PPVL (Positive Predictive Value for Lesions) + SensL, Lesion detection sensitivity + F1 Score, a F1 Score between PPVL and SensL + +Results are provided as follows: +Jaccard; Dice; Sensitivity; Specificity; PPV; NPV; RelativeVolumeError; +HausdorffDistance; ContourMeanDistance; SurfaceDistance; PPVL; SensL; F1_score; + +NbTestedLesions; VolTestedLesions; --> These metrics are computed for images that + have no lesions in the GT +**************************************************************************************** + +Mathematical details on how these metrics are computed can be found here: +https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6135867/pdf/41598_2018_Article_31911.pdf + +and in Section 4 of this paper (for how the subjects with no lesions are handled): +https://portal.fli-iam.irisa.fr/files/2021/06/MS_Challenge_Evaluation_Challengers.pdf + +INSTALLATION: +##### STEP 0: Install git lfs via apt if you don't already have it. +##### STEP 1: Install ANIMA ##### +cd ~ +mkdir anima/ +cd anima/ +wget -q https://github.com/Inria-Visages/Anima-Public/releases/download/v4.2/Anima-Ubuntu-4.2.zip (change version to latest) +unzip Anima-Ubuntu-4.2.zip +git lfs install +git clone --depth 1 https://github.com/Inria-Visages/Anima-Scripts-Public.git +git clone --depth 1 https://github.com/Inria-Visages/Anima-Scripts-Data-Public.git + +##### STEP 2: Configure directories ##### +# Variable names and section titles should stay the same +# Put this file in ${HOME}/.anima/config.txt +# Make the anima variable point to your Anima public build +# Make the extra-data-root point to the data folder of Anima-Scripts +# The last folder separator for each path is crucial, do not forget them +# Use full paths, nothing relative or using tildes + +cd ~ +mkdir .anima/ +touch .anima/config.txt + +echo "[anima-scripts]" >> .anima/config.txt +echo "anima = ${HOME}/anima/Anima-Binaries-4.2/" >> .anima/config.txt +echo "anima-scripts-public-root = ${HOME}/anima/Anima-Scripts-Public/" >> .anima/config.txt +echo "extra-data-root = ${HOME}/anima/Anima-Scripts-Data-Public/" >> .anima/config.txt + +USAGE: +python 02_compute_anima_metrics.py + --pred_folder + --gt_folder + --label-type + --dataset-name + --region-based + + +NOTE 1: For checking all the available options run the following command from your terminal: + /animaSegPerfAnalyzer -h +NOTE 2: We use certain additional arguments below with the following purposes: + -i -> input image, -r -> reference image, -o -> output folder + -d -> evaluates surface distance, -l -> evaluates the detection of lesions + -a -> intra-lesion evalulation (advanced), -s -> segmentation evaluation, + -X -> save as XML file -A -> prints details on output metrics and exits + +Authors: Naga Karthik, Jan Valosek +""" + +import os +import glob +import subprocess +import argparse +from collections import defaultdict +import xml.etree.ElementTree as ET +import numpy as np +import nibabel as nib +from test_utils import fetch_filename_details + + +def get_parser(): + # parse command line arguments + parser = argparse.ArgumentParser(description='Compute test metrics using animaSegPerfAnalyzer') + + # Arguments for model, data, and training + parser.add_argument('--pred-folder', required=True, type=str, + help='Path to the folder containing nifti images of test predictions') + parser.add_argument('--gt-folder', required=True, type=str, + help='Path to the folder containing nifti images of GT labels') + parser.add_argument('--dataset-name', required=True, type=str, + help='Dataset name used for storing on git-annex. Example: "dcm-zurich-lesions"') + parser.add_argument('--region-based', required=False, action='store_true', default=False, + help='If the training was done on region-based datasets, set this flag to True. ' + 'Region-based means that the output segmentation is a multi-class segmentation (1: SC, ' + '2: Lesion)') + parser.add_argument('--label-type', required=False, type=str, choices=['sc', 'lesion'], default='sc', + help='Type of prediction and GT label to be used for ANIMA evaluation. ' + 'Options: "sc" for spinal cord segmentation, "lesion" for lesion segmentation' + 'NOTE: when label-type is "lesion", additional lesion detection metrics, namely,' + 'Lesion PPV, Lesion Sensitivity, and F1_score are computed' + 'NOTE: this argument is ignored when region-based is set to True, and both "sc" and ' + '"lesion" are computed for each subject.') + + return parser + + +def get_test_metrics_by_dataset(pred_folder, gt_folder, output_folder, anima_binaries_path, region_based, label_type): + """ + Computes the test metrics given folders containing nifti images of test predictions + and GT images by running the "animaSegPerfAnalyzer" command + """ + + # glob all the predictions and GTs and get the last three digits of the filename + pred_files = sorted(glob.glob(os.path.join(pred_folder, "*.nii.gz"))) + gt_files = sorted(glob.glob(os.path.join(gt_folder, "*.nii.gz"))) + + dataset_name_nnunet = fetch_filename_details(pred_files[0])[0] + + if region_based: + # loop over the predictions and compute the metrics + for pred_file, gt_file in zip(pred_files, gt_files): + + _, sub_pred, ses_pred, idx_pred, _, _ = fetch_filename_details(pred_file) + _, sub_gt, ses_gt, idx_gt, _, _ = fetch_filename_details(gt_file) + + # make sure the subject and session IDs match + print(f"Subject and session IDs for Preds and GTs: {sub_pred}_{ses_pred}_{idx_pred}, {sub_gt}_{ses_gt}_{idx_gt}") + assert idx_pred == idx_gt, 'Subject and session IDs for Preds and GTs do not match. Please check the filenames.' + + if ses_gt == "": + sub_ses_pred, sub_ses_gt = f"{sub_pred}", f"{sub_gt}" + else: + sub_ses_pred, sub_ses_gt = f"{sub_pred}_{ses_pred}", f"{sub_gt}_{ses_gt}" + assert sub_ses_pred == sub_ses_gt, 'Subject and session IDs for Preds and GTs do not match. Please check the filenames.' + + for seg in ['sc', 'lesion']: + # load the predictions and GTs + pred_npy = nib.load(pred_file).get_fdata() + gt_npy = nib.load(gt_file).get_fdata() + + if seg == 'sc': + pred_npy = np.array(pred_npy == 1, dtype=float) + gt_npy = np.array(gt_npy == 1, dtype=float) + + elif seg == 'lesion': + pred_npy = np.array(pred_npy == 2, dtype=float) + gt_npy = np.array(gt_npy == 2, dtype=float) + + # Save the binarized predictions and GTs + pred_nib = nib.Nifti1Image(pred_npy, affine=np.eye(4)) + gtc_nib = nib.Nifti1Image(gt_npy, affine=np.eye(4)) + nib.save(img=pred_nib, filename=os.path.join(pred_folder, f"{dataset_name_nnunet}_{sub_ses_pred}_{idx_pred}_{seg}.nii.gz")) + nib.save(img=gtc_nib, filename=os.path.join(gt_folder, f"{dataset_name_nnunet}_{sub_ses_gt}_{idx_gt}_{seg}.nii.gz")) + + # Run ANIMA segmentation performance metrics on the predictions + if seg == 'sc': + seg_perf_analyzer_cmd = '%s -i %s -r %s -o %s -d -s -X' + elif seg == 'lesion': # add lesion evaluation metrics with `-l` + seg_perf_analyzer_cmd = '%s -i %s -r %s -o %s -d -s -l -X' + + os.system(seg_perf_analyzer_cmd % + (os.path.join(anima_binaries_path, 'animaSegPerfAnalyzer'), + os.path.join(pred_folder, f"{dataset_name_nnunet}_{sub_ses_pred}_{idx_pred}_{seg}.nii.gz"), + os.path.join(gt_folder, f"{dataset_name_nnunet}_{sub_ses_gt}_{idx_gt}_{seg}.nii.gz"), + os.path.join(output_folder, f"{idx_pred}_{seg}"))) + + # Delete temporary binarized NIfTI files + os.remove(os.path.join(pred_folder, f"{dataset_name_nnunet}_{sub_ses_pred}_{idx_pred}_{seg}.nii.gz")) + os.remove(os.path.join(gt_folder, f"{dataset_name_nnunet}_{sub_ses_gt}_{idx_gt}_{seg}.nii.gz")) + + # Get all XML filepaths where ANIMA performance metrics are saved for each hold-out subject + subject_sc_filepaths = [os.path.join(output_folder, f) for f in + os.listdir(output_folder) if f.endswith('.xml') and 'sc' in f] + subject_lesion_filepaths = [os.path.join(output_folder, f) for f in + os.listdir(output_folder) if f.endswith('.xml') and 'lesion' in f] + + return subject_sc_filepaths, subject_lesion_filepaths + + else: + # loop over the predictions and compute the metrics + for pred_file, gt_file in zip(pred_files, gt_files): + + _, sub_pred, ses_pred, idx_pred, _, _ = fetch_filename_details(pred_file) + _, sub_gt, ses_gt, idx_gt, _, _ = fetch_filename_details(gt_file) + + # make sure the subject and session IDs match + print(f"Subject and session IDs for Preds and GTs: {sub_pred}_{ses_pred}_{idx_pred}, {sub_gt}_{ses_gt}_{idx_gt}") + assert idx_pred == idx_gt, 'Subject and session IDs for Preds and GTs do not match. Please check the filenames.' + + if ses_gt == "": + sub_ses_pred, sub_ses_gt = f"{sub_pred}", f"{sub_gt}" + else: + sub_ses_pred, sub_ses_gt = f"{sub_pred}_{ses_pred}", f"{sub_gt}_{ses_gt}" + assert sub_ses_pred == sub_ses_gt, 'Subject and session IDs for Preds and GTs do not match. Please check the filenames.' + + # load the predictions and GTs + pred_npy = nib.load(pred_file).get_fdata() + gt_npy = nib.load(gt_file).get_fdata() + + # make sure the predictions are binary because ANIMA accepts binarized inputs only + pred_npy = np.array(pred_npy > 0.5, dtype=float) + gt_npy = np.array(gt_npy > 0.5, dtype=float) + + # Save the binarized predictions and GTs + pred_nib = nib.Nifti1Image(pred_npy, affine=np.eye(4)) + gtc_nib = nib.Nifti1Image(gt_npy, affine=np.eye(4)) + nib.save(img=pred_nib, filename=os.path.join(pred_folder, f"{dataset_name_nnunet}_{idx_pred}_bin.nii.gz")) + nib.save(img=gtc_nib, filename=os.path.join(gt_folder, f"{dataset_name_nnunet}_{idx_gt}_bin.nii.gz")) + + # Run ANIMA segmentation performance metrics on the predictions + if label_type == 'lesion': + seg_perf_analyzer_cmd = '%s -i %s -r %s -o %s -d -l -s -X' + elif label_type == 'sc': + seg_perf_analyzer_cmd = '%s -i %s -r %s -o %s -d -s -X' + else: + raise ValueError('Please specify a valid label type: lesion or sc') + + os.system(seg_perf_analyzer_cmd % + (os.path.join(anima_binaries_path, 'animaSegPerfAnalyzer'), + os.path.join(pred_folder, f"{dataset_name_nnunet}_{idx_pred}_bin.nii.gz"), + os.path.join(gt_folder, f"{dataset_name_nnunet}_{idx_gt}_bin.nii.gz"), + os.path.join(output_folder, f"{idx_pred}"))) + + # Delete temporary binarized NIfTI files + os.remove(os.path.join(pred_folder, f"{dataset_name_nnunet}_{idx_pred}_bin.nii.gz")) + os.remove(os.path.join(gt_folder, f"{dataset_name_nnunet}_{idx_gt}_bin.nii.gz")) + + # Get all XML filepaths where ANIMA performance metrics are saved for each hold-out subject + subject_filepaths = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if f.endswith('.xml')] + + return subject_filepaths + + +def main(): + + # get the ANIMA binaries path + cmd = r'''grep "^anima = " ~/.anima/config.txt | sed "s/.* = //"''' + anima_binaries_path = subprocess.check_output(cmd, shell=True).decode('utf-8').strip('\n') + print('ANIMA Binaries Path:', anima_binaries_path) + # version = subprocess.check_output(anima_binaries_path + 'animaSegPerfAnalyzer --version', shell=True).decode('utf-8').strip('\n') + print('Running ANIMA version:', + subprocess.check_output(anima_binaries_path + 'animaSegPerfAnalyzer --version', shell=True).decode( + 'utf-8').strip('\n')) + + parser = get_parser() + args = parser.parse_args() + + # define variables + pred_folder, gt_folder = args.pred_folder, args.gt_folder + label_type = args.label_type + dataset_name = args.dataset_name + region_based = args.region_based + + output_folder = os.path.join(pred_folder, f"anima_stats") + if not os.path.exists(output_folder): + os.makedirs(output_folder, exist_ok=True) + print(f"Saving ANIMA performance metrics to {output_folder}") + + if not region_based: + + # Get all XML filepaths where ANIMA performance metrics are saved for each hold-out subject + subject_filepaths = get_test_metrics_by_dataset(pred_folder, gt_folder, output_folder, anima_binaries_path, + region_based, label_type=label_type) + + test_metrics = defaultdict(list) + + # Update the test metrics dictionary by iterating over all subjects + for subject_filepath in subject_filepaths: + subject = os.path.split(subject_filepath)[-1].split('_')[0] + root_node = ET.parse(source=subject_filepath).getroot() + + # if GT is empty then metrics aren't calculated, hence the only entries in the XML file + # NbTestedLesions and VolTestedLesions, both of which are zero. Hence, we can skip subjects + # with empty GTs by checked if the length of the .xml file is 2 + if len(root_node) == 2: + print(f"Skipping Subject={int(subject):03d} ENTIRELY Due to Empty GT!") + continue + + for metric in list(root_node): + name, value = metric.get('name'), float(metric.text) + + if np.isinf(value) or np.isnan(value): + print(f'Skipping Metric={name} for Subject={int(subject):03d} Due to INF or NaNs!') + continue + + test_metrics[name].append(value) + + # Print aggregation of each metric via mean and standard dev. + with open(os.path.join(output_folder, f'log_{dataset_name}.txt'), 'a') as f: + print(f'Test Phase Metrics [ANIMA], n={len(subject_filepaths)}: ', file=f) + + print(f'Test Phase Metrics [ANIMA], n={len(subject_filepaths)}: ') + for key in test_metrics: + print('\t%s -> Mean: %0.4f Std: %0.2f' % (key, np.mean(test_metrics[key]), np.std(test_metrics[key]))) + + # save the metrics to a log file + with open(os.path.join(output_folder, f'log_{dataset_name}.txt'), 'a') as f: + print("\t%s --> Mean: %0.3f, Std: %0.3f" % + (key, np.mean(test_metrics[key]), np.std(test_metrics[key])), file=f) + + # Region-based training + else: + # Get all XML filepaths where ANIMA performance metrics are saved for each hold-out subject + subject_sc_filepaths, subject_lesion_filepaths = \ + get_test_metrics_by_dataset(pred_folder, gt_folder, output_folder, anima_binaries_path, + region_based, label_type=label_type) + + # loop through the sc and lesion filepaths and get the metrics + for subject_filepaths in [subject_sc_filepaths, subject_lesion_filepaths]: + + test_metrics = defaultdict(list) + + # Update the test metrics dictionary by iterating over all subjects + for subject_filepath in subject_filepaths: + + subject = os.path.split(subject_filepath)[-1].split('_')[0] + seg_type = os.path.split(subject_filepath)[-1].split('_')[1] + root_node = ET.parse(source=subject_filepath).getroot() + + # if GT is empty then metrics aren't calculated, hence the only entries in the XML file + # NbTestedLesions and VolTestedLesions, both of which are zero. Hence, we can skip subjects + # with empty GTs by checked if the length of the .xml file is 2 + if len(root_node) == 2: + print(f"Skipping Subject={int(subject):03d} ENTIRELY Due to Empty GT!") + continue + + for metric in list(root_node): + name, value = metric.get('name'), float(metric.text) + + if np.isinf(value) or np.isnan(value): + print(f'Skipping Metric={name} for Subject={int(subject):03d} Due to INF or NaNs!') + continue + + test_metrics[name].append(value) + + # Print aggregation of each metric via mean and standard dev. + with open(os.path.join(output_folder, f'log_{dataset_name}.txt'), 'a') as f: + print(f'Test Phase Metrics [ANIMA] for {seg_type}, n={len(subject_filepaths)}: ', file=f) + + print(f'Test Phase Metrics [ANIMA] for {seg_type}, n={len(subject_filepaths)}: ') + for key in test_metrics: + print('\t%s -> Mean: %0.4f Std: %0.2f' % (key, np.mean(test_metrics[key]), np.std(test_metrics[key]))) + + # save the metrics to a log file + with open(os.path.join(output_folder, f'log_{dataset_name}.txt'), 'a') as f: + print("\t%s --> Mean: %0.3f, Std: %0.3f" % + (key, np.mean(test_metrics[key]), np.std(test_metrics[key])), file=f) + + +if __name__ == '__main__': + main() diff --git a/training/test_utils.py b/training/test_utils.py new file mode 100644 index 0000000..a3b7737 --- /dev/null +++ b/training/test_utils.py @@ -0,0 +1,42 @@ +import os +import re + + +def fetch_filename_details(filename_path): + """ + Get dataset name, subject name, session number (if exists), file ID and filename from the input nnUNet-compatible + filename or file path. The function works both on absolute file path as well as filename + :param filename_path: input nifti filename (e.g., sub-001_ses-01_T1w.nii.gz) or file path + (e.g., /home/user/MRI/bids/derivatives/labels/sub-001/ses-01/anat/sub-001_ses-01_T1w.nii.gz + :return: datasetName: dataset name + :return: subjectID: subject ID (e.g., sub-001) + :return: sessionID: session ID (e.g., ses-01) + :return: fileID: file ID (e.g., 001) + :return: fileName: filename (e.g., sub-001_ses-01_T1w.nii.gz) + :return: seg_type: segmentation type (e.g., sc or lesion) + + Adapted from: https://github.com/spinalcordtoolbox/manual-correction/blob/main/utils.py#L24 + """ + + _, fileName = os.path.split(filename_path) # Get just the filename (i.e., remove the path) + datasetName = fileName.split('_')[0] # Get the dataset name (i.e., remove the filename) + + subject = re.search('sub-(.*?)[_/]', filename_path) + subjectID = subject.group(0)[:-1] if subject else "" # [:-1] removes the last underscore or slash + + session = re.findall(r'ses-..', filename_path) + sessionID = session[0] if session else "" # Return None if there is no session + + fID = re.search('_\d{3}', fileName) + fileID = fID.group(0)[1:] if fID else "" # [1:-1] removes the underscores + + # Fetch segtype (sc or lesion) + seg_type = re.search('sc|lesion', filename_path) + seg_type = seg_type.group(0) if seg_type else "" + + # REGEX explanation + # \d - digit + # \d? - no or one occurrence of digit + # *? - match the previous element as few times as possible (zero or more times) + + return datasetName, subjectID, sessionID, fileID, fileName, seg_type