Skip to content

Commit 7f358fe

Browse files
committed
fixed bug in keep_slice_range_by_area
1 parent 9b48aa2 commit 7f358fe

2 files changed

Lines changed: 86 additions & 25 deletions

File tree

src/napari_tmidas/_tests/test_processing_basic.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -197,32 +197,44 @@ def test_mirror_labels_invalid_axis(self):
197197
mirror_labels(image, axis=2)
198198

199199
def test_keep_slice_range_by_area_basic(self):
200-
"""Keep slices between minimum and maximum area along default axis"""
200+
"""Keep label content between minimum and maximum area, preserving shape"""
201201
volume = np.zeros((5, 4, 4), dtype=np.int32)
202-
volume[0, 0, 0] = 1 # area 1
202+
volume[0, 0, 0] = 1 # area 1 (min)
203203
volume[1, :2, :2] = 1 # area 4
204204
volume[2, :3, :3] = 1 # area 9 (max)
205205
volume[3, :1, :3] = 1 # area 3
206206
volume[4, :2, :1] = 1 # area 2
207207

208208
result = keep_slice_range_by_area(volume)
209209

210-
assert result.shape == (3, 4, 4)
211-
np.testing.assert_array_equal(result, volume[0:3])
210+
# Shape should be preserved
211+
assert result.shape == (5, 4, 4)
212+
# Content between min (slice 0) and max (slice 2) should be kept
213+
np.testing.assert_array_equal(result[0:3], volume[0:3])
214+
# Content after max should be zeroed
215+
np.testing.assert_array_equal(result[3:], np.zeros((2, 4, 4)))
212216

213217
def test_keep_slice_range_by_area_with_axis(self):
214-
"""Axis parameter allows trimming along any dimension"""
215-
base = np.zeros((5, 4, 4), dtype=np.uint16)
216-
base[0, 0, 0] = 1
217-
base[1, :2, :2] = 1
218-
base[2, :3, :3] = 1
219-
reordered = base.transpose(1, 0, 2)
220-
221-
result = keep_slice_range_by_area(reordered, axis=1)
222-
223-
expected = reordered[:, 0:3, :]
224-
assert result.shape == expected.shape
225-
np.testing.assert_array_equal(result, expected)
218+
"""Axis parameter allows zeroing content along any dimension while preserving shape"""
219+
# Create volume with different areas along axis 1
220+
volume = np.zeros((4, 5, 3), dtype=np.uint16)
221+
volume[:2, 0, :2] = 1 # slice 0: area = 2*2 = 4
222+
volume[:, 1, :] = 1 # slice 1: area = 4*3 = 12 (max)
223+
volume[:3, 2, :] = 1 # slice 2: area = 3*3 = 9
224+
volume[:2, 3, :2] = 1 # slice 3: area = 2*2 = 4
225+
volume[0, 4, 0] = 1 # slice 4: area = 1 (min)
226+
227+
result = keep_slice_range_by_area(volume, axis=1)
228+
229+
# Shape should be preserved
230+
assert result.shape == volume.shape
231+
# Min area is at slice 4, max area is at slice 1, so range is 1-4 (inclusive)
232+
# Slice 0 should be zeroed (before the range)
233+
np.testing.assert_array_equal(
234+
result[:, 0, :], np.zeros((4, 3), dtype=np.uint16)
235+
)
236+
# Slices 1-4 should be kept
237+
np.testing.assert_array_equal(result[:, 1:5, :], volume[:, 1:5, :])
226238

227239
def test_keep_slice_range_by_area_uniform(self):
228240
"""Uniform area returns the original volume"""
@@ -232,6 +244,42 @@ def test_keep_slice_range_by_area_uniform(self):
232244

233245
np.testing.assert_array_equal(result, volume)
234246

