diff --git a/super_resolution/README.md b/super_resolution/README.md new file mode 100644 index 0000000..aeb6f90 --- /dev/null +++ b/super_resolution/README.md @@ -0,0 +1,6 @@ +# Super-resolution experiment + +In this folder we have the code used for the super-resolution experiment. The super-resolution algorithm is stored here: https://gitlab.com/iacl/smore + +To run the super-resolution algo: `run-smore --in-fpath img_path --out-dir output_dir` + diff --git a/super_resolution/perform_sct_inference.py b/super_resolution/perform_sct_inference.py new file mode 100644 index 0000000..d8e8ad4 --- /dev/null +++ b/super_resolution/perform_sct_inference.py @@ -0,0 +1,70 @@ +""" +This code performs inference on the super-resolved images using the SCT lesion_ms model. +The model used is the model from this release: https://github.com/ivadomed/ms-lesion-agnostic/releases/tag/r20250909 +Inference is performed with GPU to speed up the process. +Instruction to use GPU inference with SCT: https://github.com/spinalcordtoolbox/spinalcordtoolbox/issues/4360#issuecomment-2035294418 + +Input: + -input: The directory containing the super-resolved images. + -output: The directory to save the inference results. + +Output: + None + +Example: + python perform_sct_inference.py --input /path/to/superres_images --output /path/to/output + +Author: Pierre-Louis Benveniste +""" +import argparse +import os +from pathlib import Path +from tqdm import tqdm +import shutil + + +def parse_args(): + parser = argparse.ArgumentParser(description="Perform SCT inference on super-resolved images.") + parser.add_argument("--input", type=str, required=True, help="Input directory containing super-resolved images.") + parser.add_argument("--output", type=str, required=True, help="Output directory for inference results.") + return parser.parse_args() + + +def main(): + # Parse arguments + args = parse_args() + input_dir = args.input + output_dir = args.output + + # Build the output_dir if it does not exist + os.makedirs(output_dir, exist_ok=True) + + # List all .nii.gz files in the input directory + input_files = list(Path(input_dir).rglob("*.nii.gz")) + input_files = sorted(input_files) + input_files = [str(f) for f in input_files] + + # Run the SCT model on each file + for input_file in tqdm(input_files): + # Build a temp folder to store intermediate results + temp_folder = Path(output_dir) / "temp" + os.makedirs(temp_folder, exist_ok=True) + + # Build file names + output_temp_file = os.path.join(temp_folder, str(input_file).split('/')[-1].replace("_0000.nii.gz", ".nii.gz")) + output_label_file = os.path.join(Path(output_dir), output_temp_file.split('/')[-1]) + + # Run the inference command using GPU + command = f"SCT_USE_GPU=1 sct_deepseg lesion_ms -i {input_file} -o {output_temp_file}" + assert os.system(command) == 0 + + # Move the resulting label file to the output directory + assert os.path.exists(output_temp_file), f"Error: Output file {output_temp_file} not found." + shutil.move(str(output_temp_file), str(output_label_file)) + + # Remove the temporary folder + shutil.rmtree(temp_folder) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/super_resolution/register_labels_to_superres.py b/super_resolution/register_labels_to_superres.py new file mode 100644 index 0000000..fb57a6f --- /dev/null +++ b/super_resolution/register_labels_to_superres.py @@ -0,0 +1,92 @@ +""" +This script registers the labels from the nnunet folder to the super-resolved images. +It uses sct_register_multimodal -identity 1 from the spinal cord toolbox. + +Input: + -labels: The directory containing the labels to register. + -images: The directory containing the super-resolved images. + -output: The directory to save the registered labels. + -min-idx: Minimum index of the files to process (inclusive). + -max-idx: Maximum index of the files to process (exclusive). + +Output: + None + +Example: + python register_labels_to_superres.py --labels /path/to/labels --images /path/to/superres_images --output /path/to/output + +Author: Pierre-Louis Benveniste +""" +import argparse +import os +from pathlib import Path +from tqdm import tqdm +import shutil + + +def parse_args(): + parser = argparse.ArgumentParser(description="Register labels to super-resolved images.") + parser.add_argument("--labels", type=str, required=True, help="Directory containing the labels to register.") + parser.add_argument("--images", type=str, required=True, help="Directory containing the super-resolved images.") + parser.add_argument("--output", type=str, required=True, help="Directory to save the registered labels.") + parser.add_argument("--min-idx", type=int, default=0, help="Minimum index of the files to process (inclusive).") + parser.add_argument("--max-idx", type=int, default=None, help="Maximum index of the files to process (exclusive).") + return parser.parse_args() + + +def main(): + # Parse arguments + args = parse_args() + labels_dir = args.labels + images_dir = args.images + output_dir = args.output + + # Build the output_dir if it does not exist + os.makedirs(output_dir, exist_ok=True) + + # List all .nii.gz files in the labels directory + label_files = list(Path(labels_dir).rglob("*.nii.gz")) + label_files = sorted(label_files) + label_files = [str(f) for f in label_files] + + # Keep only file with file_00X.nii.gz with X between min_idx and max_idx + if args.min_idx is not None and args.max_idx is not None: + label_files = [f for f in label_files if int(f.split("_")[-1].replace(".nii.gz", "")) >= args.min_idx and int(f.split("_")[-1].replace(".nii.gz", "")) < args.max_idx] + + # For each label file, find the corresponding super-resolved image and register + for label_file in tqdm(label_files): + # Extract the base name to find the corresponding image + image_name = label_file.split("/")[-1].replace(".nii.gz", "_0000.nii.gz") + image_file = Path(images_dir) / image_name + + if not image_file.exists(): + print(f"Warning: Corresponding image for {label_file} not found. Skipping.") + break + + # Build the output label file path + output_label_file = Path(output_dir) / Path(label_file).name + if output_label_file.exists(): + print(f"Warning: Output file {output_label_file} already exists. Skipping.") + continue + + # Build a temporary folder to store intermediate results + label_number = str(label_file).split("_")[-1].replace(".nii.gz", "") + temp_folder = Path(output_dir) / f"temp_{label_number}" + os.makedirs(temp_folder, exist_ok=True) + output_temp_file = temp_folder / Path(label_file).name + + # Run the registration command + command = f"sct_register_multimodal -i {label_file} -d {image_file} -o {output_temp_file} -identity 1 -x nn" + assert os.system(command) == 0, f"Error in registration for {label_file}" + + # Move the registered label to the output directory + assert os.system(f"mv {output_temp_file} {output_label_file}") == 0, f"Error moving file {output_temp_file} to {output_label_file}" + + # Clean up the temporary folder + shutil.rmtree(temp_folder) + + return None + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/super_resolution/run_super_res_nnunet_data.py b/super_resolution/run_super_res_nnunet_data.py new file mode 100644 index 0000000..f6e4cd2 --- /dev/null +++ b/super_resolution/run_super_res_nnunet_data.py @@ -0,0 +1,78 @@ +""" +This algorithm runs the super-resolution model on some preprocessed data. +The nnUNet data is either some train or test data in the nnUNet format. +We give as input a min and max index to be a able to run the super-resolution in parallel. +This script requires to activate the smorve conda environment. + +Input: +- input_dir: The directory containing the nnUNet data. +- output_dir: The directory where the super-resolution results will be saved. +- min_index: The minimum index of the data to process. +- max_index: The maximum index of the data to process. + +Output: + None + +Example: + python run_super_res_nnunet_data.py --input_dir /path/to/nnUNet_data --output_dir /path/to/output --min_index 0 --max_index 100 + +Author: Pierre-Louis Benveniste +""" +import argparse +import os +from pathlib import Path +from tqdm import tqdm +import nibabel as nib + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run super-resolution on nnUNet data.") + parser.add_argument("--input-dir", type=str, required=True, help="Input directory containing nnUNet data.") + parser.add_argument("--output-dir", type=str, required=True, help="Output directory for super-resolution results.") + parser.add_argument("--min-index", type=int, required=True, help="Minimum index of the data to process.") + parser.add_argument("--max-index", type=int, required=True, help="Maximum index of the data to process.") + return parser.parse_args() + + +def main(): + # Parse arguments + args = parse_args() + input_dir = args.input_dir + output_dir = args.output_dir + + # Build the output_dir if it does not exist + os.makedirs(output_dir, exist_ok=True) + + # List all .nii.gz files in the input directory + input_files = list(Path(input_dir).rglob("*.nii.gz")) + input_files = sorted(input_files) + + # Images will have the following standard: name_XXX_0000.nii.gz with XXX being the index + # We will process files from min_index to max_index + input_files = [f for f in input_files if int(f.stem.split("_")[-2]) >= args.min_index and int(f.stem.split("_")[-2]) <= args.max_index] + + # Run the super-resolution model on each file + for input_file in tqdm(input_files): + # If an image is isotropic we just copy it to the destination folder + img = nib.load(str(input_file)) + resolution = img.header.get_zooms()[:3] # Get the first three dimensions (x, y, z) + # If all three are clode to 1e-2 then we consider it isotropic + if max(resolution) - min(resolution) < 1e-2: + # Then we just copy the file to the output directory + assert os.system(f"cp {str(input_file)} {output_dir}")==0 + else: + # Build a temp folder + temp_dir = Path(output_dir) / f"temp_{input_file.stem.split('_')[-2]}" + os.makedirs(temp_dir, exist_ok=True) + os.system(f"run-smore --in-fpath {str(input_file)} --out-dir {temp_dir}") + # Then we copy the result to the output directory + temp_files = list(temp_dir.rglob("*.nii.gz")) + assert os.system(f"cp {str(temp_files[0])} {str(Path(output_dir) / temp_files[0].name.replace('_smore4', ''))}") == 0 + # Remove the temp directory + assert os.system(f"rm -rf {temp_dir}") == 0 + + return None + + +if __name__ == "__main__": + main() \ No newline at end of file