Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions monai/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
import numpy as np
from typing import Dict, Hashable, Mapping
from scipy.ndimage.morphology import binary_erosion
import scipy.ndimage as ndi
import torch
import monai.transforms as transforms
from monai.config import KeysCollection
from monai.transforms import MapTransform
import torchio as tio

rs = np.random.RandomState()


class SpinalCordContourd(MapTransform):
Expand Down Expand Up @@ -59,6 +63,65 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch
return data


def train_transforms_totalspineseg(crop_size, lbl_key="label", pad_mode="zero", device="cuda"):

transforms_monai = [
# pre-processing
transforms.LoadImaged(keys=["image", lbl_key]),
transforms.EnsureChannelFirstd(keys=["image", lbl_key]),
transforms.Orientationd(keys=["image", lbl_key], axcodes="RPI"),
# NOTE: spine interpolation with order=2 is spline, order=1 is linear
transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)),
transforms.ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,
mode="constant" if pad_mode == "zero" else pad_mode),
# convert the data to Tensor without meta, move to GPU and cache it to avoid CPU -> GPU sync in every epoch
transforms.EnsureTyped(keys=["image", lbl_key], device=device, track_meta=False),
# Contrast augmentation
transforms.RandLambdad(keys=["image"], func=lambda x: ndi.laplace(x)), # laplacian
transforms.RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.3), # this is monai's RandomGamma
transforms.NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False),
transforms.HistogramNormalized(keys=["image"], num_bins=256, min=0.0, max=1.0),
transforms.RandLambdad(keys=["image"], func=lambda x: torch.log(1 + x), prob=0.05), # log
transforms.RandLambdad(keys=["image"], func=lambda x: torch.sqrt(x), prob=0.05), # square root
transforms.RandLambdad(keys=["image"], func=lambda x: torch.exp(x), prob=0.05), # exponential
transforms.RandLambdad(keys=["image"], func=lambda x: torch.sin(x), prob=0.05), # sine
transforms.RandLambdad(keys=["image"], func=lambda x: 1/(1+torch.exp(-x)), prob=0.05), # sigmoid
# transforms.RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.1), # nnUNet's BrightnessMultiplicativeTransform
# transforms.RandGaussianSharpen(keys=["image"], prob=0.1),
]
# todo: add inverse color augmentation

# artifacts augmentation
if rs.rand() < 0.7:
transforms_monai.append(rs.choice([
tio.RandomMotion(include=["image", lbl_key]),
tio.RandomGhosting(include=["image", lbl_key]),
tio.RandomSpike(intensity=(1,2), include=["image"]),
tio.RandomBiasField(include=["image"]),
tio.RandomBlur(include=["image"]),
]))
transforms_monai.append(transforms.RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1))
transforms_monai.append(transforms.RandGaussianSharpend(keys=["image"], prob=0.1),)

# spatial augmentation
transforms_monai.append(tio.RandomFlip(axes=('LR'), flip_probability=0.3, include=["image", lbl_key]))

if rs.rand() < 0.7:
transforms_monai.append(rs.choice([
tio.RandomAffine(image_interpolation='bspline', label_interpolation='linear', include=["image", lbl_key]),
tio.RandomAffine(image_interpolation='linear', label_interpolation='nearest', include=["image", lbl_key]),
tio.RandomElasticDeformation(max_displacement=30, include=["image", lbl_key]),
]))

# simulate low resolution
if rs.rand() < 0.7:
transforms_monai.append(tio.RandomAnisotropy(downsampling=(1.5, 5), include=["image", lbl_key]))

transforms_monai.append(tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(0.5, 99.5), include=["image"]))

return transforms.Compose(transforms_monai)


def train_transforms(crop_size, lbl_key="label", pad_mode="zero", device="cuda"):

monai_transforms = [
Expand Down
96 changes: 96 additions & 0 deletions scripts/create_difficult_subjects_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
This script takes as input a dataset, a list of difficult subjects, and contrasts to create a folder with them.
It does the following:
1. createa a folder "difficult-cases" in output path, copies the subjects to this folder
2. outputs a yaml file with the list of subjects segregated based on the datasets
NOTE: this script keeps the BIDS folder structure intact so that it could be used with sct_run_batch

The idea is to build a dataset of difficult subjects for benchmarking our segmentation models (current ones and new ones we develop)

Author: Naga Karthik

"""

import os
import yaml
import argparse
import glob
import subprocess

def get_parser():

parser = argparse.ArgumentParser(description='Create a list of difficult subjects')
parser.add_argument('--dataset', type=str, required=True,
help='root path to the dataset containing the subjects')
parser.add_argument('--include', type=str, required=True, nargs='+',
help='list of difficult subjects to be included')
parser.add_argument('--contrasts', type=str, required=True, nargs='+',
help='list of contrasts to be copied for each subject. If "all" is provided, all files will be copied')
parser.add_argument('--path-out', type=str, required=True,
help='path to the output directory where the folder "difficult-cases" will be created')
return parser

def main():

args = get_parser().parse_args()

path_dataset = args.dataset
contrasts = args.contrasts
path_out = os.path.join(args.path_out, 'difficult-cases-temp')
if not os.path.exists(path_out):
print(f'Creating folder at: {path_out}')
os.makedirs(path_out, exist_ok=True)
else:
print(f'Folder already exists! Adding subjects to: {path_out}')

# check if a yaml file already exists and load it into the dictionary
if os.path.exists(os.path.join(path_out, 'difficult_cases.yaml')):
with open(os.path.join(path_out, 'difficult_cases.yaml'), 'r') as file:
difficult_cases_dict = yaml.load(file, Loader=yaml.FullLoader)
else:
difficult_cases_dict = {}

# loop through all subjects
for subject_id in args.include:

dataset = os.path.basename(path_dataset)
if dataset not in difficult_cases_dict:
difficult_cases_dict[dataset] = []

subject_path = os.path.join(path_dataset, subject_id)

for contrast in contrasts:

if contrast == 'all':
# find all image files for all contrasts
files = subprocess.run(f'find {subject_path} -name "*.nii.gz"', shell=True, capture_output=True, text=True).stdout.split('\n')
else:
# find all image files for the contrast
files = subprocess.run(f'find {subject_path} -name "*{contrast}*.nii.gz"', shell=True, capture_output=True, text=True).stdout.split('\n')

if len(files) == 1 and not files[0]:
print(f'No files found for {subject_id} contrast {contrast}')
else:
# get the relative path between subject_id and file
for file in files:
if file:
relative_path = os.path.relpath(file, subject_path)
rel_path_with_subject = os.path.join(subject_id, relative_path)
if rel_path_with_subject in difficult_cases_dict[dataset]:
print(f'Subject {subject_id} contrast {contrast} already exists in the difficult cases list')
continue
# copy the file to the output directory
output_path = os.path.join(path_out, rel_path_with_subject)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
print(f'Copying {file} to {output_path}')
subprocess.run(f'cp {file} {output_path}', shell=True)
difficult_cases_dict[dataset].append(rel_path_with_subject)

# save the difficult cases list
with open(os.path.join(path_out, 'difficult_cases.yaml'), 'w') as file:
documents = yaml.dump(difficult_cases_dict, file)

if __name__ == '__main__':
main()