Skip to content

Commit 5845eb2

Browse files
authored
Merge pull request #425 from karanphil/memsmt_csd
[WIP] Adding b-tensor encoding scripts like compute_fodf/frf and metrics
2 parents 9eea119 + 09fa0c0 commit 5845eb2

15 files changed

+1869
-36
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
Instructions for tensor-valued dMRI scripts (b-tensor)
2+
======================================================
3+
4+
5+
The scripts for multi-encoding multi-shell multi-tissue CSD (memsmt-CSD) are based on P. Karan et al., Bridging the gap between constrained spherical deconvolution and diffusional variance decomposition via tensor-valued diffusion MRI. Medical Image Analysis (2022). We recommend reading it to understand the scope of the memsmt-CSD problem.
6+
7+
If you want to do CSD with b-tensor data, you should start by computing the fiber response functions. This script should run fast (less than 5 minutes on a full brain).
8+
::
9+
10+
scil_compute_memsmt_frf.py wm_frf.txt gm_frf.txt csf_frf.txt --in_dwis dwi_linear.nii.gz dwi_planar.nii.gz dwi_spherical.nii.gz --in_bvals dwi_linear.bval dwi_planar.bval dwi_spherical.bval --in_bvecs dwi_linear.bvec dwi_planar.bvec dwi_spherical.bvec --in_bdeltas 1 -0.5 0 --mask mask.nii.gz --mask_wm wm_mask.nii.gz --mask_gm gm_mask.nii.gz --mask_csf csf_mask.nii.gz -f
11+
12+
Then, you should compute the fODFs and volume fractions. The following command will save a fODF file for each tissue and a volume fractions file. This script should run in about 1-2 hours for a full brain.
13+
::
14+
15+
scil_compute_memsmt_fodf.py wm_frf.txt gm_frf.txt csf_frf.txt --in_dwis dwi_linear.nii.gz dwi_planar.nii.gz dwi_spherical.nii.gz --in_bvals dwi_linear.bval dwi_planar.bval dwi_spherical.bval --in_bvecs dwi_linear.bvec dwi_planar.bvec dwi_spherical.bvec --in_bdeltas 1 -0.5 0 --mask mask.nii.gz --processes 8 -f
16+
17+
If you want to do DIVIDE with b-tensor data, you should use the following command. It will save files for the MD, uFA, OP, MK_I, MK_A and MK_T. This script should run in about 1-2 hours for a full brain.
18+
::
19+
20+
scil_compute_divide.py --in_dwis dwi_linear.nii.gz dwi_planar.nii.gz dwi_spherical.nii.gz --in_bvals dwi_linear.bval dwi_planar.bval dwi_spherical.bval --in_bvecs dwi_linear.bvec dwi_planar.bvec dwi_spherical.bvec --in_bdeltas 1 -0.5 0 --mask mask.nii.gz --fa fa.nii.gz --processes 8 -f

scilpy/image/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,25 @@ def volume_iterator(img, blocksize=1, start=0, end=0):
7676
yield list(range(stop, end)), img.dataobj[..., stop:end]
7777

7878

