Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions pipe_random_affine_transform_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/bin/bash
#
# Apply random affine transformations to all the T2w volumes of a dataset following BIDS convention
#
# Usage:
# sct_run_batch -jobs 1 -path-data bids_dataset -path-out res_registration -script pipe_random_affine_transform_dataset.sh
#
#
# Author: Evan Béal

set -x
# Immediately exit if error
set -e -o pipefail

# Exit if user presses CTRL+C (Linux, OSX)
trap "echo Caught Keyboard Interrupt within script. Exiting now.; exit" INT

# Retrieve input params
SUBJECT=$1

# Save script path
PATH_SCRIPT=$PWD

# get starting time:
start=`date +%s`

# SCRIPT STARTS HERE
# ==============================================================================
# Display useful info for the log, such as SCT version, RAM and CPU cores available
sct_check_dependencies -short

# SUBJECT_ID=$(dirname "$SUBJECT")
SES=$(basename "$SUBJECT")

# Choose whether to keep original naming and location of input volumes for the transformed volumes.
KEEP_ORI_NAMING_LOC=1

# Go to folder where data will be copied and processed
cd ${PATH_DATA_PROCESSED}
# Copy source images
mkdir -p ${SUBJECT}
rsync -avzh $PATH_DATA/$SUBJECT/ ${SUBJECT}
# Go to anat folder where all structural data are located
echo $PWD
cd ${SUBJECT}/anat/

file_mov_before_aff_transfo="${SES}_T2w"

CONDA_BASE=$(conda info --base)
source $CONDA_BASE/etc/profile.d/conda.sh
conda activate smenv
# Transform the volume
python $PATH_SCRIPT/random_affine_transform_dataset.py --mov-img-path $file_mov_before_aff_transfo --sub-id ${SES} --out-file $PATH_DATA_PROCESSED/summary_transform.csv
conda deactivate

file_mov_transformed="${file_mov_before_aff_transfo}_aff_transformed"

if [ $KEEP_ORI_NAMING_LOC == 1 ]
then
rm -rf "${PATH_DATA_PROCESSED}/${SUBJECT}/anat/${file_mov_before_aff_transfo}.nii.gz"
mv "${PATH_DATA_PROCESSED}/${SUBJECT}/anat/${file_mov_transformed}.nii.gz" "${PATH_DATA_PROCESSED}/${SUBJECT}/anat/${file_mov_before_aff_transfo}.nii.gz"
fi

# Display useful info for the log
end=`date +%s`
runtime=$((end-start))
echo
echo "~~~"
echo "SCT version: `sct_version`"
echo "Ran on: `uname -nsr`"
echo "Duration: $(($runtime / 3600))hrs $((($runtime / 60) % 60))min $(($runtime % 60))sec"
echo "~~~"
72 changes: 72 additions & 0 deletions pipe_random_deformable_transform_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/bin/bash
#
# Apply random non-linear transformations to all the T2w volumes of a dataset following BIDS convention
#
# Usage:
# sct_run_batch -jobs 1 -path-data bids_dataset -path-out res_registration -script pipe_random_deformable_transform_dataset.sh
#
#
# Author: Evan Béal

set -x
# Immediately exit if error
set -e -o pipefail

# Exit if user presses CTRL+C (Linux, OSX)
trap "echo Caught Keyboard Interrupt within script. Exiting now.; exit" INT

# Retrieve input params
SUBJECT=$1

# Save script path
PATH_SCRIPT=$PWD

# get starting time:
start=`date +%s`

# SCRIPT STARTS HERE
# ==============================================================================
# Display useful info for the log, such as SCT version, RAM and CPU cores available
sct_check_dependencies -short

# SUBJECT_ID=$(dirname "$SUBJECT")
SES=$(basename "$SUBJECT")

# Choose whether to keep original naming and location of input volumes for the transformed volumes.
KEEP_ORI_NAMING_LOC=1

# Go to folder where data will be copied and processed
cd ${PATH_DATA_PROCESSED}
# Copy source images
mkdir -p ${SUBJECT}
rsync -avzh $PATH_DATA/$SUBJECT/ ${SUBJECT}
# Go to anat folder where all structural data are located
echo $PWD
cd ${SUBJECT}/anat/

file_mov_before_def_transfo="${SES}_T2w"

