Skip to content

Commit ac5a7a4

Browse files
committed
fix few issues
1 parent 3268fbb commit ac5a7a4

8 files changed

Lines changed: 80 additions & 87 deletions

File tree

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
</p>
1010
<br />
1111

12-
<p align="center">A Python package for calculating 3D cardiomyocyte orientations in heart images.</p>
12+
<p align="center">A Python package to quantify and visualize 3D cardiomyocyte orientation in heart imaging datasets</p>
1313

1414
[![CI](https://github.com/JosephBrunet/cardiotensor/actions/workflows/ci.yml/badge.svg)](https://github.com/JosephBrunet/cardiotensor/actions/workflows/ci.yml)
1515
[![Doc](https://img.shields.io/badge/docs-dev-blue.svg)](https://JosephBrunet.github.io/cardiotensor/)
@@ -18,11 +18,6 @@
1818
[![Python Version](https://img.shields.io/pypi/pyversions/cardiotensor.svg)](https://pypi.python.org/pypi/cardiotensor)
1919
[![PyPI version](https://img.shields.io/pypi/v/cardiotensor.svg)](https://pypi.org/project/cardiotensor/)
2020

21-
<p align="center">
22-
<img src="https://github.com/JosephBrunet/cardiotensor/raw/main/assets/images/result_HA_slice.jpeg" alt="Example Slice" style="max-width: 80%">
23-
<br>
24-
<em>Helical angle map of a heart scanned using synchrtron X-ray imaging.</em>
25-
</p>
2621

2722
## Introduction
2823

@@ -51,6 +46,11 @@ cardiotensor's documentation is available at [josephbrunet.fr/cardiotensor/](htt
5146

5247
Have a look at our [simple example](https://www.josephbrunet.fr/cardiotensor/getting-started/examples/) that runs you through all the commands of the package
5348

49+
<p align="center">
50+
<img src="https://github.com/JosephBrunet/cardiotensor/raw/main/assets/images/result_HA_slice.jpeg" alt="Example Slice" style="max-width: 70%">
51+
<br>
52+
<em>Helical angle map of a heart scanned using synchrtron X-ray imaging.</em>
53+
</p>
5454

5555
## More Information
5656

src/cardiotensor/orientation/orientation_computation_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,6 @@ def rotate_vectors_to_new_axis(
355355
# Reshape back to the original shape
356356
rotated_vecs = rotated_vecs.reshape(vector_field_slice.shape)
357357

358-
# print(f"Rotation matrix:\n{rotation_matrix}")
359-
360358
return rotated_vecs
361359

362360

@@ -454,15 +452,16 @@ def plot_images(
454452
ax[0, 0].legend(loc="upper right")
455453

456454
# Helix Image
457-
tmp = ax[0, 1].imshow(img_helix, cmap=orig_map)
455+
tmp = ax[0, 1].imshow(img_helix, cmap=orig_map, vmin=-90, vmax=90)
458456
ax[0, 1].set_title("Helix Angle")
459457

460458
# Intrusion Image
461-
ax[1, 0].imshow(img_intrusion, cmap=orig_map)
459+
ax[1, 0].imshow(img_intrusion, cmap=orig_map, vmin=-90, vmax=90)
462460
ax[1, 0].set_title("Intrusion Angle")
463461

464462
# FA Image
465-
fa_plot = ax[1, 1].imshow(img_FA, cmap="inferno")
463+
fa_plot = ax[1, 1].imshow(img_FA, cmap="inferno", vmin=0, vmax=1)
464+
466465
ax[1, 1].set_title("Fractional Anisotropy")
467466

468467
# Add colorbars for relevant subplots
@@ -598,6 +597,7 @@ def write_img_rgb(
598597

599598
print(f"Writing image to {output_path}")
600599
if OUTPUT_FORMAT == "jp2":
600+
ratio_compression = 10
601601
glymur.Jp2k(
602602
output_path,
603603
data=img,

src/cardiotensor/orientation/orientation_computation_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def compute_orientation(
192192

193193
mask = mask_reader.load_volume(
194194
start_index_padded, end_index_padded, unbinned_shape=data_reader.shape
195-
).astype("float32")
195+
)
196196

197197
assert mask.shape == volume.shape, (
198198
f"Mask shape {mask.shape} does not match volume shape {volume.shape}"

src/cardiotensor/scripts/generate_streamlines.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
from cardiotensor.utils.downsampling import downsample_vector_volume, downsample_volume
2323
from cardiotensor.utils.utils import read_conf_file
2424

25-
from memory_profiler import profile
2625

27-
@profile
2826
def script():
2927
parser = argparse.ArgumentParser(
3028
description="Trace streamlines from a 3D vector field and save to .npz"
@@ -149,12 +147,15 @@ def script():
149147
sys.exit(1)
150148

151149
print("Ensuring Z-components are positive...")
152-
neg_mask = vector_field[2] < 0
150+
neg_mask = vector_field[0] < 0
153151
vector_field[:, neg_mask] *= -1
154152
del neg_mask
155153

156154
if MASK_PATH:
157155
print("Applying mask from config...")
156+
157+
print(f"MASK_PATH: {MASK_PATH}\nstart_binned: {start_binned}, end_binned: {end_binned}", vec_reader.shape[1:])
158+
158159
mask_reader = DataReader(MASK_PATH)
159160

160161
# Load the corresponding mask volume, resampled to match vector field shape

src/cardiotensor/scripts/visualize_vector_field.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from cardiotensor.utils.utils import read_conf_file
1212

1313

14+
1415
def script():
1516
parser = argparse.ArgumentParser(
1617
description="Plot 3D vector field using FURY from configuration file."
@@ -85,17 +86,20 @@ def script():
8586

8687
print("Loading vector field...")
8788
vec_reader = DataReader(vec_load_dir)
88-
vector_slices = vec_reader.load_volume(
89+
vector_field = vec_reader.load_volume(
8990
start_index=start_binned, end_index=end_binned
9091
)
92+
93+
print("Ensuring Z-components are positive...")
94+
neg_mask = vector_field[0] > 0 # Identify where Z component is negative
95+
vector_field[:, neg_mask] *= -1 # Flip the entire vector at that location
96+
del neg_mask
97+
98+
9199

92100
# If your vector_field is in shape (3, Z, Y, X), convert it:
93-
if vector_slices.shape[0] == 3:
94-
vector_field = np.moveaxis(vector_slices, 0, -1)
95-
96-
print("Ensuring Z-components are positive...")
97-
neg_mask = vector_field[2] < 0
98-
vector_field[:, neg_mask] *= -1
101+
if vector_field.shape[0] == 3:
102+
vector_field = np.moveaxis(vector_field, 0, -1)
99103

100104
if MASK_PATH:
101105
print("Applying mask from config...")

src/cardiotensor/tractography/generate_streamlines.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,20 @@ def trace_streamline(
100100
max_steps: int | None = 1000,
101101
angle_threshold: float = 60.0,
102102
eps: float = 1e-10,
103+
direction: int = 1,
103104
) -> list[tuple[float, float, float]]:
104105
"""
105106
Trace one streamline from `start_pt` (z,y,x) in the continuous vector_field.
106107
- Interpolate & normalize each sub‐step
107-
- Move forward by `step_length` voxels each step (Euler or RK4)
108-
- Stop if turning angle > angle_threshold or out of bounds or `vec` too small.
109-
- If max_steps is None, trace until a stopping condition is hit (no hard limit).
108+
- Move forward by `step_length` voxels each step using RK4
109+
- `direction` = +1 (default) or -1 to reverse integration direction
110110
"""
111111
Z, Y, X = vector_field.shape[1:]
112112
coords: list[tuple[float, float, float]] = [
113113
(float(start_pt[0]), float(start_pt[1]), float(start_pt[2]))
114114
]
115115
current_pt = np.array(start_pt, dtype=np.float64)
116-
prev_dir: np.ndarray | None = None # previous unit vector
116+
prev_dir: np.ndarray | None = None # previous unit vector
117117

118118
def interp_unit(pt: np.ndarray) -> np.ndarray | None:
119119
"""Return a normalized direction vector at fractional pt, or None if invalid."""
@@ -123,7 +123,7 @@ def interp_unit(pt: np.ndarray) -> np.ndarray | None:
123123
norm = np.linalg.norm(vec)
124124
if norm < eps:
125125
return None
126-
return np.array([vec[2], vec[1], vec[0]]) / norm # flip to (z,y,x) order
126+
return np.array([vec[2], vec[1], vec[0]]) / norm * direction # flip to (z,y,x)
127127

128128
step_count = 0
129129
while max_steps is None or step_count < max_steps:
@@ -134,9 +134,6 @@ def interp_unit(pt: np.ndarray) -> np.ndarray | None:
134134
if fa_value < fa_threshold:
135135
break
136136

137-
# ------
138-
# Runge-Kutta 4th order integration
139-
140137
k1 = interp_unit(current_pt)
141138
if k1 is None:
142139
break
@@ -168,8 +165,6 @@ def interp_unit(pt: np.ndarray) -> np.ndarray | None:
168165
next_pt = current_pt + increment
169166
next_dir = k1
170167

171-
# ------
172-
173168
zn, yn, xn = next_pt
174169
if not (0 <= zn < Z and 0 <= yn < Y and 0 <= xn < X):
175170
break
@@ -181,6 +176,7 @@ def interp_unit(pt: np.ndarray) -> np.ndarray | None:
181176
return coords
182177

183178

179+
184180
def generate_streamlines_from_vector_field(
185181
vector_field: np.ndarray,
186182
seed_points: np.ndarray,
@@ -214,18 +210,20 @@ def generate_streamlines_from_vector_field(
214210
step_length=step_length,
215211
max_steps=max_steps,
216212
angle_threshold=angle_threshold,
213+
direction=1,
217214
)
218215

219216
# Backward tracing if enabled
220217
if bidirectional:
221218
backward_pts = trace_streamline(
222219
start_pt=start,
223-
vector_field=-vector_field,
220+
vector_field=vector_field,
224221
fa_volume=fa_volume,
225222
fa_threshold=fa_threshold,
226223
step_length=step_length,
227224
max_steps=max_steps,
228225
angle_threshold=angle_threshold,
226+
direction=-1,
229227
)
230228
# Remove duplicate seed point and reverse
231229
backward_pts = backward_pts[::-1][:-1] if len(backward_pts) > 1 else []

src/cardiotensor/utils/DataReader.py

Lines changed: 34 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from alive_progress import alive_bar
1111
from scipy.ndimage import zoom
1212

13-
1413
class DataReader:
1514
def __init__(self, path: str | Path):
1615
"""
@@ -109,81 +108,65 @@ def load_volume(
109108
unbinned_shape: tuple[int, int, int] | None = None,
110109
) -> np.ndarray:
111110
"""
112-
Loads the volume data based on the detected volume type.
111+
Loads the volume and resizes it to unbinned_shape if provided, using fast scipy.ndimage.zoom.
113112
114113
Args:
115114
start_index (int): Start index for slicing (for stacks).
116115
end_index (int): End index for slicing (for stacks). If None, loads the entire stack.
117-
unbinned_shape (tuple): Shape of the volume without downsampling. Default is None (no binning).
116+
unbinned_shape (tuple): Desired shape (Z, Y, X). If None, no resizing is done.
118117
119118
Returns:
120-
np.ndarray: Loaded (and possibly resized) volume data.
119+
np.ndarray: Loaded volume.
121120
"""
122-
123121
if end_index is None:
124122
end_index = self.shape[0]
125123

126-
# compute how much we’ll need to zoom later
127-
binning_factor = 1.0
128-
if unbinned_shape is not None:
129-
binning_factor = unbinned_shape[0] / self.shape[0]
130-
print(f"Mask bining factor: {binning_factor}")
124+
# Decide if resize is needed
125+
need_resize = False
126+
if unbinned_shape is not None and self.shape != unbinned_shape:
127+
need_resize = True
128+
zoom_factors = tuple(u / s for u, s in zip(unbinned_shape, self.shape))
129+
print(f"Zoom factors: {zoom_factors}")
130+
else:
131+
zoom_factors = (1.0, 1.0, 1.0)
131132

132-
# if we’re going to zoom, adjust the requested slice range to include padding
133-
if binning_factor != 1.0:
133+
# Optional padding if resizing
134+
if need_resize:
134135
start_index_ini, end_index_ini = start_index, end_index
135-
start_index = int(start_index_ini / binning_factor) - 1
136+
start_index = int(start_index_ini / zoom_factors[0]) - 1
136137
start_index = max(start_index, 0)
137-
end_index = int(end_index_ini / binning_factor) + 1
138+
end_index = int(end_index_ini / zoom_factors[0]) + 1
138139
end_index = min(end_index, self.shape[0])
139-
print(
140-
f"Mask start index padded: {start_index} - Mask end index padded: {end_index}"
141-
)
140+
print(f"Volume start index padded: {start_index} - end: {end_index}")
142141

143-
# load the raw volume
142+
# Load volume from stack or mhd
144143
if not self.volume_info["stack"]:
145-
# single-file (e.g. .mhd)
146144
if self.volume_info["type"] == "mhd":
147145
volume, _ = _load_raw_data_with_mhd(self.path)
148146
volume = volume[start_index:end_index, :, :]
149147
else:
150-
# image stack
151148
volume = self._load_image_stack(
152149
self.volume_info["file_list"], start_index, end_index
153150
)
154151

155-
# if we need to upsample + crop + resize each slice back to unbinned_shape:
156-
if binning_factor != 1.0 and unbinned_shape is not None:
157-
print("Resizing mask")
158-
# 1) zoom the 3D block
159-
volume = zoom(volume, zoom=binning_factor, order=0)
160-
161-
# 2) figure out where our original slice window ended up
162-
start_up = int(abs(start_index * binning_factor - start_index_ini))
163-
end_up = start_up + (end_index_ini - start_index_ini)
164-
start_up = max(start_up, 0)
165-
end_up = min(end_up, volume.shape[0])
166-
167-
# 3) crop to just those slices
168-
volume = volume[start_up:end_up, :, :]
169-
170-
# 4) allocate exactly the target (unbinned) shape and resize each slice
171-
volume_resized = np.empty(
172-
(volume.shape[0], unbinned_shape[1], unbinned_shape[2]),
173-
dtype=volume.dtype,
174-
)
175-
for i in range(volume.shape[0]):
176-
volume_resized[i] = cv2.resize(
177-
volume[i],
178-
(unbinned_shape[2], unbinned_shape[1]),
179-
interpolation=cv2.INTER_LINEAR,
180-
)
181-
182-
# 5) replace volume with the resized version
183-
volume = volume_resized
184-
152+
if need_resize:
153+
print("Resizing with scipy.ndimage.zoom...")
154+
155+
# Ensure float32 for better memory and speed
156+
volume = volume.astype(np.float32)
157+
volume = zoom(volume, zoom=zoom_factors, order=0) # Nearest-neighbor for mask
158+
159+
# Final crop to original range
160+
crop_start = int(abs(start_index * zoom_factors[0] - start_index_ini))
161+
crop_end = crop_start + (end_index_ini - start_index_ini)
162+
crop_start = max(crop_start, 0)
163+
crop_end = min(crop_end, volume.shape[0])
164+
165+
volume = volume[crop_start:crop_end, :, :]
166+
185167
return volume
186168

169+
187170
def _custom_image_reader(self, file_path: Path) -> np.ndarray:
188171
"""
189172
Reads an image from the given file path.
@@ -380,3 +363,4 @@ def _load_raw_data_with_mhd(
380363
# End 3D fix
381364

382365
return (data, meta_dict)
366+

src/cardiotensor/utils/plot_vector_field.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,27 @@ def plot_vector_field_with_fury(
2626
print("Creating coordinate grid...")
2727
zz, yy, xx = np.mgrid[0:Z:stride, 0:Y:stride, 0:X:stride]
2828
coords = np.stack((zz, yy, xx), axis=-1)
29-
vectors = vector_field[0:Z:stride, 0:Y:stride, 0:X:stride]
29+
vector_field = vector_field[0:Z:stride, 0:Y:stride, 0:X:stride]
3030

3131
# Flatten coordinates and vectors
3232
coords_flat = coords.reshape(-1, 3)
33-
vectors_flat = vectors.reshape(-1, 3)
33+
vectors_flat = vector_field.reshape(-1, 3)
34+
del vector_field
3435

3536
print("Extracting and filtering vectors...")
3637
norms = np.linalg.norm(vectors_flat, axis=1)
3738
valid_mask = norms > 0
3839

3940
centers = coords_flat[valid_mask] * voxel_size
4041
directions = vectors_flat[valid_mask] / norms[valid_mask, None]
42+
del coords_flat, vectors_flat, norms
4143

4244
print("Generating colors...")
4345
if ha_volume is not None:
4446
ha_sub = ha_volume[0:Z:stride, 0:Y:stride, 0:X:stride]
4547
ha_flat = ha_sub.reshape(-1)
4648
ha_values = ha_flat[valid_mask]
47-
color_array = colormap.create_colormap(ha_values, name="plasma", auto=True)
49+
color_array = colormap.create_colormap(ha_values, name="hsv", auto=True)
4850
else:
4951
color_array = np.tile([1.0, 0.0, 0.0], (centers.shape[0], 1))
5052

@@ -54,8 +56,12 @@ def plot_vector_field_with_fury(
5456

5557
print("Creating arrow actor...")
5658
arrow_actor = actor.arrow(
57-
centers, directions, colors=color_array, scales=10 * size_arrow
59+
centers,
60+
directions,
61+
colors=color_array,
62+
scales=10 * size_arrow,
5863
)
64+
5965
scene = window.Scene()
6066
scene.add(arrow_actor)
6167

0 commit comments

Comments
 (0)