Skip to content

Commit 92f81b1

Browse files
committed
Various documentation and formatting changes as well as optimizations during review
1 parent 6a74b0e commit 92f81b1

20 files changed

+415
-448
lines changed

CorpusCallosum/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Corpus Callosum Pipeline
22

33
A deep learning-based pipeline for automated segmentation, analysis, and shape analysis of the corpus callosum in brain MRI scans.
4-
Also segments the fornix, localizes the AC and PC and standardizes the orientation of the brain.
4+
Also segments the fornix, localizes the anterior and posterior commissure (AC and PC) and standardizes the orientation of the brain.
55

66
## Overview
77

@@ -15,7 +15,9 @@ This pipeline combines localization and segmentation deep learning models to:
1515

1616
## Quickstart
1717

18-
``` python3 fastsurfer_cc.py --subject_dir /path/to/fastsurfer/output --verbose ```
18+
```bash
19+
python3 fastsurfer_cc.py --subject_dir /path/to/fastsurfer/output --verbose
20+
` ``
1921
2022
Gives all standard outputs. Then corpus callosum morphometry can be found at `stats/callosum.CC.midslice.json`, including 100 thickness measurements and areas of sub-segments.
2123
Visualization will be placed in `/path/to/fastsurfer/output/qc_snapshots`. For more detailed info see the following sections.

CorpusCallosum/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,8 @@
1313
# limitations under the License.
1414

1515
__all__ = [
16-
"config",
1716
"data",
18-
"localization",
1917
"segmentation",
2018
"transforms",
2119
"utils",
22-
"visualization",
2320
]

CorpusCallosum/cc_visualization.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import argparse
2+
import sys
23
from pathlib import Path
4+
from typing import Literal
35

46
import numpy as np
57

@@ -9,8 +11,8 @@
911
from CorpusCallosum.shape.cc_mesh import CC_Mesh
1012

1113

12-
def options_parse() -> argparse.Namespace:
13-
"""Parse command line arguments for the visualization pipeline."""
14+
def make_parser() -> argparse.ArgumentParser:
15+
"""Create a command line parser for the visualization pipeline."""
1416
parser = argparse.ArgumentParser(description="Visualize corpus callosum from template files.")
1517
parser.add_argument("--contours", type=str, required=False, help="Path to contours.txt file", default=None)
1618
parser.add_argument("--thickness", type=str, required=True, help="Path to thickness_values.txt file")
@@ -38,12 +40,17 @@ def options_parse() -> argparse.Namespace:
3840
nargs=2,
3941
default=None,
4042
metavar=("MIN", "MAX"),
41-
help="Optional fixed range for the colorbar (min max)",
43+
required=False,
44+
help="Specify the range for the colorbar (2 values: min max). Defaults to automatic choice.",
4245
)
4346
parser.add_argument("--legend", type=str, default="Thickness (mm)", help="Legend for the colorbar")
4447
parser.add_argument("--twoD", action="store_true", help="Generate 2D visualization instead of 3D mesh")
4548

46-
args = parser.parse_args()
49+
return parser
50+
51+
def options_parse() -> argparse.Namespace:
52+
"""Parse command line arguments for the pipeline."""
53+
args = make_parser().parse_args()
4754

