Skip to content

Commit 3d3eb6a

Browse files
ClePoldkuegler
authored andcommitted
improved contour extraction for thin CC and surface coordinates
1 parent 3cf2420 commit 3d3eb6a

File tree

5 files changed

+378
-65
lines changed

5 files changed

+378
-65
lines changed

CorpusCallosum/fastsurfer_cc.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def options_parse() -> argparse.Namespace:
100100
parser.add_argument(
101101
"--contour_smoothing",
102102
type=float,
103-
default=1.0,
104-
help="Gaussian sigma for smoothing during contour detection. Default is 1.0, higher values mean a smoother"
103+
default=5,
104+
help="Window size for smoothing during contour detection. Default is 5, higher values mean a smoother"
105105
"outline, at the cost of precision.",
106106
)
107107
parser.add_argument(
@@ -390,7 +390,7 @@ def segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, sl
390390
or np.any(cc_volume_mask[:, :, 0])
391391
or np.any(cc_volume_mask[:, :, -1])
392392
):
393-
print("Warning: CC volume mask touches the edge of the segmentation field-of-view, CC might be truncated")
393+
print("Warning: CC voume mask touches the edge of the segmentation field-of-view, CC might be truncated")
394394

395395
# get voxels that were removed during cleaning
396396
removed_voxels = pre_clean_segmentation != segmentation
@@ -411,7 +411,7 @@ def main(
411411
num_thickness_points: int = 100,
412412
subdivisions: list[float] | None = None,
413413
subdivision_method: str = "shape",
414-
contour_smoothing: float = 1.0,
414+
contour_smoothing: float = 5,
415415
save_template: str | Path | None = None,
416416
cpu: bool = False,
417417
# output paths
@@ -606,7 +606,6 @@ def main(
606606
midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze
607607
)
608608

609-
610609
# calculate affine for segmentation volume
611610
orig_to_seg = np.eye(4)
612611
orig_to_seg[0, 3] = -FSAVERAGE_MIDDLE + slices_to_analyze // 2

CorpusCallosum/segmentation/segmentation_postprocessing.py

Lines changed: 255 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
from scipy import integrate, ndimage
17+
from scipy.spatial.distance import cdist
1718
from skimage.measure import label
1819

1920
import FastSurferCNN.utils.logging as logging
@@ -22,6 +23,205 @@
2223
logger = logging.get_logger(__name__)
2324

2425

26+
def find_component_boundaries(labels_arr: np.ndarray, component_id: int) -> np.ndarray:
27+
"""Find boundary voxels of a connected component.
28+
29+
Args:
30+
labels_arr (np.ndarray): Labeled array from connected components analysis
31+
component_id (int): ID of the component to find boundaries for
32+
33+
Returns:
34+
np.ndarray: Array of boundary coordinates (N, 3)
35+
"""
36+
component_mask = labels_arr == component_id
37+
38+
# Create a structuring element for 6-connectivity (face neighbors only)
39+
struct = ndimage.generate_binary_structure(3, 1)
40+
41+
# Erode the component to find internal voxels
42+
eroded = ndimage.binary_erosion(component_mask, structure=struct)
43+
44+
# Boundary is the difference between original and eroded
45+
boundary = component_mask & ~eroded
46+
47+
return np.array(np.where(boundary)).T
48+
49+
50+
def find_minimal_connection_path(boundary1: np.ndarray, boundary2: np.ndarray,
51+
max_distance: float = 3.0) -> tuple[np.ndarray, np.ndarray] | None:
52+
"""Find the minimal connection path between two component boundaries.
53+
54+
Args:
55+
boundary1 (np.ndarray): Boundary coordinates of first component (N1, 3)
56+
boundary2 (np.ndarray): Boundary coordinates of second component (N2, 3)
57+
max_distance (float): Maximum distance to consider for connection
58+
59+
Returns:
60+
tuple | None: (point1, point2) coordinates of closest points if within max_distance, None otherwise
61+
"""
62+
if len(boundary1) == 0 or len(boundary2) == 0:
63+
return None
64+
65+
# Calculate pairwise distances between all boundary points
66+
distances = cdist(boundary1, boundary2, metric='euclidean')
67+
68+
# Find the minimum distance and corresponding points
69+
min_idx = np.unravel_index(np.argmin(distances), distances.shape)
70+
min_distance = distances[min_idx]
71+
72+
if min_distance <= max_distance:
73+
point1 = boundary1[min_idx[0]]
74+
point2 = boundary2[min_idx[1]]
75+
return point1, point2
76+
77+
return None
78+
79+
80+
def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple[int, int, int]]:
81+
"""Create a line of voxels connecting two points using simplified 3D line algorithm.
82+
83+
Args:
84+
point1 (np.ndarray): Starting point coordinates (3,)
85+
point2 (np.ndarray): Ending point coordinates (3,)
86+
87+
Returns:
88+
list: List of (x, y, z) coordinates forming the connection line
89+
"""
90+
x1, y1, z1 = map(int, point1)
91+
x2, y2, z2 = map(int, point2)
92+
93+
line_points = []
94+
95+
# Calculate the number of steps needed
96+
dx = abs(x2 - x1)
97+
dy = abs(y2 - y1)
98+
dz = abs(z2 - z1)
99+
100+
steps = max(dx, dy, dz)
101+
102+
if steps == 0:
103+
return [(x1, y1, z1)]
104+
105+
# Calculate increments for each dimension
106+
x_inc = (x2 - x1) / steps
107+
y_inc = (y2 - y1) / steps
108+
z_inc = (z2 - z1) / steps
109+
110+
# Generate points along the line
111+
for i in range(steps + 1):
112+
x = int(round(x1 + i * x_inc))
113+
y = int(round(y1 + i * y_inc))
114+
z = int(round(z1 + i * z_inc))
115+
line_points.append((x, y, z))
116+
117+
return line_points
118+
119+
120+
def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> np.ndarray:
121+
"""Connect nearby disconnected components that should be connected.
122+
123+
This function identifies disconnected components in the segmentation and creates
124+
minimal connections between components that are close to each other.
125+
126+
Args:
127+
seg_arr (np.ndarray): Input binary segmentation array
128+
max_connection_distance (float): Maximum distance to connect components
129+
130+
Returns:
131+
np.ndarray: Segmentation array with minimal connections added
132+
"""
133+
134+
# Create a copy to modify
135+
connected_seg = seg_arr.copy()
136+
137+
# Find connected components without dilation first
138+
labels_cc = label(seg_arr, connectivity=3, background=0)
139+
140+
# Get component sizes (excluding background)
141+
bincount = np.bincount(labels_cc.flat)
142+
component_ids = np.where(bincount > 0)[0][1:] # Exclude background (0)
143+
144+
if len(component_ids) <= 1:
145+
return connected_seg # Only one component, no connections needed
146+
147+
# Sort components by size (largest first)
148+
component_sizes = [(comp_id, bincount[comp_id]) for comp_id in component_ids]
149+
component_sizes.sort(key=lambda x: x[1], reverse=True)
150+
151+
# Use the largest component as the reference
152+
main_component_id = component_sizes[0][0]
153+
154+
155+
156+
logger.info(f"Found {len(component_ids)} disconnected components. "
157+
f"Attempting to connect smaller components to main component (size: {component_sizes[0][1]})")
158+
159+
connections_made = 0
160+
161+
# Try to connect each smaller component to the main component
162+
for comp_id, comp_size in component_sizes[1:]:
163+
if comp_size < 5: # Skip very small components (likely noise)
164+
logger.debug(f"Skipping tiny component {comp_id} with size {comp_size}")
165+
continue
166+
167+
# Find boundaries of both components
168+
main_boundary = find_component_boundaries(labels_cc, main_component_id)
169+
comp_boundary = find_component_boundaries(labels_cc, comp_id)
170+
171+
# Find minimal connection path
172+
connection = find_minimal_connection_path(main_boundary, comp_boundary, max_connection_distance)
173+
174+
if connection is not None:
175+
point1, point2 = connection
176+
distance = np.linalg.norm(point2 - point1)
177+
178+
logger.debug(f"Connecting component {comp_id} (size: {comp_size}) to main component. "
179+
f"Distance: {distance:.2f} voxels")
180+
181+
# Create connection line
182+
connection_line = create_connection_line(point1, point2)
183+
184+
# Add connection voxels to the segmentation
185+
# Use the same label as the original segmentation at the connection points
186+
connection_label = seg_arr[point1[0], point1[1], point1[2]] if \
187+
seg_arr[point1[0], point1[1], point1[2]] != 0 else \
188+
seg_arr[point2[0], point2[1], point2[2]]
189+
190+
for x, y, z in connection_line:
191+
if (0 <= x < connected_seg.shape[0] and
192+
0 <= y < connected_seg.shape[1] and
193+
0 <= z < connected_seg.shape[2]):
194+
if connected_seg[x, y, z] == 0: # Only fill empty voxels
195+
connected_seg[x, y, z] = connection_label
196+
197+
connections_made += 1
198+
else:
199+
logger.debug(f"Component {comp_id} (size: {comp_size}) too far from main component")
200+
201+
logger.info(f"Created {connections_made} minimal connections between components")
202+
203+
204+
# Plot components for visualization
205+
# import matplotlib.pyplot as plt
206+
# n_components = len(component_sizes)
207+
# fig, axes = plt.subplots(1, n_components + 1, figsize=(5*(n_components + 1), 5))
208+
# if n_components == 1:
209+
# axes = [axes]
210+
# # Plot each component in a different color
211+
# for i, (comp_id, comp_size) in enumerate(component_sizes):
212+
# component_mask = labels_cc == comp_id
213+
# axes[i].imshow(component_mask[component_mask.shape[0]//2], cmap='gray')
214+
# axes[i].set_title(f'Component {comp_id}\nSize: {comp_size}')
215+
# axes[i].axis('off')
216+
217+
# # Plot the connected segmentation
218+
# axes[-1].imshow(connected_seg[connected_seg.shape[0]//2], cmap='gray')
219+
# axes[-1].set_title('Connected Segmentation')
220+
# axes[-1].axis('off')
221+
# plt.tight_layout()
222+
# plt.show()
223+
224+
return connected_seg
25225

