Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions super_resolution/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Super-resolution experiment

In this folder we have the code used for the super-resolution experiment. The super-resolution algorithm is stored here: https://gitlab.com/iacl/smore

To run the super-resolution algo: `run-smore --in-fpath img_path --out-dir output_dir`

70 changes: 70 additions & 0 deletions super_resolution/perform_sct_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
This code performs inference on the super-resolved images using the SCT lesion_ms model.
The model used is the model from this release: https://github.com/ivadomed/ms-lesion-agnostic/releases/tag/r20250909
Inference is performed with GPU to speed up the process.
Instruction to use GPU inference with SCT: https://github.com/spinalcordtoolbox/spinalcordtoolbox/issues/4360#issuecomment-2035294418

Input:
-input: The directory containing the super-resolved images.
-output: The directory to save the inference results.

Output:
None

Example:
python perform_sct_inference.py --input /path/to/superres_images --output /path/to/output

Author: Pierre-Louis Benveniste
"""
import argparse
import os
from pathlib import Path
from tqdm import tqdm
import shutil


def parse_args():
parser = argparse.ArgumentParser(description="Perform SCT inference on super-resolved images.")
parser.add_argument("--input", type=str, required=True, help="Input directory containing super-resolved images.")
parser.add_argument("--output", type=str, required=True, help="Output directory for inference results.")
return parser.parse_args()


def main():
# Parse arguments
args = parse_args()
input_dir = args.input
output_dir = args.output

# Build the output_dir if it does not exist
os.makedirs(output_dir, exist_ok=True)

# List all .nii.gz files in the input directory
input_files = list(Path(input_dir).rglob("*.nii.gz"))
input_files = sorted(input_files)
input_files = [str(f) for f in input_files]

# Run the SCT model on each file
for input_file in tqdm(input_files):
# Build a temp folder to store intermediate results
temp_folder = Path(output_dir) / "temp"
os.makedirs(temp_folder, exist_ok=True)

# Build file names
output_temp_file = os.path.join(temp_folder, str(input_file).split('/')[-1].replace("_0000.nii.gz", ".nii.gz"))
output_label_file = os.path.join(Path(output_dir), output_temp_file.split('/')[-1])

# Run the inference command using GPU
command = f"SCT_USE_GPU=1 sct_deepseg lesion_ms -i {input_file} -o {output_temp_file}"
assert os.system(command) == 0

# Move the resulting label file to the output directory
assert os.path.exists(output_temp_file), f"Error: Output file {output_temp_file} not found."
shutil.move(str(output_temp_file), str(output_label_file))

# Remove the temporary folder
shutil.rmtree(temp_folder)


if __name__ == "__main__":
main()
92 changes: 92 additions & 0 deletions super_resolution/register_labels_to_superres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
This script registers the labels from the nnunet folder to the super-resolved images.
It uses sct_register_multimodal -identity 1 from the spinal cord toolbox.

Input:
-labels: The directory containing the labels to register.
-images: The directory containing the super-resolved images.
-output: The directory to save the registered labels.
-min-idx: Minimum index of the files to process (inclusive).
-max-idx: Maximum index of the files to process (exclusive).

Output:
None

Example:
python register_labels_to_superres.py --labels /path/to/labels --images /path/to/superres_images --output /path/to/output

Author: Pierre-Louis Benveniste
"""
import argparse
import os
from pathlib import Path
from tqdm import tqdm
import shutil


def parse_args():
parser = argparse.ArgumentParser(description="Register labels to super-resolved images.")
parser.add_argument("--labels", type=str, required=True, help="Directory containing the labels to register.")
parser.add_argument("--images", type=str, required=True, help="Directory containing the super-resolved images.")
parser.add_argument("--output", type=str, required=True, help="Directory to save the registered labels.")
parser.add_argument("--min-idx", type=int, default=0, help="Minimum index of the files to process (inclusive).")
parser.add_argument("--max-idx", type=int, default=None, help="Maximum index of the files to process (exclusive).")
return parser.parse_args()


def main():
# Parse arguments
args = parse_args()
labels_dir = args.labels
images_dir = args.images
output_dir = args.output

# Build the output_dir if it does not exist
os.makedirs(output_dir, exist_ok=True)

# List all .nii.gz files in the labels directory
label_files = list(Path(labels_dir).rglob("*.nii.gz"))
label_files = sorted(label_files)
label_files = [str(f) for f in label_files]

# Keep only file with file_00X.nii.gz with X between min_idx and max_idx
if args.min_idx is not None and args.max_idx is not None:
label_files = [f for f in label_files if int(f.split("_")[-1].replace(".nii.gz", "")) >= args.min_idx and int(f.split("_")[-1].replace(".nii.gz", "")) < args.max_idx]