79+
def extract_affine(input_files):
80+
"""Extract the affine from a list of nifti files.
81+
82+
Parameters
83+
----------
84+
input_files : list of strings (file paths)
85+
Diffusion data files.
86+
87+
Returns
88+
-------
89+
affine : np.ndarray
90+
Affine of the nifti volume.
91+
"""
92+
for input_file in input_files:
93+
if input_file:
94+
vol = nib.load(input_file)
95+
return vol.affine
96+
97+
7998
def check_slice_indices(vol_img, axis_name, slice_ids):
8099
"""Check that the given volume can be sliced at the given slice indices
81100
along the requested axis.

scilpy/io/fetcher.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def get_testing_files_dict():
106106
'anatomical_filtering.zip':
107107
['1Li8DdySnMnO9Gich4pilhXisjkjz1-Dy',
108108
'6f0eff5154ff0973a3dc26db00e383ea'],
109+
'btensor_testdata.zip':
110+
['1AMsKlbOZyPnT9TAbxcFzHS1b29aJWKDg',
111+
'7c68524fca01268203dc8bfee340f037'],
109112
'fodf_filtering.zip':
110113
['1iyoX2ltLOoLer-v-49LHOzopHCFZ_Tv6',
111114
'e79c4291af584fdb25814aa7b403a6ce']}

scilpy/reconst/b_tensor_utils.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import logging
2+
3+
from dipy.core.gradients import (gradient_table,
4+
unique_bvals_tolerance, get_bval_indices)
5+
from dipy.io.gradients import read_bvals_bvecs
6+
import nibabel as nib
7+
import numpy as np
8+
9+
from scilpy.utils.bvec_bval_tools import (normalize_bvecs, is_normalized_bvecs,
10+
extract_dwi_shell)
11+
12+
13+
bshapes = {0: "STE", 1: "LTE", -0.5: "PTE", 0.5: "CTE"}
14+
15+
16+
def generate_btensor_input(in_dwis, in_bvals, in_bvecs,
17+
in_bdeltas, force_b0_threshold,
18+
do_pa_signals=False, tol=20):
19+
"""Generate b-tensor input from an ensemble of data, bvals and bvecs files.
20+
This generated input is mandatory for all scripts using b-tensor encoding
21+
data. Also generate the powder-averaged (PA) data if set.
22+
23+
Parameters
24+
----------
25+
in_dwis : list of strings (file paths)
26+
Diffusion data files for each b-tensor encodings.
27+
in_bvals : list of strings (file paths)
28+
All of the bvals files associated.
29+
in_bvecs : list of strings (file paths)
30+
All of the bvecs files associated.
31+
in_bdeltas : list of floats
32+
All of the b_deltas (describing the type of encoding) files associated.
33+
force_b0_threshold : bool, optional
34+
If set, will continue even if the minimum bvalue is suspiciously high.
35+
do_pa_signals : bool, optional
36+
If set, will compute the powder_averaged input instead of the regular
37+
one. This means that the signal is averaged over all directions for
38+
each bvals.
39+
tol : int
40+
tolerance gap for b-values clustering. Defaults to 20
41+
42+
Returns
43+
-------
44+
gtab_full : GradientTable
45+
A single gradient table containing the information of all encodings.
46+
data_full : np.ndarray (4d)
47+
All concatenated diffusion data from the different encodings.
48+
ubvals_full : array
49+
All the unique bvals from the different encodings, but with a single
50+
b0. If two or more encodings have the same bvalue, then they are
51+
differentiate by +1.
52+
ub_deltas_full : array
53+
All the b_delta values associated with `ubvals_full`.
54+
pa_signals : np.ndarray (4d) (if `do_pa_signals`)
55+
Powder-averaged diffusion data.
56+
gtab_infos : np.ndarray (if `do_pa_signals`)
57+
Contains information about the gtab, such as the unique bvals, the
58+
encoding types, the number of directions and the acquisition index.
59+
"""
60+
data_full = np.empty(0)
61+
bvals_full = np.empty(0)
62+
bvecs_full = np.empty(0)
63+
b_shapes = np.empty(0)
64+
ubvals_full = np.empty(0)
65+
ub_deltas_full = np.empty(0)
66+
nb_bvecs_full = np.empty(0)
67+
acq_index_full = np.empty(0)
68+
ubvals_divide = np.empty(0)
69+
acq_index = 0
70+
for inputf, bvalsf, bvecsf, b_delta in zip(in_dwis, in_bvals,
71+
in_bvecs, in_bdeltas):
72+
if inputf: # verifies if the input file exists
73+
vol = nib.load(inputf)
74+
bvals, bvecs = read_bvals_bvecs(bvalsf, bvecsf)
75+
if np.sum([bvals > tol]) != 0:
76+
bvals = np.round(bvals)
77+
if not is_normalized_bvecs(bvecs):
78+
logging.warning('Your b-vectors do not seem normalized...')
79+
bvecs = normalize_bvecs(bvecs)
80+
ubvals = unique_bvals_tolerance(bvals, tol=tol)
81+
for ubval in ubvals: # Loop over all unique bvals
82+
# Extracting the data for the ubval shell
83+
indices, shell_data, _, output_bvecs = \
84+
extract_dwi_shell(vol, bvals, bvecs, [ubval], tol=tol)
85+
nb_bvecs = len(indices)
86+
# Adding the current data to each arrays of interest
87+
acq_index_full = np.concatenate([acq_index_full,
88+
[acq_index]]) \
89+
if acq_index_full.size else np.array([acq_index])
90+
ubvals_divide = np.concatenate([ubvals_divide, [ubval]]) \
91+
if ubvals_divide.size else np.array([ubval])
92+
while np.isin(ubval, ubvals_full): # Differentiate the bvals
93+
ubval += 1
94+
ubvals_full = np.concatenate([ubvals_full, [ubval]]) \
95+
if ubvals_full.size else np.array([ubval])
96+
ub_deltas_full = np.concatenate([ub_deltas_full, [b_delta]]) \
97+
if ub_deltas_full.size else np.array([b_delta])
98+
nb_bvecs_full = np.concatenate([nb_bvecs_full, [nb_bvecs]]) \
99+
if nb_bvecs_full.size else np.array([nb_bvecs])
100+
data_full = np.concatenate([data_full, shell_data], axis=-1) \
101+
if data_full.size else shell_data
102+
bvals_full = np.concatenate([bvals_full,
103+
np.repeat([ubval], nb_bvecs)]) \
104+
if bvals_full.size else np.repeat([ubval], nb_bvecs)
105+
bvecs_full = np.concatenate([bvecs_full, output_bvecs]) \
106+
if bvecs_full.size else output_bvecs
107+
b_shapes = np.concatenate([b_shapes,
108+
np.repeat([bshapes[b_delta]],
109+
nb_bvecs)]) \
110+
if b_shapes.size else np.repeat([bshapes[b_delta]],
111+
nb_bvecs)
112+
acq_index += 1
113+
# In the case that the PA data is wanted, there is a different return
114+
if do_pa_signals:
115+
pa_signals = np.zeros(((data_full.shape[:-1])+(len(ubvals_full),)))
116+
for i, ubval in enumerate(ubvals_full):
117+
indices = get_bval_indices(bvals_full, ubval, tol=0)
118+
pa_signals[..., i] = np.nanmean(data_full[..., indices], axis=-1)
119+
gtab_infos = np.ndarray((4, len(ubvals_full)))
120+
gtab_infos[0] = ubvals_divide
121+
gtab_infos[1] = ub_deltas_full
122+
gtab_infos[2] = nb_bvecs_full
123+
gtab_infos[3] = acq_index_full
124+
if np.sum([ubvals_full < tol]) < acq_index - 1:
125+
gtab_infos[3] *= 0
126+
return(pa_signals, gtab_infos)
127+
# Removing the duplicate b0s from ubvals_full
128+
duplicate_b0_ind = np.union1d(np.argwhere(ubvals_full == min(ubvals_full)),
129+
np.argwhere(ubvals_full > tol))
130+
ubvals_full = ubvals_full[duplicate_b0_ind]
131+
ub_deltas_full = ub_deltas_full[duplicate_b0_ind]
132+
# Sorting the data by bvals
133+
sorted_indices = np.argsort(bvals_full, axis=0)
134+
bvals_full = np.take_along_axis(bvals_full, sorted_indices, axis=0)
135+
bvals_full[bvals_full < tol] = min(ubvals_full)
136+
bvecs_full = np.take_along_axis(bvecs_full,
137+
sorted_indices.reshape(len(bvals_full), 1),
138+
axis=0)
139+
b_shapes = np.take_along_axis(b_shapes, sorted_indices, axis=0)
140+
data_full = np.take_along_axis(data_full,
141+
sorted_indices.reshape(1, 1, 1,
142+
len(bvals_full)),
143+
axis=-1)
144+
# Sorting the ubvals
145+
sorted_indices = np.argsort(np.asarray(ubvals_full), axis=0)
146+
ubvals_full = np.take_along_axis(np.asarray(ubvals_full), sorted_indices,
147+
axis=0)
148+
ub_deltas_full = np.take_along_axis(np.asarray(ub_deltas_full),
149+
sorted_indices, axis=0)
150+
# Creating the corresponding gtab
151+
gtab_full = gradient_table(bvals_full, bvecs_full,
152+
b0_threshold=bvals_full.min(),
153+
btens=b_shapes)
154+
155+
return(gtab_full, data_full, ubvals_full, ub_deltas_full)

0 commit comments

Comments
 (0)