CONDA_BASE=$(conda info --base)
source $CONDA_BASE/etc/profile.d/conda.sh
conda activate smenv
# Transform the volume
python $PATH_SCRIPT/random_deformable_transform_dataset.py --mov-img-path $file_mov_before_def_transfo --sub-id ${SES} --out-file $PATH_DATA_PROCESSED/summary_transform.csv
conda deactivate

file_mov_transformed="${file_mov_before_def_transfo}_def_transformed"

if [ $KEEP_ORI_NAMING_LOC == 1 ]
then
rm -rf "${PATH_DATA_PROCESSED}/${SUBJECT}/anat/${file_mov_before_def_transfo}.nii.gz"
mv "${PATH_DATA_PROCESSED}/${SUBJECT}/anat/${file_mov_transformed}.nii.gz" "${PATH_DATA_PROCESSED}/${SUBJECT}/anat/${file_mov_before_def_transfo}.nii.gz"
fi

# Display useful info for the log
end=`date +%s`
runtime=$((end-start))
echo
echo "~~~"
echo "SCT version: `sct_version`"
echo "Ran on: `uname -nsr`"
echo "Duration: $(($runtime / 3600))hrs $((($runtime / 60) % 60))min $(($runtime % 60))sec"
echo "~~~"
124 changes: 124 additions & 0 deletions random_affine_transform_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
File to randomly affine transform the volumes
"""

import os
import argparse

import numpy as np
import nibabel as nib

from scipy.ndimage.interpolation import affine_transform
import csv
import datetime


def random_affine_transform(im, sub_id, out_file):

import math
import random

angle_degree = 5 # 5
# Get the random angle
angle_d = np.random.uniform(- angle_degree, angle_degree)
angle = math.radians(angle_d)
# Get the two axes that define the plane of rotation
axes = list(random.sample(range(3), 2))
axes.sort()

scale_factor = 0.05 # 0.05
# Scale
scale_axis = random.uniform(1 - scale_factor, 1 + scale_factor)

# Get params
data_shape = im.shape
translation = [0.05, 0.05, 0.05] # 0.05, 0.05, 0.05
max_dx = translation[0] * data_shape[0]
max_dy = translation[1] * data_shape[1]
max_dz = translation[2] * data_shape[2]
translations = (np.round(np.random.uniform(-max_dx, max_dx)),
np.round(np.random.uniform(-max_dy, max_dy)),
np.round(np.random.uniform(-max_dz, max_dz)))

# Do rotation
shape = 0.5 * np.array(data_shape)
if axes == [0, 1]:
rotate = np.array([[math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0],
[0, 0, 1]])
elif axes == [0, 2]:
rotate = np.array([[math.cos(angle), 0, math.sin(angle)],
[0, 1, 0],
[-math.sin(angle), 0, math.cos(angle)]])
elif axes == [1, 2]:
rotate = np.array([[1, 0, 0],
[0, math.cos(angle), -math.sin(angle)],
[0, math.sin(angle), math.cos(angle)]])
else:
raise ValueError("Unknown axes value")

scale = np.array([[1 / scale_axis, 0, 0], [0, 1 / scale_axis, 0], [0, 0, 1 / scale_axis]])
transforms = scale.dot(rotate)

offset = shape - shape.dot(transforms) + translations

data_out = affine_transform(im, transforms.T, order=1, offset=offset,
output_shape=data_shape).astype(im.dtype)

summary_transfo = dict()
summary_transfo['subject'] = sub_id
summary_transfo['rotation_angle_degree'] = angle_d
summary_transfo['rotation_axes'] = axes
summary_transfo['scaling'] = scale_axis
summary_transfo['translation'] = translations
summary_transfo['im_shape'] = im.shape

# write header
if not os.path.isfile(out_file):
with open(out_file, 'w') as csvfile:
header = ['Timestamp', 'Subject', 'rotation_angle_degree', 'rotation_axes', 'scaling', 'translation', 'im_shape']
writer = csv.DictWriter(csvfile, fieldnames=header)
writer.writeheader()

# populate data
with open(out_file, 'a') as csvfile:
spamwriter = csv.writer(csvfile, delimiter=',')
line = list()
line.append(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) # Timestamp
for val in summary_transfo.keys():
line.append(str(summary_transfo[val]))
spamwriter.writerow(line)

return data_out


def run_main(mov_im_path, sub_id, out_file):
"""
Transform the moving volume
"""

moving_nii = nib.load(f'{mov_im_path}.nii.gz')

mov_data = moving_nii.get_fdata()
affine_transform_mov_data = random_affine_transform(mov_data, sub_id, out_file)
mov_resampled_nii = nib.Nifti1Image(affine_transform_mov_data, moving_nii.affine)

nib.save(mov_resampled_nii, os.path.join(f'{mov_im_path}_aff_transformed.nii.gz'))


if __name__ == "__main__":

# parse the commandline
parser = argparse.ArgumentParser()

# parameters to be specified by the user
parser.add_argument('--mov-img-path', required=True, help='path to the moving image')

parser.add_argument('--sub-id', required=True, help='Subject ID')
parser.add_argument('--out-file', required=False, default='summary_transform.csv',
help='path to the output csv summarizing the affine transform applied')

args = parser.parse_args()

run_main(args.mov_img_path, args.sub_id, args.out_file)

92 changes: 92 additions & 0 deletions random_deformable_transform_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
File to randomly transform the volumes using warping field generated from noise distribution
"""

