|
| 1 | +import logging |
| 2 | + |
| 3 | +from dipy.core.gradients import (gradient_table, |
| 4 | + unique_bvals_tolerance, get_bval_indices) |
| 5 | +from dipy.io.gradients import read_bvals_bvecs |
| 6 | +import nibabel as nib |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from scilpy.utils.bvec_bval_tools import (normalize_bvecs, is_normalized_bvecs, |
| 10 | + extract_dwi_shell) |
| 11 | + |
| 12 | + |
| 13 | +bshapes = {0: "STE", 1: "LTE", -0.5: "PTE", 0.5: "CTE"} |
| 14 | + |
| 15 | + |
| 16 | +def generate_btensor_input(in_dwis, in_bvals, in_bvecs, |
| 17 | + in_bdeltas, force_b0_threshold, |
| 18 | + do_pa_signals=False, tol=20): |
| 19 | + """Generate b-tensor input from an ensemble of data, bvals and bvecs files. |
| 20 | + This generated input is mandatory for all scripts using b-tensor encoding |
| 21 | + data. Also generate the powder-averaged (PA) data if set. |
| 22 | +
|
| 23 | + Parameters |
| 24 | + ---------- |
| 25 | + in_dwis : list of strings (file paths) |
| 26 | + Diffusion data files for each b-tensor encodings. |
| 27 | + in_bvals : list of strings (file paths) |
| 28 | + All of the bvals files associated. |
| 29 | + in_bvecs : list of strings (file paths) |
| 30 | + All of the bvecs files associated. |
| 31 | + in_bdeltas : list of floats |
| 32 | + All of the b_deltas (describing the type of encoding) files associated. |
| 33 | + force_b0_threshold : bool, optional |
| 34 | + If set, will continue even if the minimum bvalue is suspiciously high. |
| 35 | + do_pa_signals : bool, optional |
| 36 | + If set, will compute the powder_averaged input instead of the regular |
| 37 | + one. This means that the signal is averaged over all directions for |
| 38 | + each bvals. |
| 39 | + tol : int |
| 40 | + tolerance gap for b-values clustering. Defaults to 20 |
| 41 | +
|
| 42 | + Returns |
| 43 | + ------- |
| 44 | + gtab_full : GradientTable |
| 45 | + A single gradient table containing the information of all encodings. |
| 46 | + data_full : np.ndarray (4d) |
| 47 | + All concatenated diffusion data from the different encodings. |
| 48 | + ubvals_full : array |
| 49 | + All the unique bvals from the different encodings, but with a single |
| 50 | + b0. If two or more encodings have the same bvalue, then they are |
| 51 | + differentiate by +1. |
| 52 | + ub_deltas_full : array |
| 53 | + All the b_delta values associated with `ubvals_full`. |
| 54 | + pa_signals : np.ndarray (4d) (if `do_pa_signals`) |
| 55 | + Powder-averaged diffusion data. |
| 56 | + gtab_infos : np.ndarray (if `do_pa_signals`) |
| 57 | + Contains information about the gtab, such as the unique bvals, the |
| 58 | + encoding types, the number of directions and the acquisition index. |
| 59 | + """ |
| 60 | + data_full = np.empty(0) |
| 61 | + bvals_full = np.empty(0) |
| 62 | + bvecs_full = np.empty(0) |
| 63 | + b_shapes = np.empty(0) |
| 64 | + ubvals_full = np.empty(0) |
| 65 | + ub_deltas_full = np.empty(0) |
| 66 | + nb_bvecs_full = np.empty(0) |
| 67 | + acq_index_full = np.empty(0) |
| 68 | + ubvals_divide = np.empty(0) |
| 69 | + acq_index = 0 |
| 70 | + for inputf, bvalsf, bvecsf, b_delta in zip(in_dwis, in_bvals, |
| 71 | + in_bvecs, in_bdeltas): |
| 72 | + if inputf: # verifies if the input file exists |
| 73 | + vol = nib.load(inputf) |
| 74 | + bvals, bvecs = read_bvals_bvecs(bvalsf, bvecsf) |
| 75 | + if np.sum([bvals > tol]) != 0: |
| 76 | + bvals = np.round(bvals) |
| 77 | + if not is_normalized_bvecs(bvecs): |
| 78 | + logging.warning('Your b-vectors do not seem normalized...') |
| 79 | + bvecs = normalize_bvecs(bvecs) |
| 80 | + ubvals = unique_bvals_tolerance(bvals, tol=tol) |
| 81 | + for ubval in ubvals: # Loop over all unique bvals |
| 82 | + # Extracting the data for the ubval shell |
| 83 | + indices, shell_data, _, output_bvecs = \ |
| 84 | + extract_dwi_shell(vol, bvals, bvecs, [ubval], tol=tol) |
| 85 | + nb_bvecs = len(indices) |
| 86 | + # Adding the current data to each arrays of interest |
| 87 | + acq_index_full = np.concatenate([acq_index_full, |
| 88 | + [acq_index]]) \ |
| 89 | + if acq_index_full.size else np.array([acq_index]) |
| 90 | + ubvals_divide = np.concatenate([ubvals_divide, [ubval]]) \ |
| 91 | + if ubvals_divide.size else np.array([ubval]) |
| 92 | + while np.isin(ubval, ubvals_full): # Differentiate the bvals |
| 93 | + ubval += 1 |
| 94 | + ubvals_full = np.concatenate([ubvals_full, [ubval]]) \ |
| 95 | + if ubvals_full.size else np.array([ubval]) |
| 96 | + ub_deltas_full = np.concatenate([ub_deltas_full, [b_delta]]) \ |
| 97 | + if ub_deltas_full.size else np.array([b_delta]) |
| 98 | + nb_bvecs_full = np.concatenate([nb_bvecs_full, [nb_bvecs]]) \ |
| 99 | + if nb_bvecs_full.size else np.array([nb_bvecs]) |
| 100 | + data_full = np.concatenate([data_full, shell_data], axis=-1) \ |
| 101 | + if data_full.size else shell_data |
| 102 | + bvals_full = np.concatenate([bvals_full, |
| 103 | + np.repeat([ubval], nb_bvecs)]) \ |
| 104 | + if bvals_full.size else np.repeat([ubval], nb_bvecs) |
| 105 | + bvecs_full = np.concatenate([bvecs_full, output_bvecs]) \ |
| 106 | + if bvecs_full.size else output_bvecs |
| 107 | + b_shapes = np.concatenate([b_shapes, |
| 108 | + np.repeat([bshapes[b_delta]], |
| 109 | + nb_bvecs)]) \ |
| 110 | + if b_shapes.size else np.repeat([bshapes[b_delta]], |
| 111 | + nb_bvecs) |
| 112 | + acq_index += 1 |
| 113 | + # In the case that the PA data is wanted, there is a different return |
| 114 | + if do_pa_signals: |
| 115 | + pa_signals = np.zeros(((data_full.shape[:-1])+(len(ubvals_full),))) |
| 116 | + for i, ubval in enumerate(ubvals_full): |
| 117 | + indices = get_bval_indices(bvals_full, ubval, tol=0) |
| 118 | + pa_signals[..., i] = np.nanmean(data_full[..., indices], axis=-1) |
| 119 | + gtab_infos = np.ndarray((4, len(ubvals_full))) |
| 120 | + gtab_infos[0] = ubvals_divide |
| 121 | + gtab_infos[1] = ub_deltas_full |
| 122 | + gtab_infos[2] = nb_bvecs_full |
| 123 | + gtab_infos[3] = acq_index_full |
| 124 | + if np.sum([ubvals_full < tol]) < acq_index - 1: |
| 125 | + gtab_infos[3] *= 0 |
| 126 | + return(pa_signals, gtab_infos) |
| 127 | + # Removing the duplicate b0s from ubvals_full |
| 128 | + duplicate_b0_ind = np.union1d(np.argwhere(ubvals_full == min(ubvals_full)), |
| 129 | + np.argwhere(ubvals_full > tol)) |
| 130 | + ubvals_full = ubvals_full[duplicate_b0_ind] |
| 131 | + ub_deltas_full = ub_deltas_full[duplicate_b0_ind] |
| 132 | + # Sorting the data by bvals |
| 133 | + sorted_indices = np.argsort(bvals_full, axis=0) |
| 134 | + bvals_full = np.take_along_axis(bvals_full, sorted_indices, axis=0) |
| 135 | + bvals_full[bvals_full < tol] = min(ubvals_full) |
| 136 | + bvecs_full = np.take_along_axis(bvecs_full, |
| 137 | + sorted_indices.reshape(len(bvals_full), 1), |
| 138 | + axis=0) |
| 139 | + b_shapes = np.take_along_axis(b_shapes, sorted_indices, axis=0) |
| 140 | + data_full = np.take_along_axis(data_full, |
| 141 | + sorted_indices.reshape(1, 1, 1, |
| 142 | + len(bvals_full)), |
| 143 | + axis=-1) |
| 144 | + # Sorting the ubvals |
| 145 | + sorted_indices = np.argsort(np.asarray(ubvals_full), axis=0) |
| 146 | + ubvals_full = np.take_along_axis(np.asarray(ubvals_full), sorted_indices, |
| 147 | + axis=0) |
| 148 | + ub_deltas_full = np.take_along_axis(np.asarray(ub_deltas_full), |
| 149 | + sorted_indices, axis=0) |
| 150 | + # Creating the corresponding gtab |
| 151 | + gtab_full = gradient_table(bvals_full, bvecs_full, |
| 152 | + b0_threshold=bvals_full.min(), |
| 153 | + btens=b_shapes) |
| 154 | + |
| 155 | + return(gtab_full, data_full, ubvals_full, ub_deltas_full) |
0 commit comments