4855
# Create output directory if it doesn't exist
4956
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
@@ -62,7 +69,7 @@ def main(
6269
color_range: tuple[float, float] | None = None,
6370
legend: str | None = None,
6471
twoD: bool = False,
65-
) -> None:
72+
) -> Literal[0] | str:
6673
"""Main function to visualize corpus callosum from template files.
6774
6875
This function loads contours and thickness values from template files,
@@ -146,21 +153,24 @@ def main(
146153
cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk"))
147154
cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf"))
148155
cc_mesh.write_overlay(str(output_dir / "cc_mesh_overlay.curv"))
149-
cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png"))
150-
156+
try:
157+
cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png"))
158+
except RuntimeError:
159+
return ("The cc_visualization script requires whippersnappy>=1.3.1 to makes screenshots, install with "
160+
"`pip install whippersnappy>=1.3.1` !")
161+
return 0
151162

152163
if __name__ == "__main__":
153-
options = options_parse()
154-
main_args = {
155-
"contours_path": options.contours,
156-
"thickness_path": options.thickness,
157-
"measurement_points_path": options.measurement_points,
158-
"output_dir": options.output_dir,
159-
"resolution": options.resolution,
160-
"smoothing_window": options.smoothing_window,
161-
"colormap": options.colormap,
162-
"color_range": options.color_range,
163-
"legend": options.legend,
164-
"twoD": options.twoD,
165-
}
166-
main(**main_args)
164+
options = make_parser().parse_args()
165+
sys.exit(main(
166+
contours_path=options.contours,
167+
thickness_path=options.thickness,
168+
measurement_points_path=options.measurement_points,
169+
output_dir=options.output_dir,
170+
resolution=options.resolution,
171+
smooth_iterations=options.smooth_iterations,
172+
colormap=options.colormap,
173+
color_range=options.color_range,
174+
legend=options.legend,
175+
twoD=options.twoD,
176+
))

CorpusCallosum/data/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
from pathlib import Path
1717

1818
### Constants
19-
WEIGHTS_PATH = Path(__file__).parent.parent / "weights"
19+
WEIGHTS_PATH = Path(__file__).parents[1] / "weights"
2020
FSAVERAGE_CENTROIDS_PATH = Path(__file__).parent / "fsaverage_centroids.json"
2121
FSAVERAGE_DATA_PATH = Path(__file__).parent / "fsaverage_data.json" # Contains both affine and header
2222
FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space
2323
CC_LABEL = 192 # Label value for corpus callosum in segmentation
2424
FORNIX_LABEL = 250 # Label value for fornix in segmentation
25-
SUBSEGEMNT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation
25+
SUBSEGMENT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation
2626

2727

2828
STANDARD_OUTPUT_PATHS = {

CorpusCallosum/data/fsaverage_cc_template.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import numpy as np
2020
from scipy import ndimage
2121

22+
from CorpusCallosum.data import constants
2223
from CorpusCallosum.shape.cc_postprocessing import process_slice
24+
from FastSurferCNN.utils.brainvolstats import mask_in_array
2325

2426

2527
def smooth_contour(contour: tuple[np.ndarray, np.ndarray], window_size: int = 5) -> tuple[np.ndarray, np.ndarray]:
@@ -93,19 +95,15 @@ def load_fsaverage_cc_template() -> tuple[
9395

9496
fsaverage_seg_path = freesurfer_home / 'subjects' / 'fsaverage' / 'mri' / 'aparc+aseg.mgz'
9597
fsaverage_seg = nib.load(fsaverage_seg_path)
96-
segmentation = fsaverage_seg.get_fdata()
98+
segmentation = np.asarray(fsaverage_seg.dataobj)
9799

98100
PC = np.array([131, 99])
99101
AC = np.array([135, 130])
100102

101103

102104
midslice = segmentation.shape[0]//2 +1
103105

104-
cc_mask = segmentation[midslice] == 251
105-
cc_mask |= segmentation[midslice] == 252
106-
cc_mask |= segmentation[midslice] == 253
107-
cc_mask |= segmentation[midslice] == 254
108-
cc_mask |= segmentation[midslice] == 255
106+
cc_mask = mask_in_array(segmentation[midslice], constants.SUBSEGMENT_LABELS)
109107

110108
# Smooth the CC mask to reduce noise and irregularities
111109

@@ -120,8 +118,7 @@ def load_fsaverage_cc_template() -> tuple[
120118
cc_mask_smoothed = cc_mask_smoothed > 0.5
121119

122120
# Use the smoothed mask for further processing
123-
cc_mask = cc_mask_smoothed.astype(int)
124-
cc_mask[cc_mask > 0] = 192
121+
cc_mask = cc_mask_smoothed.astype(int) * 192
125122

126123
(_, contour_with_thickness, anterior_endpoint_idx,
127124
posterior_endpoint_idx) = process_slice(segmentation=cc_mask[None],

CorpusCallosum/data/generate_fsaverage_centroids.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,4 @@ def main() -> None:
163163

164164

165165
if __name__ == "__main__":
166-
main()
166+
main()

CorpusCallosum/data/read_write.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import multiprocessing
1717
from pathlib import Path
18+
from typing import overload
1819

1920
import nibabel as nib
2021
import numpy as np
@@ -52,8 +53,15 @@ def run_in_background(function: callable, debug: bool = False, *args, **kwargs)
5253
return process
5354

5455

56+
@overload
57+
def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: None = None) -> dict[int, np.ndarray]:
58+
...
5559

56-
def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None = None) -> dict[int, np.ndarray]:
60+
@overload
61+
def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int]) -> tuple[dict[int, np.ndarray], list[int]]:
62+
...
63+
64+
def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None = None):
5765
"""Get centroids of segmentation labels in RAS coordinates.
5866
5967
Parameters
@@ -187,12 +195,7 @@ def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, np.ndarray
187195
centroids_data = json.load(f)
188196

189197
# Convert string keys back to integers and lists back to numpy arrays
190-
centroids = {}
191-
for label_str, centroid_list in centroids_data.items():
192-
label_id = int(label_str)
193-
centroids[label_id] = np.array(centroid_list)
194-
195-
return centroids
198+
return {int(label): np.array(centroid) for label, centroid in centroids_data.items()}
196199

197200

198201
def load_fsaverage_affine(affine_path: str | Path) -> np.ndarray:
@@ -270,8 +273,8 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.nda
270273
if "header" not in data:
271274
raise ValueError("Required field 'header' missing from data file")
272275

273-
header_fields = ["dims", "delta", "Mdc", "Pxyz_c"]
274-
for field in header_fields:
276+
required_header_fields = ["dims", "delta", "Mdc", "Pxyz_c"]
277+
for field in required_header_fields:
275278
if field not in data["header"]:
276279
raise ValueError(f"Required header field missing: {field}")
277280

0 commit comments

Comments
 (0)