26226

27227
def get_cc_volume_voxel(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[float, float, float]) -> float:
@@ -142,72 +342,94 @@ def get_cc_volume_contour(desired_width_mm: int, cc_contours: list[np.ndarray],
142342
return integrate.simpson(areas, x=measurement_points)
143343

144344

145-
def get_largest_cc(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
146-
"""Get largest connected component from a binary segmentation array.
345+
def get_largest_cc(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> tuple[np.ndarray, np.ndarray]:
346+
"""Get largest connected component from a binary segmentation array with minimal connections.
147347
148-
This function takes a binary segmentation array, dilates it, finds connected components,
149-
and returns the largest component (excluding background) along with its mask.
348+
This function takes a binary segmentation array, attempts to connect nearby disconnected
349+
components that should be connected, then finds the largest connected component.
350+
It first tries to establish minimal connections between close components before
351+
falling back to dilation if no connections are made.
150352
151353
Args:
152354
seg_arr (np.ndarray): Input binary segmentation array
355+
max_connection_distance (float): Maximum distance to connect components (default: 3.0)
153356
154357
Returns:
155358
tuple: A tuple containing:
156359
- clean_seg (np.ndarray): Segmentation array with only the largest connected component
157360
- largest_cc (np.ndarray): Binary mask of the largest connected component
158361
"""
159-
# generate dilatation structure
160-
struct1 = ndimage.generate_binary_structure(3, 3)
161-
# Dilate prediction
162-
mask = ndimage.binary_dilation(seg_arr, structure=struct1, iterations=1, ).astype(np.uint8)
163-
# Get connected components
362+
# First attempt: try to connect nearby components with minimal connections
363+
connected_seg = connect_nearby_components(seg_arr, max_connection_distance)
364+
365+
# Check if connections were successful by comparing connectivity
366+
original_labels = label(seg_arr, connectivity=3, background=0)
367+
connected_labels = label(connected_seg, connectivity=3, background=0)
368+
369+
original_components = len(np.unique(original_labels)) - 1 # Exclude background
370+
connected_components = len(np.unique(connected_labels)) - 1 # Exclude background
371+
372+
if connected_components < original_components:
373+
logger.info(f"Successfully reduced components from {original_components} to {connected_components} "
374+
"using minimal connections")
375+
mask = connected_seg
376+
# else:
377+
# logger.info("No connections made, falling back to dilation approach")
378+
# # Fallback: use the original dilation approach
379+
# struct1 = ndimage.generate_binary_structure(3, 3)
380+
# mask = ndimage.binary_dilation(seg_arr, structure=struct1, iterations=1).astype(np.uint8)
381+
382+
# Get connected components from the processed mask
164383
labels_cc = label(mask, connectivity=3, background=0)
165-
# Get componnets count
384+
385+
# Get component counts
166386
bincount = np.bincount(labels_cc.flat)
167-
# Get background label, assumption that background is the biggest connected component
387+
388+
# Get background label (assumed to be the largest component)
168389
background = np.argmax(bincount)
169390
bincount[background] = -1
391+
170392
# Get largest connected component
171393
largest_cc = labels_cc == np.argmax(bincount)
172-
# Apply mask
173-
clean_seg = seg_arr * largest_cc
174394

175-
return clean_seg,largest_cc
395+
return largest_cc
176396

177-
def clean_cc_segmentation(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
397+
def clean_cc_segmentation(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> tuple[np.ndarray, np.ndarray]:
178398
"""Clean corpus callosum segmentation by removing non-connected components.
179399
180400
This function processes a segmentation array to clean up the corpus callosum (CC)
181401
by removing non-connected components. It first isolates the CC (label 192),
182-
removes non-connected components, then adds the fornix (label 250), and
183-
finally removes non-connected components from the combined CC and fornix.
402+
attempts to connect nearby disconnected components, then adds the fornix (label 250),
403+
and finally removes non-connected components from the combined CC and fornix.
184404
185405
Args:
186406
seg_arr (np.ndarray): Input segmentation array with CC (192) and fornix (250) labels
407+
max_connection_distance (float): Maximum distance to connect components (default: 3.0)
187408
188409
Returns:
189410
tuple: A tuple containing:
190411
- clean_seg (np.ndarray): Cleaned segmentation array with only the largest
191412
connected component of CC and fornix
192413
- mask (np.ndarray): Binary mask of the largest connected component
193414
"""
194-
#Remove non connected components from the CC alone
195-
clean_seg = np.zeros_like(seg_arr)
196-
clean_seg[seg_arr == CC_LABEL] = CC_LABEL
197-
clean_seg,_ = get_largest_cc(clean_seg)
415+
# Remove non connected components from the CC alone, with minimal connections
416+
cc_seg = np.zeros_like(seg_arr)
417+
cc_seg[seg_arr == CC_LABEL] = CC_LABEL
198418

199-
#Add fornix to the CC labels
200-
clean_seg[seg_arr == FORNIX_LABEL] = FORNIX_LABEL
419+
cc_label_cleaned = np.zeros_like(cc_seg)
420+
for i in range(cc_seg.shape[0]):
421+
cc_label_cleaned[i] = get_largest_cc(cc_seg[None,i], max_connection_distance)
422+
# import matplotlib.pyplot as plt
423+
# fig, ax = plt.subplots(1,3)
424+
# ax[0].imshow(cc_seg[i])
425+
# ax[1].imshow(mask[i])
426+
# ax[2].imshow(cc_seg[i] - mask[i]*CC_LABEL) # difference between pre and post clean
427+
# plt.show()
201428

202-
#Remove non connected components from CC & Fornix
203-
clean_seg, mask = get_largest_cc(clean_seg)
204429

205-
unique_labels = np.unique(clean_seg)
430+
# Add fornix to the CC labels
431+
clean_seg = np.zeros_like(seg_arr)
432+
clean_seg[cc_label_cleaned > 0] = CC_LABEL
433+
clean_seg[seg_arr == FORNIX_LABEL] = FORNIX_LABEL
206434

207-
if 250 not in unique_labels:
208-
clean_seg[seg_arr == 250] = 250
209-
mask[seg_arr == 250] = True
210-
if 192 not in unique_labels:
211-
clean_seg[seg_arr == 192] = 192
212-
mask[seg_arr == 192] = True
213-
return clean_seg, mask
435+
return clean_seg, cc_label_cleaned > 0

0 commit comments

Comments
 (0)