Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions src/scilpy/cli/scil_fodf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from dipy.data import get_sphere
from dipy.direction.peaks import reshape_peaks_for_visualization

from scilpy.io.image import get_data_as_mask
from scilpy.io.image import get_data_as_mask, load_nifti_reorient, save_nifti_reorient
from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args,
add_processes_arg, add_verbose_arg,
assert_inputs_exist, assert_outputs_exist,
Expand Down Expand Up @@ -140,10 +140,9 @@ def main():
assert_headers_compatible(parser, args.in_fODF, args.mask)

# Loading
vol = nib.load(args.in_fODF)
vol, flip_vector = load_nifti_reorient(args.in_fODF, return_flip_vector=True)
data = vol.get_fdata(dtype=np.float32)
affine = vol.affine
mask = get_data_as_mask(nib.load(args.mask),
mask = get_data_as_mask(load_nifti_reorient(args.mask),
dtype=bool) if args.mask else None

sphere = get_sphere(name=args.sphere)
Expand All @@ -169,27 +168,28 @@ def main():
sphere, nbr_processes=args.nbr_processes)

# Save result
affine = vol.affine
if args.nufo:
nib.save(nib.Nifti1Image(nufo_map.astype(np.float32), affine),
args.nufo)
out_img = nib.Nifti1Image(nufo_map.astype(np.float32), affine)
save_nifti_reorient(out_img, flip_vector, args.nufo)

if args.afd_max:
nib.save(nib.Nifti1Image(afd_max.astype(np.float32), affine),
args.afd_max)
out_img = nib.Nifti1Image(afd_max.astype(np.float32), affine)
save_nifti_reorient(out_img, flip_vector, args.afd_max)

if args.afd_total:
# this is the analytical afd total
afd_tot = data[:, :, :, 0]
nib.save(nib.Nifti1Image(afd_tot.astype(np.float32), affine),
args.afd_total)
out_img = nib.Nifti1Image(afd_tot.astype(np.float32), affine)
save_nifti_reorient(out_img, flip_vector, args.afd_total)

if args.afd_sum:
nib.save(nib.Nifti1Image(afd_sum.astype(np.float32), affine),
args.afd_sum)
out_img = nib.Nifti1Image(afd_sum.astype(np.float32), affine)
save_nifti_reorient(out_img, flip_vector, args.afd_sum)

if args.rgb:
nib.save(nib.Nifti1Image(rgb_map.astype('uint8'), affine),
args.rgb)
out_img = nib.Nifti1Image(rgb_map.astype('uint8'), affine)
save_nifti_reorient(out_img, flip_vector, args.rgb)

if args.peaks or args.peak_values:
if not args.abs_peaks_and_values:
Expand All @@ -198,15 +198,17 @@ def main():
where=peak_values[..., 0, None] != 0)
peak_dirs[...] *= peak_values[..., :, None]
if args.peaks:
nib.save(nib.Nifti1Image(
out_img = nib.Nifti1Image(
reshape_peaks_for_visualization(peak_dirs),
affine), args.peaks)
affine)
save_nifti_reorient(out_img, flip_vector, args.peaks)
if args.peak_values:
nib.save(nib.Nifti1Image(peak_values, vol.affine),
args.peak_values)
out_img = nib.Nifti1Image(peak_values, vol.affine)
save_nifti_reorient(out_img, flip_vector, args.peak_values)

if args.peak_indices:
nib.save(nib.Nifti1Image(peak_indices, vol.affine), args.peak_indices)
out_img = nib.Nifti1Image(peak_indices, vol.affine)
save_nifti_reorient(out_img, flip_vector, args.peak_indices)


if __name__ == "__main__":
Expand Down
13 changes: 8 additions & 5 deletions src/scilpy/cli/scil_fodf_ssst.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from scilpy.gradients.bvec_bval_tools import (check_b0_threshold,
normalize_bvecs,
is_normalized_bvecs)
from scilpy.io.image import get_data_as_mask
from scilpy.io.image import get_data_as_mask, load_nifti_reorient, save_nifti_reorient
from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg,
add_processes_arg, add_sh_basis_args,
add_skip_b0_check_arg, add_verbose_arg,
Expand Down Expand Up @@ -79,12 +79,14 @@ def main():

# Loading data
full_frf = np.loadtxt(args.frf_file)
vol = nib.load(args.in_dwi)
vol, flip_vector = load_nifti_reorient(args.in_dwi, return_flip_vector=True)
data = vol.get_fdata(dtype=np.float32)

# Loading bvals and bvecs and flipping signs for RAS orientation
bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)

# Checking mask
mask = get_data_as_mask(nib.load(args.mask),
mask = get_data_as_mask(load_nifti_reorient(args.mask),
dtype=bool) if args.mask else None

sh_order = args.sh_order
Expand Down Expand Up @@ -136,9 +138,10 @@ def main():
is_input_legacy=True,
is_output_legacy=is_legacy,
nbr_processes=args.nbr_processes)
nib.save(nib.Nifti1Image(shm_coeff.astype(np.float32),
out_img = nib.Nifti1Image(shm_coeff.astype(np.float32),
affine=vol.affine,
header=vol.header), args.out_fODF)
header=vol.header)
save_nifti_reorient(out_img, flip_vector, args.out_fODF)