# For each label file, find the corresponding super-resolved image and register
for label_file in tqdm(label_files):
# Extract the base name to find the corresponding image
image_name = label_file.split("/")[-1].replace(".nii.gz", "_0000.nii.gz")
image_file = Path(images_dir) / image_name

if not image_file.exists():
print(f"Warning: Corresponding image for {label_file} not found. Skipping.")
break

# Build the output label file path
output_label_file = Path(output_dir) / Path(label_file).name
if output_label_file.exists():
print(f"Warning: Output file {output_label_file} already exists. Skipping.")
continue

# Build a temporary folder to store intermediate results
label_number = str(label_file).split("_")[-1].replace(".nii.gz", "")
temp_folder = Path(output_dir) / f"temp_{label_number}"
os.makedirs(temp_folder, exist_ok=True)
output_temp_file = temp_folder / Path(label_file).name

# Run the registration command
command = f"sct_register_multimodal -i {label_file} -d {image_file} -o {output_temp_file} -identity 1 -x nn"
assert os.system(command) == 0, f"Error in registration for {label_file}"

# Move the registered label to the output directory
assert os.system(f"mv {output_temp_file} {output_label_file}") == 0, f"Error moving file {output_temp_file} to {output_label_file}"

# Clean up the temporary folder
shutil.rmtree(temp_folder)

return None


if __name__ == "__main__":
main()
78 changes: 78 additions & 0 deletions super_resolution/run_super_res_nnunet_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
This algorithm runs the super-resolution model on some preprocessed data.
The nnUNet data is either some train or test data in the nnUNet format.
We give as input a min and max index to be a able to run the super-resolution in parallel.
This script requires to activate the smorve conda environment.

Input:
- input_dir: The directory containing the nnUNet data.
- output_dir: The directory where the super-resolution results will be saved.
- min_index: The minimum index of the data to process.
- max_index: The maximum index of the data to process.

Output:
None

Example:
python run_super_res_nnunet_data.py --input_dir /path/to/nnUNet_data --output_dir /path/to/output --min_index 0 --max_index 100

Author: Pierre-Louis Benveniste
"""
import argparse
import os
from pathlib import Path
from tqdm import tqdm
import nibabel as nib


def parse_args():
parser = argparse.ArgumentParser(description="Run super-resolution on nnUNet data.")
parser.add_argument("--input-dir", type=str, required=True, help="Input directory containing nnUNet data.")
parser.add_argument("--output-dir", type=str, required=True, help="Output directory for super-resolution results.")
parser.add_argument("--min-index", type=int, required=True, help="Minimum index of the data to process.")
parser.add_argument("--max-index", type=int, required=True, help="Maximum index of the data to process.")
return parser.parse_args()


def main():
# Parse arguments
args = parse_args()
input_dir = args.input_dir
output_dir = args.output_dir

# Build the output_dir if it does not exist
os.makedirs(output_dir, exist_ok=True)

# List all .nii.gz files in the input directory
input_files = list(Path(input_dir).rglob("*.nii.gz"))
input_files = sorted(input_files)

# Images will have the following standard: name_XXX_0000.nii.gz with XXX being the index
# We will process files from min_index to max_index
input_files = [f for f in input_files if int(f.stem.split("_")[-2]) >= args.min_index and int(f.stem.split("_")[-2]) <= args.max_index]

# Run the super-resolution model on each file
for input_file in tqdm(input_files):
# If an image is isotropic we just copy it to the destination folder
img = nib.load(str(input_file))
resolution = img.header.get_zooms()[:3] # Get the first three dimensions (x, y, z)
# If all three are clode to 1e-2 then we consider it isotropic
if max(resolution) - min(resolution) < 1e-2:
# Then we just copy the file to the output directory
assert os.system(f"cp {str(input_file)} {output_dir}")==0
else:
# Build a temp folder
temp_dir = Path(output_dir) / f"temp_{input_file.stem.split('_')[-2]}"
os.makedirs(temp_dir, exist_ok=True)
os.system(f"run-smore --in-fpath {str(input_file)} --out-dir {temp_dir}")
# Then we copy the result to the output directory
temp_files = list(temp_dir.rglob("*.nii.gz"))
assert os.system(f"cp {str(temp_files[0])} {str(Path(output_dir) / temp_files[0].name.replace('_smore4', ''))}") == 0
# Remove the temp directory
assert os.system(f"rm -rf {temp_dir}") == 0

return None


if __name__ == "__main__":
main()