247+
def test_keep_slice_range_by_area_shape_preserved(self):
248+
"""Verify that output shape matches input shape (critical for image-label alignment)"""
249+
# Simulate a label volume with 100 z-slices where labels exist in slices 20-80
250+
volume = np.zeros((100, 50, 50), dtype=np.uint32)
251+
volume[20, :10, :10] = 1 # Sparse content at slice 20 (min area)
252+
for i in range(21, 80):
253+
volume[i, :30, :30] = i # Denser content in middle slices
254+
volume[79, :, :] = 100 # Maximum content at slice 79 (max area)
255+
# Slices 0-19 and 80-99 should be empty and get zeroed
256+
257+
result = keep_slice_range_by_area(volume, axis=0)
258+
259+
# Critical: shape must be preserved to maintain alignment with image data
260+
assert result.shape == (
261+
100,
262+
50,
263+
50,
264+
), "Output shape must match input shape"
265+
266+
# Slices before min (0-19) should be zeroed
267+
assert np.all(
268+
result[:20] == 0
269+
), "Slices before min-area slice should be zeroed"
270+
271+
# Slices between min and max (20-79) should be preserved
272+
np.testing.assert_array_equal(
273+
result[20:80],
274+
volume[20:80],
275+
err_msg="Label content in range should be preserved",
276+
)
277+
278+
# Slices after max (80-99) should be zeroed
279+
assert np.all(
280+
result[80:] == 0
281+
), "Slices after max-area slice should be zeroed"
282+
235283
def test_keep_slice_range_by_area_invalid_dims(self):
236284
"""At least 3 dimensions are required"""
237285
image = np.ones((4, 4), dtype=np.uint8)

src/napari_tmidas/processing_functions/basic.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def _load_label_file(path: str) -> np.ndarray:
304304
@BatchProcessingRegistry.register(
305305
name="Keep Slice Range by Area",
306306
suffix="_area_range",
307-
description="Keep only slices between the minimum and maximum non-zero area along the chosen axis",
307+
description="Zero out label content outside the min/max area slice range (preserves image shape for alignment)",
308308
parameters={
309309
"axis": {
310310
"type": int,
@@ -314,12 +314,12 @@ def _load_label_file(path: str) -> np.ndarray:
314314
},
315315
)
316316
def keep_slice_range_by_area(image: np.ndarray, axis: int = 0) -> np.ndarray:
317-
"""Return only the slices between the minimum-area and maximum-area slices (inclusive).
317+
"""Keep label content only between the minimum-area and maximum-area slices (inclusive).
318318
319319
The per-slice area is measured as the number of non-zero pixels in the slice. When all slices
320-
share the same area, the original volume is returned unchanged. Useful for trimming empty
321-
leading/trailing slices while preserving the region between the smallest and largest occupied
322-
slices.
320+
share the same area, the original volume is returned unchanged. This function preserves the
321+
original image dimensions but zeros out label content outside the detected range, ensuring
322+
alignment with corresponding image data is maintained.
323323
324324
Parameters
325325
----------
@@ -331,7 +331,8 @@ def keep_slice_range_by_area(image: np.ndarray, axis: int = 0) -> np.ndarray:
331331
Returns
332332
-------
333333
numpy.ndarray
334-
Volume containing only the slices between the minimum and maximum area slices (inclusive).
334+
Volume with the same shape as input, but with label content zeroed outside the
335+
minimum and maximum area slice range (inclusive).
335336
"""
336337

337338
if image.ndim < 3:
@@ -368,10 +369,22 @@ def keep_slice_range_by_area(image: np.ndarray, axis: int = 0) -> np.ndarray:
368369
start = min(min_idx, max_idx)
369370
end = max(min_idx, max_idx)
370371

371-
slicer = [slice(None)] * image.ndim
372-
slicer[axis] = slice(start, end + 1)
372+
# Create a copy of the full image to preserve shape
373+
result = image.copy()
373374

374-
return image[tuple(slicer)].copy()
375+
# Zero out slices before the start
376+
if start > 0:
377+
before_slicer = [slice(None)] * image.ndim
378+
before_slicer[axis] = slice(0, start)
379+
result[tuple(before_slicer)] = 0
380+
381+
# Zero out slices after the end
382+
if end < image.shape[axis] - 1:
383+
after_slicer = [slice(None)] * image.ndim
384+
after_slicer[axis] = slice(end + 1, None)
385+
result[tuple(after_slicer)] = 0
386+
387+
return result
375388

376389

377390
@BatchProcessingRegistry.register(

0 commit comments

Comments
 (0)