if __name__ == "__main__":
Expand Down
20 changes: 13 additions & 7 deletions src/scilpy/cli/scil_tracking_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from dipy.tracking import utils as track_utils
from dipy.tracking.local_tracking import LocalTracking
from dipy.tracking.stopping_criterion import BinaryStoppingCriterion
from scilpy.io.image import get_data_as_mask
from scilpy.io.image import get_data_as_mask, load_nifti_reorient
from scilpy.io.utils import (add_sphere_arg, add_verbose_arg,
assert_headers_compatible, assert_inputs_exist,
assert_outputs_exist, parse_sh_basis_arg,
Expand Down Expand Up @@ -189,14 +189,19 @@ def main():
# when providing information to dipy (i.e. working as if in voxel space)
# will not yield correct results. Tracking is performed in voxel space
# in both the GPU and CPU cases.
odf_sh_img = nib.load(args.in_odf)
# odf_sh_img = nib.load(args.in_odf)
odf_sh_img, flip_vector = load_nifti_reorient(args.in_odf, return_flip_vector=True)
odf_data = odf_sh_img.get_fdata(dtype=np.float32)

if not np.allclose(np.mean(odf_sh_img.header.get_zooms()[:3]),
odf_sh_img.header.get_zooms()[0], atol=1e-03):
parser.error(
'ODF SH file is not isotropic. Tracking cannot be ran robustly.')

logging.debug("Loading masks and finding seeds.")
mask_data = get_data_as_mask(nib.load(args.in_mask), dtype=bool)
# mask_img = nib.load(args.in_mask)
mask_img = load_nifti_reorient(args.in_mask)
mask_data = get_data_as_mask(mask_img, dtype=bool)

if args.npv:
nb_seeds = args.npv
Expand All @@ -210,7 +215,8 @@ def main():

voxel_size = odf_sh_img.header.get_zooms()[0]
vox_step_size = args.step_size / voxel_size
seed_img = nib.load(args.in_seed)
# seed_img = nib.load(args.in_seed)
seed_img = load_nifti_reorient(args.in_seed)

sh_basis, is_legacy = parse_sh_basis_arg(args)

Expand Down Expand Up @@ -241,9 +247,9 @@ def main():
logging.info("Starting CPU local tracking.")
streamlines_generator = LocalTracking(
get_direction_getter(
args.in_odf, args.algo, args.sphere,
odf_data, args.algo, args.sphere,
args.sub_sphere, args.theta, sh_basis,
voxel_size, args.sf_threshold, args.sh_to_pmf,
args.sf_threshold, args.sh_to_pmf,
args.probe_length, args.probe_radius,
args.probe_quality, args.probe_count,
args.support_exponent, is_legacy=is_legacy),
Expand Down Expand Up @@ -284,7 +290,7 @@ def main():
save_tractogram(streamlines_generator, tracts_format,
odf_sh_img, total_nb_seeds, args.out_tractogram,
args.min_length, args.max_length, args.compress_th,
args.save_seeds, args.verbose)
args.save_seeds, args.verbose, flip_vector=flip_vector)
# Final logging
logging.info('Saved tractogram to {0}.'.format(args.out_tractogram))

Expand Down
16 changes: 8 additions & 8 deletions src/scilpy/cli/scil_viz_fodf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
assert_outputs_exist,
parse_sh_basis_arg,
assert_headers_compatible)
from scilpy.io.image import assert_same_resolution, get_data_as_mask
from scilpy.io.image import assert_same_resolution, get_data_as_mask, load_nifti_reorient
from scilpy.utils.spatial import RAS_AXES_NAMES
from scilpy.version import version_string
from scilpy.viz.backends.fury import (create_interactive_window,
Expand Down Expand Up @@ -220,7 +220,7 @@ def _get_data_from_inputs(args):
Load data given by args. Perform checks to ensure dimensions agree
between the data for mask, background, peaks and fODF.
"""
fodf = nib.load(args.in_fodf).get_fdata(dtype=np.float32)
fodf = load_nifti_reorient(args.in_fodf).get_fdata(dtype=np.float32)

# Optional:
bg = None
Expand All @@ -231,16 +231,16 @@ def _get_data_from_inputs(args):
variance = None
if args.background:
assert_same_resolution([args.background, args.in_fodf])
bg = nib.load(args.background).get_fdata()
bg = load_nifti_reorient(args.background).get_fdata()
if args.in_transparency_mask:
transparency_mask = get_data_as_mask(
nib.load(args.in_transparency_mask), dtype=bool)
load_nifti_reorient(args.in_transparency_mask), dtype=bool)
if args.mask:
assert_same_resolution([args.mask, args.in_fodf])
mask = get_data_as_mask(nib.load(args.mask), dtype=bool)
mask = get_data_as_mask(load_nifti_reorient(args.mask), dtype=bool)
if args.peaks:
assert_same_resolution([args.peaks, args.in_fodf])
peaks = nib.load(args.peaks).get_fdata()
peaks = load_nifti_reorient(args.peaks).get_fdata()
if len(peaks.shape) == 4:
last_dim = peaks.shape[-1]
if last_dim % 3 == 0:
Expand All @@ -253,10 +253,10 @@ def _get_data_from_inputs(args):
if args.peaks_values:
assert_same_resolution([args.peaks_values, args.in_fodf])
peak_vals =\
nib.load(args.peaks_values).get_fdata()
load_nifti_reorient(args.peaks_values).get_fdata()
if args.variance:
assert_same_resolution([args.variance, args.in_fodf])
variance = nib.load(args.variance).get_fdata(dtype=np.float32)
variance = load_nifti_reorient(args.variance).get_fdata(dtype=np.float32)
if len(variance.shape) == 3:
variance = np.reshape(variance, variance.shape + (1,))
if variance.shape != fodf.shape:
Expand Down
30 changes: 30 additions & 0 deletions src/scilpy/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,36 @@
from scilpy.utils import is_float


def load_nifti_reorient(file_path, return_flip_vector=False):
vol = nib.load(file_path)

# Compute the image orientation (axis codes)
axcodes = nib.orientations.aff2axcodes(vol.affine)
target = ('R', 'A', 'S')
flip_vector = [1 if axcodes[i] == target[i] else -1 for i in range(3)]

ras_order = [[i, flip_vector[i]] for i in range(0, 3)]
filename = vol.get_filename()
vol = vol.as_reoriented(ras_order)
vol.set_filename(filename)

if return_flip_vector:
return vol, flip_vector
return vol


def nifti_reorient(img, flip_vector):
original_order = [[i, flip_vector[i]] for i in range(0, 3)]
img = img.as_reoriented(original_order)
return img


def save_nifti_reorient(img, flip_vector, file_path):
original_order = [[i, flip_vector[i]] for i in range(0, 3)]
img = img.as_reoriented(original_order)
nib.save(img, file_path)


def load_img(arg):
if is_float(arg):
img = float(arg)
Expand Down
18 changes: 11 additions & 7 deletions src/scilpy/tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dipy.io.utils import create_tractogram_header, get_reference_info
from dipy.reconst.shm import sh_to_sf_matrix
from dipy.tracking.streamlinespeed import compress_streamlines, length
from scilpy.io.image import nifti_reorient
from scilpy.io.utils import (add_compression_arg, add_overwrite_arg,
add_sh_basis_args)
from scilpy.reconst.utils import find_order_from_nb_coeff, get_maximas
Expand Down Expand Up @@ -179,7 +180,8 @@ def tqdm_if_verbose(generator: Iterable, verbose: bool, *args, **kwargs):

def save_tractogram(
streamlines_generator, tracts_format, ref_img, total_nb_seeds,
out_tractogram, min_length, max_length, compress, save_seeds, verbose
out_tractogram, min_length, max_length, compress, save_seeds,
verbose, flip_vector=None,
):
""" Save the streamlines on-the-fly using a generator. Tracts are
filtered according to their length and compressed if requested. Seeds
Expand Down Expand Up @@ -209,8 +211,12 @@ def save_tractogram(
data_per_streamline property.
verbose : bool
If True, display progression bar.
flip_vector : np.ndarray, shape (3,)
If provided, the reference image will be flipped according to this vector.
"""
if flip_vector is not None:
ref_img = nifti_reorient(ref_img, flip_vector)

voxel_size = ref_img.header.get_zooms()[0]

Expand Down Expand Up @@ -265,16 +271,16 @@ def tracks_generator_wrapper():
nib.streamlines.save(tractogram, out_tractogram, header=header)


def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis,
voxel_size, sf_threshold, sh_to_pmf,
def get_direction_getter(img_data, algo, sphere, sub_sphere, theta, sh_basis,
sf_threshold, sh_to_pmf,
probe_length, probe_radius, probe_quality,
probe_count, support_exponent, is_legacy=True):
""" Return the direction getter object.
Parameters
----------
in_img: str
Path to the input odf file.
img_data: np.ndarray
The input odf data as a numpy array (float32).
algo: str
Algorithm to use for tracking. Can be 'det', 'prob', 'ptt' or 'eudx'.
sphere: str
Expand Down Expand Up @@ -319,8 +325,6 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis,
dg: dipy.direction.DirectionGetter
The direction getter object.
"""
img_data = nib.load(in_img).get_fdata(dtype=np.float32)

sphere = HemiSphere.from_sphere(
get_sphere(name=sphere)).subdivide(n=sub_sphere)

Expand Down