diff --git a/src/scilpy/cli/scil_fodf_metrics.py b/src/scilpy/cli/scil_fodf_metrics.py index 92df1b578..2684c473c 100755 --- a/src/scilpy/cli/scil_fodf_metrics.py +++ b/src/scilpy/cli/scil_fodf_metrics.py @@ -39,7 +39,7 @@ from dipy.data import get_sphere from dipy.direction.peaks import reshape_peaks_for_visualization -from scilpy.io.image import get_data_as_mask +from scilpy.io.image import get_data_as_mask, load_nifti_reorient, save_nifti_reorient from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_processes_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, @@ -138,10 +138,9 @@ def main(): assert_headers_compatible(parser, args.in_fODF, args.mask) # Loading - vol = nib.load(args.in_fODF) + vol, flip_vector = load_nifti_reorient(args.in_fODF, return_flip_vector=True) data = vol.get_fdata(dtype=np.float32) - affine = vol.affine - mask = get_data_as_mask(nib.load(args.mask), + mask = get_data_as_mask(load_nifti_reorient(args.mask), dtype=bool) if args.mask else None sphere = get_sphere(name=args.sphere) @@ -167,27 +166,28 @@ def main(): sphere, nbr_processes=args.nbr_processes) # Save result + affine = vol.affine if args.nufo: - nib.save(nib.Nifti1Image(nufo_map.astype(np.float32), affine), - args.nufo) + out_img = nib.Nifti1Image(nufo_map.astype(np.float32), affine) + save_nifti_reorient(out_img, flip_vector, args.nufo) if args.afd_max: - nib.save(nib.Nifti1Image(afd_max.astype(np.float32), affine), - args.afd_max) + out_img = nib.Nifti1Image(afd_max.astype(np.float32), affine) + save_nifti_reorient(out_img, flip_vector, args.afd_max) if args.afd_total: # this is the analytical afd total afd_tot = data[:, :, :, 0] - nib.save(nib.Nifti1Image(afd_tot.astype(np.float32), affine), - args.afd_total) + out_img = nib.Nifti1Image(afd_tot.astype(np.float32), affine) + save_nifti_reorient(out_img, flip_vector, args.afd_total) if args.afd_sum: - nib.save(nib.Nifti1Image(afd_sum.astype(np.float32), affine), - args.afd_sum) + out_img = nib.Nifti1Image(afd_sum.astype(np.float32), affine) + save_nifti_reorient(out_img, flip_vector, args.afd_sum) if args.rgb: - nib.save(nib.Nifti1Image(rgb_map.astype('uint8'), affine), - args.rgb) + out_img = nib.Nifti1Image(rgb_map.astype('uint8'), affine) + save_nifti_reorient(out_img, flip_vector, args.rgb) if args.peaks or args.peak_values: if not args.abs_peaks_and_values: @@ -196,15 +196,17 @@ def main(): where=peak_values[..., 0, None] != 0) peak_dirs[...] *= peak_values[..., :, None] if args.peaks: - nib.save(nib.Nifti1Image( + out_img = nib.Nifti1Image( reshape_peaks_for_visualization(peak_dirs), - affine), args.peaks) + affine) + save_nifti_reorient(out_img, flip_vector, args.peaks) if args.peak_values: - nib.save(nib.Nifti1Image(peak_values, vol.affine), - args.peak_values) + out_img = nib.Nifti1Image(peak_values, vol.affine) + save_nifti_reorient(out_img, flip_vector, args.peak_values) if args.peak_indices: - nib.save(nib.Nifti1Image(peak_indices, vol.affine), args.peak_indices) + out_img = nib.Nifti1Image(peak_indices, vol.affine) + save_nifti_reorient(out_img, flip_vector, args.peak_indices) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_fodf_ssst.py b/src/scilpy/cli/scil_fodf_ssst.py index 73a420e08..7618f03fa 100755 --- a/src/scilpy/cli/scil_fodf_ssst.py +++ b/src/scilpy/cli/scil_fodf_ssst.py @@ -20,7 +20,7 @@ from scilpy.gradients.bvec_bval_tools import (check_b0_threshold, normalize_bvecs, is_normalized_bvecs) -from scilpy.io.image import get_data_as_mask +from scilpy.io.image import get_data_as_mask, load_nifti_reorient, save_nifti_reorient from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, add_processes_arg, add_sh_basis_args, add_skip_b0_check_arg, add_verbose_arg, @@ -77,12 +77,14 @@ def main(): # Loading data full_frf = np.loadtxt(args.frf_file) - vol = nib.load(args.in_dwi) + vol, flip_vector = load_nifti_reorient(args.in_dwi, return_flip_vector=True) data = vol.get_fdata(dtype=np.float32) + + # Loading bvals and bvecs and flipping signs for RAS orientation bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) # Checking mask - mask = get_data_as_mask(nib.load(args.mask), + mask = get_data_as_mask(load_nifti_reorient(args.mask), dtype=bool) if args.mask else None sh_order = args.sh_order @@ -134,9 +136,10 @@ def main(): is_input_legacy=True, is_output_legacy=is_legacy, nbr_processes=args.nbr_processes) - nib.save(nib.Nifti1Image(shm_coeff.astype(np.float32), + out_img = nib.Nifti1Image(shm_coeff.astype(np.float32), affine=vol.affine, - header=vol.header), args.out_fODF) + header=vol.header) + save_nifti_reorient(out_img, flip_vector, args.out_fODF) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_tracking_local.py b/src/scilpy/cli/scil_tracking_local.py index fed1ab1d9..b0f95835e 100755 --- a/src/scilpy/cli/scil_tracking_local.py +++ b/src/scilpy/cli/scil_tracking_local.py @@ -66,7 +66,7 @@ from dipy.tracking import utils as track_utils from dipy.tracking.local_tracking import LocalTracking from dipy.tracking.stopping_criterion import BinaryStoppingCriterion -from scilpy.io.image import get_data_as_mask +from scilpy.io.image import get_data_as_mask, load_nifti_reorient from scilpy.io.utils import (add_sphere_arg, add_verbose_arg, assert_headers_compatible, assert_inputs_exist, assert_outputs_exist, parse_sh_basis_arg, @@ -187,14 +187,19 @@ def main(): # when providing information to dipy (i.e. working as if in voxel space) # will not yield correct results. Tracking is performed in voxel space # in both the GPU and CPU cases. - odf_sh_img = nib.load(args.in_odf) + # odf_sh_img = nib.load(args.in_odf) + odf_sh_img, flip_vector = load_nifti_reorient(args.in_odf, return_flip_vector=True) + odf_data = odf_sh_img.get_fdata(dtype=np.float32) + if not np.allclose(np.mean(odf_sh_img.header.get_zooms()[:3]), odf_sh_img.header.get_zooms()[0], atol=1e-03): parser.error( 'ODF SH file is not isotropic. Tracking cannot be ran robustly.') logging.debug("Loading masks and finding seeds.") - mask_data = get_data_as_mask(nib.load(args.in_mask), dtype=bool) + # mask_img = nib.load(args.in_mask) + mask_img = load_nifti_reorient(args.in_mask) + mask_data = get_data_as_mask(mask_img, dtype=bool) if args.npv: nb_seeds = args.npv @@ -208,7 +213,8 @@ def main(): voxel_size = odf_sh_img.header.get_zooms()[0] vox_step_size = args.step_size / voxel_size - seed_img = nib.load(args.in_seed) + # seed_img = nib.load(args.in_seed) + seed_img = load_nifti_reorient(args.in_seed) sh_basis, is_legacy = parse_sh_basis_arg(args) @@ -239,9 +245,9 @@ def main(): logging.info("Starting CPU local tracking.") streamlines_generator = LocalTracking( get_direction_getter( - args.in_odf, args.algo, args.sphere, + odf_data, args.algo, args.sphere, args.sub_sphere, args.theta, sh_basis, - voxel_size, args.sf_threshold, args.sh_to_pmf, + args.sf_threshold, args.sh_to_pmf, args.probe_length, args.probe_radius, args.probe_quality, args.probe_count, args.support_exponent, is_legacy=is_legacy), @@ -282,7 +288,7 @@ def main(): save_tractogram(streamlines_generator, tracts_format, odf_sh_img, total_nb_seeds, args.out_tractogram, args.min_length, args.max_length, args.compress_th, - args.save_seeds, args.verbose) + args.save_seeds, args.verbose, flip_vector=flip_vector) # Final logging logging.info('Saved tractogram to {0}.'.format(args.out_tractogram)) diff --git a/src/scilpy/cli/scil_viz_fodf.py b/src/scilpy/cli/scil_viz_fodf.py index 753eae950..9a4ca9323 100755 --- a/src/scilpy/cli/scil_viz_fodf.py +++ b/src/scilpy/cli/scil_viz_fodf.py @@ -35,7 +35,7 @@ assert_outputs_exist, parse_sh_basis_arg, assert_headers_compatible) -from scilpy.io.image import assert_same_resolution, get_data_as_mask +from scilpy.io.image import assert_same_resolution, get_data_as_mask, load_nifti_reorient from scilpy.utils.spatial import RAS_AXES_NAMES from scilpy.version import version_string from scilpy.viz.backends.fury import (create_interactive_window, @@ -220,7 +220,7 @@ def _get_data_from_inputs(args): Load data given by args. Perform checks to ensure dimensions agree between the data for mask, background, peaks and fODF. """ - fodf = nib.load(args.in_fodf).get_fdata(dtype=np.float32) + fodf = load_nifti_reorient(args.in_fodf).get_fdata(dtype=np.float32) # Optional: bg = None @@ -231,16 +231,16 @@ def _get_data_from_inputs(args): variance = None if args.background: assert_same_resolution([args.background, args.in_fodf]) - bg = nib.load(args.background).get_fdata() + bg = load_nifti_reorient(args.background).get_fdata() if args.in_transparency_mask: transparency_mask = get_data_as_mask( - nib.load(args.in_transparency_mask), dtype=bool) + load_nifti_reorient(args.in_transparency_mask), dtype=bool) if args.mask: assert_same_resolution([args.mask, args.in_fodf]) - mask = get_data_as_mask(nib.load(args.mask), dtype=bool) + mask = get_data_as_mask(load_nifti_reorient(args.mask), dtype=bool) if args.peaks: assert_same_resolution([args.peaks, args.in_fodf]) - peaks = nib.load(args.peaks).get_fdata() + peaks = load_nifti_reorient(args.peaks).get_fdata() if len(peaks.shape) == 4: last_dim = peaks.shape[-1] if last_dim % 3 == 0: @@ -253,10 +253,10 @@ def _get_data_from_inputs(args): if args.peaks_values: assert_same_resolution([args.peaks_values, args.in_fodf]) peak_vals =\ - nib.load(args.peaks_values).get_fdata() + load_nifti_reorient(args.peaks_values).get_fdata() if args.variance: assert_same_resolution([args.variance, args.in_fodf]) - variance = nib.load(args.variance).get_fdata(dtype=np.float32) + variance = load_nifti_reorient(args.variance).get_fdata(dtype=np.float32) if len(variance.shape) == 3: variance = np.reshape(variance, variance.shape + (1,)) if variance.shape != fodf.shape: diff --git a/src/scilpy/io/image.py b/src/scilpy/io/image.py index 64072bdb7..c36934392 100644 --- a/src/scilpy/io/image.py +++ b/src/scilpy/io/image.py @@ -9,6 +9,36 @@ from scilpy.utils import is_float +def load_nifti_reorient(file_path, return_flip_vector=False): + vol = nib.load(file_path) + + # Compute the image orientation (axis codes) + axcodes = nib.orientations.aff2axcodes(vol.affine) + target = ('R', 'A', 'S') + flip_vector = [1 if axcodes[i] == target[i] else -1 for i in range(3)] + + ras_order = [[i, flip_vector[i]] for i in range(0, 3)] + filename = vol.get_filename() + vol = vol.as_reoriented(ras_order) + vol.set_filename(filename) + + if return_flip_vector: + return vol, flip_vector + return vol + + +def nifti_reorient(img, flip_vector): + original_order = [[i, flip_vector[i]] for i in range(0, 3)] + img = img.as_reoriented(original_order) + return img + + +def save_nifti_reorient(img, flip_vector, file_path): + original_order = [[i, flip_vector[i]] for i in range(0, 3)] + img = img.as_reoriented(original_order) + nib.save(img, file_path) + + def load_img(arg): """ Function to create the variable for scil_volume_math main function. diff --git a/src/scilpy/tracking/utils.py b/src/scilpy/tracking/utils.py index 551d80959..02afc2e5b 100644 --- a/src/scilpy/tracking/utils.py +++ b/src/scilpy/tracking/utils.py @@ -16,6 +16,7 @@ from dipy.io.utils import create_tractogram_header, get_reference_info from dipy.reconst.shm import sh_to_sf_matrix from dipy.tracking.streamlinespeed import compress_streamlines, length +from scilpy.io.image import nifti_reorient from scilpy.io.utils import (add_compression_arg, add_overwrite_arg, add_sh_basis_args) from scilpy.reconst.utils import find_order_from_nb_coeff, get_maximas @@ -179,7 +180,8 @@ def tqdm_if_verbose(generator: Iterable, verbose: bool, *args, **kwargs): def save_tractogram( streamlines_generator, tracts_format, ref_img, total_nb_seeds, - out_tractogram, min_length, max_length, compress, save_seeds, verbose + out_tractogram, min_length, max_length, compress, save_seeds, + verbose, flip_vector=None, ): """ Save the streamlines on-the-fly using a generator. Tracts are filtered according to their length and compressed if requested. Seeds @@ -209,8 +211,12 @@ def save_tractogram( data_per_streamline property. verbose : bool If True, display progression bar. + flip_vector : np.ndarray, shape (3,) + If provided, the reference image will be flipped according to this vector. """ + if flip_vector is not None: + ref_img = nifti_reorient(ref_img, flip_vector) voxel_size = ref_img.header.get_zooms()[0] @@ -265,16 +271,16 @@ def tracks_generator_wrapper(): nib.streamlines.save(tractogram, out_tractogram, header=header) -def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, - voxel_size, sf_threshold, sh_to_pmf, +def get_direction_getter(img_data, algo, sphere, sub_sphere, theta, sh_basis, + sf_threshold, sh_to_pmf, probe_length, probe_radius, probe_quality, probe_count, support_exponent, is_legacy=True): """ Return the direction getter object. Parameters ---------- - in_img: str - Path to the input odf file. + img_data: np.ndarray + The input odf data as a numpy array (float32). algo: str Algorithm to use for tracking. Can be 'det', 'prob', 'ptt' or 'eudx'. sphere: str @@ -319,8 +325,6 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, dg: dipy.direction.DirectionGetter The direction getter object. """ - img_data = nib.load(in_img).get_fdata(dtype=np.float32) - sphere = HemiSphere.from_sphere( get_sphere(name=sphere)).subdivide(n=sub_sphere)