import os
import argparse

import numpy as np
import nibabel as nib
import neurite as ne
import voxelmorph as vxm
import tensorflow.keras.backend as K

import csv
import datetime

if __name__ == "__main__":

# parse the commandline
parser = argparse.ArgumentParser()

# parameters to be specified by the user
parser.add_argument('--mov-img-path', required=True, help='path to the moving image')

parser.add_argument('--sub-id', required=True, help='Subject ID')
parser.add_argument('--out-file', required=False, default='summary_transform.csv',
help='path to the output csv summarizing the affine transform applied')

args = parser.parse_args()

# -------------------------------------------------------------------------------------------------------- #
# ---- LOADING THE VOLUME AND GETTING THE ASSOCIATED AFFINE MATRIX ---- #
# -------------------------------------------------------------------------------------------------------- #

im = nib.load(f"{args.mov_img_path}.nii.gz")
affine = im.affine

# -------------------------------------------------------------------------------------------------------- #
# ---- GENERATING THE DEFORMATION FIELD ---- #
# -------------------------------------------------------------------------------------------------------- #

def_field, std = ne.utils.augment.draw_perlin(out_shape=(im.shape[0], im.shape[1], im.shape[2], 1, 3),
scales=[16], max_std=2.5)

def_field2, std2 = ne.utils.augment.draw_perlin(out_shape=(im.shape[0], im.shape[1], im.shape[2], 1, 3),
scales=[32, 64], max_std=5)

warp = vxm.utils.compose([K.constant(def_field[..., 0, :]), K.constant(def_field2[..., 0, :])])
warp_data = K.eval(warp)
def_field_nii = nib.Nifti1Image(np.array(warp_data), affine=affine)

out_def_path = f"{args.mov_img_path}_warp_to_transform.nii.gz"
nib.save(def_field_nii, out_def_path)

# -------------------------------------------------------------------------------------------------------- #
# ---- APPLYING THE DEFORMATION FIELD TO THE IMAGE TO PRODUCE THE MOVED IMAGE ---- #
# -------------------------------------------------------------------------------------------------------- #

moving = vxm.py.utils.load_volfile(f"{args.mov_img_path}.nii.gz", add_batch_axis=True, add_feat_axis=True)
deform = vxm.py.utils.load_volfile(out_def_path, add_batch_axis=True, ret_affine=True)

moved = vxm.networks.Transform(moving.shape[1:-1],
interp_method='linear',
nb_feats=moving.shape[-1]).predict([moving, deform[0]])

# save moved image
out_im_path = f"{args.mov_img_path}_def_transformed.nii.gz"
vxm.py.utils.save_volfile(moved.squeeze(), out_im_path, affine)

# os.remove(out_def_path)

summary_transfo = dict()
summary_transfo['subject'] = args.sub_id
summary_transfo['std_for_scale_16'] = std[0]
summary_transfo['std_for_scale_32'] = std2[0]
summary_transfo['std_for_scale_64'] = std2[1]

# write header
if not os.path.isfile(args.out_file):
with open(args.out_file, 'w') as csvfile:
header = ['Timestamp', 'Subject', 'std_for_scale_16', 'std_for_scale_32', 'std_for_scale_64']
writer = csv.DictWriter(csvfile, fieldnames=header)
writer.writeheader()

# populate data
with open(args.out_file, 'a') as csvfile:
spamwriter = csv.writer(csvfile, delimiter=',')
line = list()
line.append(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) # Timestamp
for val in summary_transfo.keys():
line.append(str(summary_transfo[val]))
spamwriter.writerow(line)