Skip to content

Commit 09c9c68

Browse files
authored
Merge pull request #1194 from AntoineTheb/atheb/bundleparc
ENH: BundleParc - continuous and distance-based parcellation
2 parents cb1e113 + 420cff6 commit 09c9c68

File tree

5 files changed

+242
-72
lines changed

5 files changed

+242
-72
lines changed

src/scilpy/cli/scil_fodf_bundleparc.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
This method takes as input fODF maps and outputs 71 bundle label maps. These maps can then be used to perform tractometry/tract profiling/radiomics. The bundle definitions follow TractSeg's minus the whole CC.
77
8-
Inputs are presumed to come from Tractoflow and must be BET and cropped. fODFs must be of basis descoteaux07 and can be of order < 8 but accuracy may be reduced.
8+
**IMPORTANT**: fODF inputs are presumed to come from Tractoflow, must have stride of -1,2,3,4 and must be BET and cropped. fODFs must be in SH format, basis descoteaux07_legacy and can be of order < 8 but accuracy may be reduced.
99
1010
Model weights will be downloaded the first time the script is run, which will require an internet connection at runtime. Otherwise they can be manually downloaded from zenodo [1] and by specifying --checkpoint.
1111
@@ -19,15 +19,16 @@
1919
2020
The default value of 50 for --min_blob_size was found empirically on adult brains at a resolution of 1mm^3. The best value for your dataset may differ.
2121
22-
This script requires a GPU with ~6GB of available memory. If you use half-precision (float16) inference, you may be able to run it with ~3GB of GPU memory available. Otherwise, install the CPU version of PyTorch.
22+
This script requires a GPU with ~6GB of available memory. If you use half-precision (float16) inference, you may be able to run it with ~3GB of GPU memory available. Otherwise, install the CPU version of PyTorch. Execution on MacOS is not supported for now.
2323
2424
Parts of the implementation are based on or lifted from:
2525
SAM-Med3D: https://github.com/uni-medical/SAM-Med3D
2626
Multidimensional Positional Encoding: https://github.com/tatp22/multidim-positional-encoding
2727
2828
To cite: Antoine Théberge, Zineb El Yamani, François Rheault, Maxime Descoteaux, Pierre-Marc Jodoin (2025). LabelSeg. ISMRM Workshop on 40 Years of Diffusion: Past, Present & Future Perspectives, Kyoto, Japan.
2929
30-
[1]: https://zenodo.org/records/15579498
30+
[1]: Descoteaux, M., Deriche, R., Knösche, T. R., & Anwander, A. (2007). Deterministic and probabilistic tractography based on complex fibre orientation distributions. IEEE Transactions on Medical Imaging, 26(11), 1464-1477.
31+
[2]: https://zenodo.org/records/15579498
3132
""" # noqa
3233

3334
import argparse
@@ -37,15 +38,18 @@
3738
import os
3839

3940
from argparse import RawTextHelpFormatter
41+
from functools import partial
4042

4143
from scilpy.io.utils import (
4244
assert_inputs_exist, assert_output_dirs_exist_and_empty,
4345
add_overwrite_arg, add_verbose_arg)
4446
from scilpy.image.volume_operations import resample_volume
4547

4648
from scilpy.ml.bundleparc.predict import predict
49+
from scilpy.ml.bundleparc.labels import post_process_labels_discrete, \
50+
post_process_labels_mm, post_process_labels_continuous
4751
from scilpy.ml.bundleparc.utils import DEFAULT_BUNDLES, \
48-
download_weights, get_model
52+
download_weights, get_model
4953
from scilpy.ml.utils import get_device, IMPORT_ERROR_MSG
5054
from scilpy import SCILPY_HOME
5155

@@ -62,20 +66,11 @@ def _build_arg_parser():
6266
formatter_class=RawTextHelpFormatter)
6367

6468
parser.add_argument('in_fodf',
65-
help='fODF input.')
69+
help='Input fODF volume in nifti format. ')
6670
parser.add_argument('--out_prefix', default='',
6771
help='Output file prefix. Default is nothing. ')
6872
parser.add_argument('--out_folder', default='bundleparc',
6973
help='Output destination. Default is [%(default)s].')
70-
parser.add_argument('--nb_pts', type=int, default=50,
71-
help='Number of divisions per bundle. '
72-
'Default is [%(default)s].')
73-
parser.add_argument('--min_blob_size', type=int, default=50,
74-
help='Minimum blob size (in voxels) to keep. Smaller '
75-
'blobs will be removed. Default is '
76-
'[%(default)s].')
77-
parser.add_argument('--keep_biggest_blob', action='store_true',
78-
help='If set, only keep the biggest blob predicted.')
7974
parser.add_argument('--half_precision', action='store_true',
8075
help='Use half precision (float16) for inference. '
8176
'This reduces memory usage but may lead to '
@@ -88,6 +83,24 @@ def _build_arg_parser():
8883
'and weights of model. Default is '
8984
'[%(default)s]. If the file does not exist, it '
9085
'will be downloaded.')
86+
parcel_group = parser.add_mutually_exclusive_group()
87+
parcel_group.add_argument('--nb_pts', type=int, default=10,
88+
help='Number of divisions per bundle. Default is'
89+
' [%(default)s].')
90+
parcel_group.add_argument('--mm', type=float,
91+
help='If set, bundles will be split in sections '
92+
'roughly X mm wide.')
93+
parcel_group.add_argument('--continuous', action='store_true',
94+
help='If set, the output label maps will be '
95+
'continuous ∈ [0, 1].')
96+
blob_group = parser.add_mutually_exclusive_group()
97+
blob_group.add_argument('--min_blob_size', type=int, default=50,
98+
help='Minimum blob size (in voxels) to keep. '
99+
'Smaller blobs will be removed. Default is '
100+
'[%(default)s].')
101+
blob_group.add_argument('--keep_biggest_blob', action='store_true',
102+
help='Only keep the biggest blob predicted.')
103+
91104
add_overwrite_arg(parser)
92105
add_verbose_arg(parser)
93106

@@ -99,7 +112,7 @@ def main():
99112
parser = _build_arg_parser()
100113
args = parser.parse_args()
101114

102-
assert_inputs_exist(parser, [args.in_fodf])
115+
assert_inputs_exist(parser, [args.in_fodf], [])
103116
assert_output_dirs_exist_and_empty(parser, args, args.out_folder,
104117
create_dir=True)
105118

@@ -135,19 +148,33 @@ def main():
135148
interp='lin',
136149
enforce_dimensions=False)
137150

151+
# Get the voxel size of the input fODF after resampling
152+
# Presuming isotropic resampling
153+
voxel_size = np.mean(resampled_img.header.get_zooms()[:3])
154+
155+
# Get the label function to use for post-processing
156+
if args.continuous:
157+
label_function = post_process_labels_continuous
158+
elif args.mm is not None:
159+
label_function = partial(post_process_labels_mm, args.mm, voxel_size)
160+
else:
161+
label_function = partial(post_process_labels_discrete,
162+
args.nb_pts)
163+
138164
# Predict label maps. `predict` is a generator
139165
# yielding one label map per bundle and its name.
140166
for y_hat_label, b_name in predict(
141-
model, resampled_img.get_fdata(dtype=np.float32), n_coefs, args.nb_pts,
142-
args.bundles, args.min_blob_size, args.keep_biggest_blob,
143-
args.half_precision, logging.getLogger().getEffectiveLevel() <
144-
logging.WARNING
167+
model, resampled_img.get_fdata(dtype=np.float32), n_coefs,
168+
label_function, DEFAULT_BUNDLES, args.keep_biggest_blob,
169+
args.half_precision,
170+
logging.getLogger().getEffectiveLevel() < logging.WARNING
145171
):
146172

147173
# Format the output as a nifti image
148174
label_img = nib.Nifti1Image(y_hat_label,
149175
resampled_img.affine,
150-
resampled_img.header, dtype=np.uint16)
176+
resampled_img.header,
177+
dtype=y_hat_label.dtype)
151178

152179
# Resampling volume to fit the original image size
153180
resampled_label = resample_volume(label_img, ref_img=None,

src/scilpy/cli/tests/test_fodf_bundleparc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,23 @@ def test_execution_invalid_bundle(script_runner, monkeypatch):
5353
ret = script_runner.run(['scil_fodf_bundleparc', in_fodf,
5454
'-f', '--bundles', 'CC'])
5555
assert not ret.success
56+
57+
58+
@pytest.mark.ml
59+
def test_execution_mm(script_runner, monkeypatch):
60+
in_fodf = os.path.join(SCILPY_HOME, 'tracking', 'fodf.nii.gz')
61+
62+
ret = script_runner.run('scil_fodf_bundleparc.py', in_fodf,
63+
'--mm', '10',
64+
'--bundles', 'IFO_right', '-f')
65+
assert ret.success
66+
67+
68+
@pytest.mark.ml
69+
def test_execution_cont(script_runner, monkeypatch):
70+
in_fodf = os.path.join(SCILPY_HOME, 'tracking', 'fodf.nii.gz')
71+
72+
ret = script_runner.run('scil_fodf_bundleparc.py', in_fodf,
73+
'--continuous',
74+
'--bundles', 'IFO_right', '-f')
75+
assert ret.success

src/scilpy/ml/bundleparc/labels.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import logging
2+
import numpy as np
3+
4+
from dipy.tracking.streamline import length, set_number_of_points
5+
6+
from scilpy.tractograms.streamline_operations import smooth_line_gaussian
7+
8+
9+
def post_process_labels_discrete(
10+
nb_labels, bundle_label, bundle_mask, bundle_name
11+
):
12+
""" Discretize the labels and apply a mask to the bundle. Labels are
13+
discretized to integers in the range [1, nb_labels] uniformly.
14+
15+
Parameters
16+
----------
17+
nb_labels : int
18+
Number of labels to discretize to.
19+
bundle_label : np.ndarray
20+
Predicted continuous labels for the bundle.
21+
bundle_mask : np.ndarray
22+
Binary mask of the bundle.
23+
bundle_name : str
24+
Name of the bundle, used for logging.
25+
26+
Returns
27+
-------
28+
bundle_label : np.ndarray
29+
Predicted labels for the bundle.
30+
"""
31+
32+
# Determine the output type based on the number of labels
33+
# In scilpy/MI-Brain, uint16 is used for labels, uint8 for binary masks.
34+
out_type = np.uint16 if nb_labels > 1 else np.uint8
35+
36+
# Label masking
37+
discrete_labels = bundle_label[bundle_mask.astype(bool)]
38+
39+
# Label dicretizing
40+
discrete_labels = np.ceil(discrete_labels * nb_labels)
41+
bundle_label[bundle_mask.astype(bool)] = discrete_labels
42+
bundle_label[~bundle_mask.astype(bool)] = 0
43+
44+
return bundle_label.astype(out_type)
45+
46+
47+
def post_process_labels_mm(
48+
labels_mm, voxel_size, bundle_label, bundle_mask, bundle_name
49+
):
50+
""" Discretize the labels and apply a mask to the bundle. Labels are
51+
discritezed to integers so that each section is roughly `labels_mm` mm long
52+
To do so, the barycenter of each label is computed to form a centroid
53+
streamline. Then, the centroid is resampled to have a number of points such
54+
that the step-size is roughly `labels_mm` mm. Finally, the labels are
55+
reassigned to the closest point in the resampled centroid.
56+
57+
Parameters
58+
----------
59+
labels_mm : float
60+
Length of each section in mm.
61+
voxel_size : np.ndarray
62+
Voxel size of the bundle image.
63+
bundle_label : np.ndarray
64+
Predicted continuous labels for the bundle.
65+
bundle_mask : np.ndarray
66+
Binary mask of the bundle.
67+
bundle_name : str
68+
Name of the bundle, used for logging.
69+
70+
Returns
71+
-------
72+
bundle_label : np.ndarray
73+
Predicted labels for the bundle.
74+
"""
75+
76+
# Label masking
77+
bundle_label[~bundle_mask.astype(bool)] = 0
78+
79+
ref_labels = np.ceil(bundle_label * 50)
80+
unique = np.unique(ref_labels)
81+
82+
# Get the 3D coordinates of the barycenter of each label
83+
barycenters = np.zeros((len(unique) - 1, 3), dtype=np.float32)
84+
for i, label in enumerate(unique[1:]):
85+
coords = np.argwhere(ref_labels == label)
86+
barycenters[i] = np.mean(coords, axis=0)
87+
88+
# Form the barycenters into a single streamline
89+
centroid = np.asarray(barycenters)
90+
centroid = smooth_line_gaussian(centroid, 5)
91+
92+
# Resampling
93+
c_length = length(centroid * voxel_size)
94+
# Calculate the number of points to resample to
95+
nb_points = np.round(c_length / labels_mm).astype(int)
96+
if nb_points < 2:
97+
logging.warning(f"{bundle_name} is shorter than the section length.")
98+
nb_points = 2
99+
100+
# Resample the centroid to have `nb_points` points
101+
# Adding 2 points so they can be excluded from the labels. Sort of a
102+
# reverse signpost problem. Otherwize, the first and last labels
103+
# and no other would be assigned to the first and last point of the
104+
# centroid.
105+
resampled_centroid = set_number_of_points(centroid, nb_points + 2)
106+
107+
# Re-discretizing the labels based on the resampled centroid
108+
discrete_labels = np.zeros_like(bundle_label, dtype=np.float32)
109+
for i, label in enumerate(unique[1:]):
110+
# Find the closest label in the resampled centroid
111+
c = centroid[i]
112+
# Calculate the distances from the centroid to the resampled centroid
113+
# Exclude the first and last points of the resampled centroid (see
114+
# above)
115+
distances = np.linalg.norm(
116+
c - resampled_centroid[None, 1:-1], axis=-1)
117+
# Get the index of the closest label
118+
closest_index = np.argmin(distances)
119+
# Assign the label to the closest index in the resampled centroid
120+
discrete_labels[ref_labels == label] = closest_index + 1
121+
122+
# Determine the output type based on the number of labels
123+
out_type = np.uint16 if nb_points > 1 else np.uint8
124+
125+
return discrete_labels.astype(out_type)
126+
127+
128+
def post_process_labels_continuous(
129+
bundle_label, bundle_mask, bundle_name
130+
):
131+
""" Don't discretize the labels, just apply a mask to the bundle.
132+
133+
Parameters
134+
----------
135+
bundle_label : np.ndarray
136+
Predicted continuous labels for the bundle.
137+
bundle_mask : np.ndarray
138+
Binary mask of the bundle.
139+
bundle_name : str
140+
Name of the bundle, used for logging.
141+
142+
Returns
143+
-------
144+
bundle_label : np.ndarray
145+
Predicted labels for the bundle.
146+
"""
147+
148+
# Determine the output type based on the number of labels
149+
# In this case, we assume the labels are continuous and
150+
# can be represented as floats.
151+
out_type = float
152+
153+
# Label masking
154+
bundle_label[~bundle_mask.astype(bool)] = 0
155+
156+
return bundle_label.astype(out_type)

0 commit comments

Comments
 (0)