Skip to content

Commit d9c4a5b

Browse files
committed
Add comprehensive test suite for convpaint functionality
Tests cover: - Function registration and parameters - Parameter validation (model_path, downsample, output_type) - Semantic to instance segmentation conversion (2D/3D) - Background label removal - Environment manager module detection
1 parent 55562d9 commit d9c4a5b

1 file changed

Lines changed: 217 additions & 0 deletions

File tree

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# src/napari_tmidas/_tests/test_convpaint.py
2+
import numpy as np
3+
import pytest
4+
5+
from napari_tmidas._registry import BatchProcessingRegistry
6+
from napari_tmidas.processing_functions import discover_and_load_processing_functions
7+
8+
9+
class TestConvpaintPrediction:
10+
"""Test convpaint prediction functionality."""
11+
12+
def test_convpaint_function_registered(self):
13+
"""Test that convpaint function is registered."""
14+
# Ensure processing functions are loaded
15+
discover_and_load_processing_functions()
16+
17+
functions = BatchProcessingRegistry.list_functions()
18+
assert "Convpaint Prediction" in functions
19+
20+
def test_convpaint_parameters(self):
21+
"""Test convpaint function has correct parameters."""
22+
# Ensure processing functions are loaded
23+
discover_and_load_processing_functions()
24+
25+
func_info = BatchProcessingRegistry.get_function_info(
26+
"Convpaint Prediction"
27+
)
28+
assert func_info is not None
29+
30+
params = func_info["parameters"]
31+
assert "model_path" in params
32+
assert "image_downsample" in params
33+
assert "output_type" in params
34+
assert "background_label" in params
35+
assert "use_cpu" in params
36+
assert "force_dedicated_env" in params
37+
# Check parameter defaults
38+
assert params["image_downsample"]["default"] == 2
39+
assert params["output_type"]["default"] == "semantic"
40+
assert params["background_label"]["default"] == 1
41+
assert params["use_cpu"]["default"] is False
42+
43+
def test_convpaint_output_type_options(self):
44+
"""Test output_type has correct options."""
45+
# Ensure processing functions are loaded
46+
discover_and_load_processing_functions()
47+
48+
func_info = BatchProcessingRegistry.get_function_info(
49+
"Convpaint Prediction"
50+
)
51+
params = func_info["parameters"]
52+
assert "options" in params["output_type"]
53+
assert params["output_type"]["options"] == ["semantic", "instance"]
54+
55+
def test_convpaint_missing_model_path(self):
56+
"""Test that missing model_path raises ValueError."""
57+
from napari_tmidas.processing_functions.convpaint_prediction import (
58+
convpaint_predict,
59+
)
60+
61+
image = np.random.randint(0, 255, (100, 100), dtype=np.uint8)
62+
63+
with pytest.raises(ValueError, match="model_path"):
64+
convpaint_predict(image, model_path="")
65+
66+
def test_convpaint_invalid_model_path(self):
67+
"""Test that invalid model_path raises ValueError."""
68+
from napari_tmidas.processing_functions.convpaint_prediction import (
69+
convpaint_predict,
70+
)
71+
72+
image = np.random.randint(0, 255, (100, 100), dtype=np.uint8)
73+
74+
with pytest.raises(ValueError, match="not found"):
75+
convpaint_predict(image, model_path="/nonexistent/model.pkl")
76+
77+
def test_convpaint_invalid_output_type(self):
78+
"""Test that invalid output_type raises ValueError."""
79+
from napari_tmidas.processing_functions.convpaint_prediction import (
80+
convpaint_predict,
81+
)
82+
83+
image = np.random.randint(0, 255, (100, 100), dtype=np.uint8)
84+
85+
# Create a temporary model file
86+
import tempfile
87+
88+
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f:
89+
model_path = f.name
90+
91+
try:
92+
with pytest.raises(ValueError, match="output_type"):
93+
convpaint_predict(
94+
image, model_path=model_path, output_type="invalid"
95+
)
96+
finally:
97+
import os
98+
99+
os.unlink(model_path)
100+
101+
def test_convpaint_invalid_downsample(self):
102+
"""Test that invalid image_downsample raises ValueError."""
103+
from napari_tmidas.processing_functions.convpaint_prediction import (
104+
convpaint_predict,
105+
)
106+
107+
image = np.random.randint(0, 255, (100, 100), dtype=np.uint8)
108+
109+
# Create a temporary model file
110+
import tempfile
111+
112+
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f:
113+
model_path = f.name
114+
115+
try:
116+
with pytest.raises(ValueError, match="image_downsample"):
117+
convpaint_predict(
118+
image, model_path=model_path, image_downsample=0
119+
)
120+
finally:
121+
import os
122+
123+
os.unlink(model_path)
124+
125+
def test_semantic_to_instance_conversion_2d(self):
126+
"""Test semantic to instance conversion for 2D images."""
127+
from napari_tmidas.processing_functions.convpaint_prediction import (
128+
_convert_semantic_to_instance,
129+
)
130+
131+
# Create a simple 2D semantic mask with two classes
132+
image = np.zeros((50, 50), dtype=np.uint8)
133+
image[10:20, 10:20] = 1 # Class 1 object
134+
image[30:40, 30:40] = 2 # Class 2 object
135+
136+
result = _convert_semantic_to_instance(image)
137+
138+
# Should have 2 unique labels plus background
139+
unique_labels = np.unique(result)
140+
assert len(unique_labels) == 3 # 0 (background), 1, 2
141+
assert 0 in unique_labels
142+
143+
def test_semantic_to_instance_conversion_3d(self):
144+
"""Test semantic to instance conversion for 3D images."""
145+
from napari_tmidas.processing_functions.convpaint_prediction import (
146+
_convert_semantic_to_instance,
147+
)
148+
149+
# Create a simple 3D semantic mask (small Z stack)
150+
image = np.zeros((5, 50, 50), dtype=np.uint8)
151+
image[1:3, 10:20, 10:20] = 1 # 3D object class 1
152+
image[2:4, 30:40, 30:40] = 2 # 3D object class 2
153+
154+
result = _convert_semantic_to_instance(image)
155+
156+
# Should process as 3D volume
157+
assert result.shape == image.shape
158+
unique_labels = np.unique(result)
159+
assert len(unique_labels) >= 2 # At least background and 1+ objects
160+
161+
def test_background_label_removal(self):
162+
"""Test that background label is correctly removed."""
163+
from napari_tmidas.processing_functions.convpaint_prediction import (
164+
_convert_semantic_to_instance,
165+
)
166+
167+
# Create semantic mask with specific background label
168+
image = np.ones((50, 50), dtype=np.uint8) # Background = 1
169+
image[10:20, 10:20] = 2 # Foreground = 2
170+
image[30:40, 30:40] = 3 # Foreground = 3
171+
172+
# Simulate background removal (this happens in main function)
173+
image[image == 1] = 0
174+
175+
result = _convert_semantic_to_instance(image)
176+
177+
# Background should be 0
178+
assert result[0, 0] == 0
179+
assert result[49, 49] == 0
180+
# Objects should have non-zero labels
181+
assert result[15, 15] > 0
182+
assert result[35, 35] > 0
183+
184+
185+
class TestConvpaintEnvManager:
186+
"""Test convpaint environment manager."""
187+
188+
def test_env_manager_exists(self):
189+
"""Test that environment manager module exists."""
190+
from napari_tmidas.processing_functions import convpaint_env_manager
191+
192+
assert convpaint_env_manager is not None
193+
194+
def test_env_manager_functions(self):
195+
"""Test that required functions exist."""
196+
from napari_tmidas.processing_functions.convpaint_env_manager import (
197+
create_convpaint_env,
198+
get_env_python_path,
199+
is_convpaint_installed,
200+
is_env_created,
201+
)
202+
203+
# Just check they exist and are callable
204+
assert callable(is_convpaint_installed)
205+
assert callable(is_env_created)
206+
assert callable(get_env_python_path)
207+
assert callable(create_convpaint_env)
208+
209+
def test_convpaint_not_installed_initially(self):
210+
"""Test convpaint detection."""
211+
from napari_tmidas.processing_functions.convpaint_env_manager import (
212+
is_convpaint_installed,
213+
)
214+
215+
# This will return True or False depending on environment
216+
result = is_convpaint_installed()
217+
assert isinstance(result, bool)

0 commit comments

Comments
 (0)