@@ -197,32 +197,44 @@ def test_mirror_labels_invalid_axis(self):
197197 mirror_labels (image , axis = 2 )
198198
199199 def test_keep_slice_range_by_area_basic (self ):
200- """Keep slices between minimum and maximum area along default axis """
200+ """Keep label content between minimum and maximum area, preserving shape """
201201 volume = np .zeros ((5 , 4 , 4 ), dtype = np .int32 )
202- volume [0 , 0 , 0 ] = 1 # area 1
202+ volume [0 , 0 , 0 ] = 1 # area 1 (min)
203203 volume [1 , :2 , :2 ] = 1 # area 4
204204 volume [2 , :3 , :3 ] = 1 # area 9 (max)
205205 volume [3 , :1 , :3 ] = 1 # area 3
206206 volume [4 , :2 , :1 ] = 1 # area 2
207207
208208 result = keep_slice_range_by_area (volume )
209209
210- assert result .shape == (3 , 4 , 4 )
211- np .testing .assert_array_equal (result , volume [0 :3 ])
210+ # Shape should be preserved
211+ assert result .shape == (5 , 4 , 4 )
212+ # Content between min (slice 0) and max (slice 2) should be kept
213+ np .testing .assert_array_equal (result [0 :3 ], volume [0 :3 ])
214+ # Content after max should be zeroed
215+ np .testing .assert_array_equal (result [3 :], np .zeros ((2 , 4 , 4 )))
212216
213217 def test_keep_slice_range_by_area_with_axis (self ):
214- """Axis parameter allows trimming along any dimension"""
215- base = np .zeros ((5 , 4 , 4 ), dtype = np .uint16 )
216- base [0 , 0 , 0 ] = 1
217- base [1 , :2 , :2 ] = 1
218- base [2 , :3 , :3 ] = 1
219- reordered = base .transpose (1 , 0 , 2 )
220-
221- result = keep_slice_range_by_area (reordered , axis = 1 )
222-
223- expected = reordered [:, 0 :3 , :]
224- assert result .shape == expected .shape
225- np .testing .assert_array_equal (result , expected )
218+ """Axis parameter allows zeroing content along any dimension while preserving shape"""
219+ # Create volume with different areas along axis 1
220+ volume = np .zeros ((4 , 5 , 3 ), dtype = np .uint16 )
221+ volume [:2 , 0 , :2 ] = 1 # slice 0: area = 2*2 = 4
222+ volume [:, 1 , :] = 1 # slice 1: area = 4*3 = 12 (max)
223+ volume [:3 , 2 , :] = 1 # slice 2: area = 3*3 = 9
224+ volume [:2 , 3 , :2 ] = 1 # slice 3: area = 2*2 = 4
225+ volume [0 , 4 , 0 ] = 1 # slice 4: area = 1 (min)
226+
227+ result = keep_slice_range_by_area (volume , axis = 1 )
228+
229+ # Shape should be preserved
230+ assert result .shape == volume .shape
231+ # Min area is at slice 4, max area is at slice 1, so range is 1-4 (inclusive)
232+ # Slice 0 should be zeroed (before the range)
233+ np .testing .assert_array_equal (
234+ result [:, 0 , :], np .zeros ((4 , 3 ), dtype = np .uint16 )
235+ )
236+ # Slices 1-4 should be kept
237+ np .testing .assert_array_equal (result [:, 1 :5 , :], volume [:, 1 :5 , :])
226238
227239 def test_keep_slice_range_by_area_uniform (self ):
228240 """Uniform area returns the original volume"""
@@ -232,6 +244,42 @@ def test_keep_slice_range_by_area_uniform(self):
232244
233245 np .testing .assert_array_equal (result , volume )
234246
247+ def test_keep_slice_range_by_area_shape_preserved (self ):
248+ """Verify that output shape matches input shape (critical for image-label alignment)"""
249+ # Simulate a label volume with 100 z-slices where labels exist in slices 20-80
250+ volume = np .zeros ((100 , 50 , 50 ), dtype = np .uint32 )
251+ volume [20 , :10 , :10 ] = 1 # Sparse content at slice 20 (min area)
252+ for i in range (21 , 80 ):
253+ volume [i , :30 , :30 ] = i # Denser content in middle slices
254+ volume [79 , :, :] = 100 # Maximum content at slice 79 (max area)
255+ # Slices 0-19 and 80-99 should be empty and get zeroed
256+
257+ result = keep_slice_range_by_area (volume , axis = 0 )
258+
259+ # Critical: shape must be preserved to maintain alignment with image data
260+ assert result .shape == (
261+ 100 ,
262+ 50 ,
263+ 50 ,
264+ ), "Output shape must match input shape"
265+
266+ # Slices before min (0-19) should be zeroed
267+ assert np .all (
268+ result [:20 ] == 0
269+ ), "Slices before min-area slice should be zeroed"
270+
271+ # Slices between min and max (20-79) should be preserved
272+ np .testing .assert_array_equal (
273+ result [20 :80 ],
274+ volume [20 :80 ],
275+ err_msg = "Label content in range should be preserved" ,
276+ )
277+
278+ # Slices after max (80-99) should be zeroed
279+ assert np .all (
280+ result [80 :] == 0
281+ ), "Slices after max-area slice should be zeroed"
282+
235283 def test_keep_slice_range_by_area_invalid_dims (self ):
236284 """At least 3 dimensions are required"""
237285 image = np .ones ((4 , 4 ), dtype = np .uint8 )
0 commit comments