Skip to content

Commit 2131b52

Browse files
committed
perf: batch profile interpolation and reuse stack-wide orientation map
Pre-compute the median-filtered orientation angle map once per stack in analyze_sarcomere_vectors and pass it through to get_sarcomere_vectors via a new precomputed_angle_map kwarg, avoiding redundant per-frame median filtering. In get_sarcomeres_length_orientation, add a uniform-length fast path that builds the Akima/linear interpolator once over the full (N, L) batch instead of once per profile. Akima construction was the dominant cost in analyze_sarcomere_vectors (~37% on real movies); the per-profile fallback is preserved for ragged inputs from external callers.
1 parent 2789d90 commit 2131b52

3 files changed

Lines changed: 152 additions & 29 deletions

File tree

sarcasm/structure.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,24 @@ def analyze_sarcomere_vectors(self, frames: Union[str, int, List[int], np.ndarra
816816
n_frames = len(z_bands)
817817
pixelsize = self.metadata.pixelsize
818818

819+
# Check pixelsize is not None
820+
if pixelsize is None:
821+
raise ValueError("Pixel size is not available. Please provide pixelsize during initialization.")
822+
823+
# Pre-compute the orientation angle map for the entire stack once.
824+
# Inside ``get_sarcomere_vectors`` this call is the second-largest
825+
# per-frame cost (median filter on a disk footprint); batching it here
826+
# lets the filter run over the full (N, 2, H, W) tensor and avoids
827+
# redundant work when ``precomputed_angle_map`` is passed below.
828+
radius_pixels = max(int(round(median_filter_radius / pixelsize, 0)), 1)
829+
angle_maps = Utils.get_orientation_angle_map(
830+
orientation_field, use_median_filter=True, radius=radius_pixels,
831+
)
832+
# ``get_orientation_angle_map`` squeezes a single-frame stack down to
833+
# (H, W); re-expand so ``angle_maps[i]`` is always valid.
834+
if angle_maps.ndim == 2:
835+
angle_maps = angle_maps[np.newaxis, ...]
836+
819837
# create empty arrays
820838
def none_lists():
821839
return [None] * self.metadata.n_stack
@@ -829,10 +847,6 @@ def nan_arrays():
829847
sarcomere_orientation_mean, sarcomere_orientation_std = nan_arrays(), nan_arrays()
830848
n_vectors, n_mbands, oop, sarcomere_area, sarcomere_area_ratio, score_thresholds = (nan_arrays() for _ in range(6))
831849

832-
# Check pixelsize is not None
833-
if pixelsize is None:
834-
raise ValueError("Pixel size is not available. Please provide pixelsize during initialization.")
835-
836850
# iterate images
837851
logger.info('Starting sarcomere length and orientation analysis...')
838852
for i, (frame_i, zbands_i, mbands_i, orientation_field_i, sarcomere_mask_i) in enumerate(
@@ -852,7 +866,8 @@ def nan_arrays():
852866
linewidth=linewidth,
853867
interpolation_method=interpolation_method,
854868
peak_prominence=peak_prominence,
855-
peak_algorithm=peak_algorithm)
869+
peak_algorithm=peak_algorithm,
870+
precomputed_angle_map=angle_maps[i])
856871

857872
# write in list
858873
n_vectors[frame_i] = len(sarcomere_length_vectors_i)

sarcasm/structure_modules/sarcomere_vectors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def get_sarcomere_vectors(
161161
interpolation_method: str = 'linear',
162162
peak_prominence: float = 0.5,
163163
peak_algorithm: str = 'default',
164+
precomputed_angle_map: Union[np.ndarray, None] = None,
164165
) -> Tuple[Union[np.ndarray, List], Union[np.ndarray, List], Union[np.ndarray, List],
165166
Union[np.ndarray, List], Union[np.ndarray, List], Union[np.ndarray, List], Union[np.ndarray, List]]:
166167
"""
@@ -221,7 +222,10 @@ def get_sarcomere_vectors(
221222
mbands_skel = skeletonize(mbands, method='lee')
222223

223224
# calculate and preprocess orientation map
224-
orientation = Utils.get_orientation_angle_map(orientation_field, use_median_filter=True, radius=radius_pixels)
225+
if precomputed_angle_map is not None:
226+
orientation = precomputed_angle_map
227+
else:
228+
orientation = Utils.get_orientation_angle_map(orientation_field, use_median_filter=True, radius=radius_pixels)
225229

226230
# label mbands
227231
midline_labels, n_mbands = ndimage.label(mbands_skel,

sarcasm/utils.py

Lines changed: 127 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,27 @@
4141

4242

4343

44+
def _batched_linear_interp(
45+
x: np.ndarray, y: np.ndarray, x_new: np.ndarray,
46+
) -> np.ndarray:
47+
"""Batched linear interpolation equivalent to calling ``np.interp`` once per
48+
row of ``y``. ``x`` is 1-D (shared), ``y`` has shape ``(N, L)``, ``x_new``
49+
is 1-D. Returns ``(N, len(x_new))``.
50+
51+
Uses ``searchsorted`` to compute indices once, then a vectorised two-tap
52+
linear combination. Matches ``np.interp`` behaviour including constant
53+
extrapolation at the boundaries.
54+
"""
55+
L = x.shape[0]
56+
idx = np.clip(np.searchsorted(x, x_new, side='right') - 1, 0, L - 2)
57+
x0 = x[idx]
58+
x1 = x[idx + 1]
59+
w = (x_new - x0) / (x1 - x0)
60+
# Clamp for exact endpoint / out-of-range behaviour to match np.interp.
61+
w = np.clip(w, 0.0, 1.0)
62+
return y[:, idx] * (1.0 - w) + y[:, idx + 1] * w
63+
64+
4465
class Utils:
4566
""" Miscellaneous utility functions """
4667

@@ -753,56 +774,143 @@ def process_profiles_batch(
753774
n_profiles = len(profiles)
754775
sarcomere_lengths = np.empty(n_profiles, dtype=np.float64)
755776
center_offsets = np.empty(n_profiles, dtype=np.float64)
756-
777+
778+
if n_profiles == 0:
779+
return sarcomere_lengths, center_offsets
780+
757781
# Pre-compute constants
758782
min_dist_pixel = int(np.round(min_dist / pixelsize, 0))
759783
width_pixels = int(np.round(width / pixelsize, 0))
760784
window_size = width_pixels * max(interp_factor, 1)
761-
785+
786+
# Fast path: all profiles share the same length (typical case when the
787+
# caller uses ``fast_profile_lines`` with uniform endpoints). Build the
788+
# interpolator **once** over the full (N, L) batch instead of once per
789+
# profile — Akima construction was the dominant cost (~37% of
790+
# ``analyze_sarcomere_vectors`` end-to-end on real movies).
791+
lengths0 = len(profiles[0])
792+
uniform_length = all(len(p) == lengths0 for p in profiles)
793+
794+
if uniform_length and lengths0 >= 2:
795+
L = lengths0
796+
# Preserve the caller's dtype through normalization so the batched
797+
# path matches per-profile numerics bit-for-bit.
798+
y_mat = np.stack(profiles, axis=0)
799+
if y_mat.ndim != 2:
800+
y_mat = y_mat.reshape(n_profiles, L)
801+
pmin = y_mat.min(axis=1)
802+
pmax = y_mat.max(axis=1)
803+
flat_mask = pmax <= pmin
804+
denom = np.where(flat_mask, y_mat.dtype.type(1.0), pmax - pmin)
805+
y_norm = (y_mat - pmin[:, None]) / denom[:, None]
806+
807+
pos_array = np.arange(L) * pixelsize
808+
if interp_factor >= 1:
809+
L_up = L * interp_factor
810+
x_interp = np.linspace(pos_array[0], pos_array[-1], num=L_up)
811+
if interpolation_method == 'akima':
812+
# One construct on the whole batch (scipy supports axis=).
813+
itp = Akima1DInterpolator(pos_array, y_norm, axis=1)
814+
y_up = itp(x_interp)
815+
else:
816+
# Batched linear interpolation — equivalent to per-row
817+
# ``np.interp`` but without the Python loop.
818+
y_up = _batched_linear_interp(pos_array, y_norm, x_interp)
819+
actual_interp_factor = interp_factor
820+
else:
821+
L_up = L
822+
y_up = y_norm
823+
x_interp = pos_array
824+
actual_interp_factor = 1
825+
826+
peak_distance = max(1, min_dist_pixel * actual_interp_factor)
827+
center = (pos_array[-1] + pos_array[0]) * 0.5
828+
829+
for i in range(n_profiles):
830+
if flat_mask[i]:
831+
sarcomere_lengths[i] = np.nan
832+
center_offsets[i] = np.nan
833+
continue
834+
835+
y_interp = y_up[i]
836+
peaks_idx, _ = find_peaks(
837+
y_interp, height=thres, distance=peak_distance,
838+
prominence=prominence,
839+
)
840+
if len(peaks_idx) < 2:
841+
sarcomere_lengths[i] = np.nan
842+
center_offsets[i] = np.nan
843+
continue
844+
845+
peaks = np.empty(len(peaks_idx), dtype=np.float64)
846+
for j, idx in enumerate(peaks_idx):
847+
start = max(0, idx - window_size)
848+
end = min(L_up, idx + window_size + 1)
849+
x_window = x_interp[start:end]
850+
y_window = y_interp[start:end] - y_interp[start:end].min()
851+
y_sum = y_window.sum()
852+
if y_sum > 0:
853+
peaks[j] = np.dot(x_window, y_window) / y_sum
854+
else:
855+
peaks[j] = x_interp[idx]
856+
857+
left_mask = peaks < center
858+
right_mask = ~left_mask
859+
if not (left_mask.any() and right_mask.any()):
860+
sarcomere_lengths[i] = np.nan
861+
center_offsets[i] = np.nan
862+
continue
863+
left_peak = peaks[left_mask][-1]
864+
right_peak = peaks[right_mask][0]
865+
slen_profile = right_peak - left_peak
866+
if slen_lims[0] <= slen_profile <= slen_lims[1]:
867+
sarcomere_lengths[i] = slen_profile
868+
center_offsets[i] = (left_peak + right_peak) * 0.5 - center
869+
else:
870+
sarcomere_lengths[i] = np.nan
871+
center_offsets[i] = np.nan
872+
873+
return sarcomere_lengths, center_offsets
874+
875+
# Fallback: variable-length profiles — retain the original per-profile
876+
# path. Rare in production since the common caller uses uniform
877+
# endpoints, but needed for external callers that pass in ragged lists.
762878
for i, profile in enumerate(profiles):
763-
# Normalize profile to [0,1] range
764879
pmin = profile.min()
765880
pmax = profile.max()
766881
if pmax == pmin:
767882
sarcomere_lengths[i] = np.nan
768883
center_offsets[i] = np.nan
769884
continue
770-
885+
771886
profile_norm = (profile - pmin) / (pmax - pmin)
772-
773-
# Create position array
774887
pos_array = np.arange(len(profile)) * pixelsize
775-
888+
776889
if interp_factor >= 1:
777-
# Use selected interpolation method
778890
x_interp = np.linspace(pos_array[0], pos_array[-1],
779891
num=len(profile) * interp_factor)
780892
if interpolation_method == 'akima':
781-
# Akima interpolation for smoother profiles (slower)
782893
interp_func = Akima1DInterpolator(pos_array, profile_norm)
783894
y_interp = interp_func(x_interp)
784895
else:
785-
# Linear interpolation (faster, default)
786896
y_interp = np.interp(x_interp, pos_array, profile_norm)
787897
actual_interp_factor = interp_factor
788898
else:
789899
y_interp = profile_norm
790900
x_interp = pos_array
791901
actual_interp_factor = 1
792-
793-
# Find peaks (ensure distance >= 1)
902+
794903
peak_distance = max(1, min_dist_pixel * actual_interp_factor)
795904
peaks_idx, _ = find_peaks(y_interp,
796905
height=thres,
797906
distance=peak_distance,
798907
prominence=prominence)
799-
908+
800909
if len(peaks_idx) < 2:
801910
sarcomere_lengths[i] = np.nan
802911
center_offsets[i] = np.nan
803912
continue
804-
805-
# Calculate refined peak positions using center of mass
913+
806914
peaks = np.empty(len(peaks_idx), dtype=np.float64)
807915
for j, idx in enumerate(peaks_idx):
808916
start = max(0, idx - window_size)
@@ -815,31 +923,27 @@ def process_profiles_batch(
815923
peaks[j] = np.dot(x_window, y_window) / y_sum
816924
else:
817925
peaks[j] = x_interp[idx]
818-
926+
819927
center = (pos_array[-1] + pos_array[0]) * 0.5
820-
821-
# Split peaks into left and right of center
822928
left_mask = peaks < center
823929
right_mask = peaks >= center
824-
825930
if not (left_mask.any() and right_mask.any()):
826931
sarcomere_lengths[i] = np.nan
827932
center_offsets[i] = np.nan
828933
continue
829-
830-
# Take rightmost peak from left side and leftmost peak from right side
934+
831935
left_peak = peaks[left_mask][-1]
832936
right_peak = peaks[right_mask][0]
833937
slen_profile = right_peak - left_peak
834938
center_offset = (left_peak + right_peak) * 0.5 - center
835-
939+
836940
if slen_lims[0] <= slen_profile <= slen_lims[1]:
837941
sarcomere_lengths[i] = slen_profile
838942
center_offsets[i] = center_offset
839943
else:
840944
sarcomere_lengths[i] = np.nan
841945
center_offsets[i] = np.nan
842-
946+
843947
return sarcomere_lengths, center_offsets
844948

845949
@staticmethod

0 commit comments

Comments
 (0)