|
| 1 | +# src/napari_tmidas/_tests/test_convpaint.py |
| 2 | +import numpy as np |
| 3 | +import pytest |
| 4 | + |
| 5 | +from napari_tmidas._registry import BatchProcessingRegistry |
| 6 | +from napari_tmidas.processing_functions import discover_and_load_processing_functions |
| 7 | + |
| 8 | + |
| 9 | +class TestConvpaintPrediction: |
| 10 | + """Test convpaint prediction functionality.""" |
| 11 | + |
| 12 | + def test_convpaint_function_registered(self): |
| 13 | + """Test that convpaint function is registered.""" |
| 14 | + # Ensure processing functions are loaded |
| 15 | + discover_and_load_processing_functions() |
| 16 | + |
| 17 | + functions = BatchProcessingRegistry.list_functions() |
| 18 | + assert "Convpaint Prediction" in functions |
| 19 | + |
| 20 | + def test_convpaint_parameters(self): |
| 21 | + """Test convpaint function has correct parameters.""" |
| 22 | + # Ensure processing functions are loaded |
| 23 | + discover_and_load_processing_functions() |
| 24 | + |
| 25 | + func_info = BatchProcessingRegistry.get_function_info( |
| 26 | + "Convpaint Prediction" |
| 27 | + ) |
| 28 | + assert func_info is not None |
| 29 | + |
| 30 | + params = func_info["parameters"] |
| 31 | + assert "model_path" in params |
| 32 | + assert "image_downsample" in params |
| 33 | + assert "output_type" in params |
| 34 | + assert "background_label" in params |
| 35 | + assert "use_cpu" in params |
| 36 | + assert "force_dedicated_env" in params |
| 37 | + # Check parameter defaults |
| 38 | + assert params["image_downsample"]["default"] == 2 |
| 39 | + assert params["output_type"]["default"] == "semantic" |
| 40 | + assert params["background_label"]["default"] == 1 |
| 41 | + assert params["use_cpu"]["default"] is False |
| 42 | + |
| 43 | + def test_convpaint_output_type_options(self): |
| 44 | + """Test output_type has correct options.""" |
| 45 | + # Ensure processing functions are loaded |
| 46 | + discover_and_load_processing_functions() |
| 47 | + |
| 48 | + func_info = BatchProcessingRegistry.get_function_info( |
| 49 | + "Convpaint Prediction" |
| 50 | + ) |
| 51 | + params = func_info["parameters"] |
| 52 | + assert "options" in params["output_type"] |
| 53 | + assert params["output_type"]["options"] == ["semantic", "instance"] |
| 54 | + |
| 55 | + def test_convpaint_missing_model_path(self): |
| 56 | + """Test that missing model_path raises ValueError.""" |
| 57 | + from napari_tmidas.processing_functions.convpaint_prediction import ( |
| 58 | + convpaint_predict, |
| 59 | + ) |
| 60 | + |
| 61 | + image = np.random.randint(0, 255, (100, 100), dtype=np.uint8) |
| 62 | + |
| 63 | + with pytest.raises(ValueError, match="model_path"): |
| 64 | + convpaint_predict(image, model_path="") |
| 65 | + |
| 66 | + def test_convpaint_invalid_model_path(self): |
| 67 | + """Test that invalid model_path raises ValueError.""" |
| 68 | + from napari_tmidas.processing_functions.convpaint_prediction import ( |
| 69 | + convpaint_predict, |
| 70 | + ) |
| 71 | + |
| 72 | + image = np.random.randint(0, 255, (100, 100), dtype=np.uint8) |
| 73 | + |
| 74 | + with pytest.raises(ValueError, match="not found"): |
| 75 | + convpaint_predict(image, model_path="/nonexistent/model.pkl") |
| 76 | + |
| 77 | + def test_convpaint_invalid_output_type(self): |
| 78 | + """Test that invalid output_type raises ValueError.""" |
| 79 | + from napari_tmidas.processing_functions.convpaint_prediction import ( |
| 80 | + convpaint_predict, |
| 81 | + ) |
| 82 | + |
| 83 | + image = np.random.randint(0, 255, (100, 100), dtype=np.uint8) |
| 84 | + |
| 85 | + # Create a temporary model file |
| 86 | + import tempfile |
| 87 | + |
| 88 | + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f: |
| 89 | + model_path = f.name |
| 90 | + |
| 91 | + try: |
| 92 | + with pytest.raises(ValueError, match="output_type"): |
| 93 | + convpaint_predict( |
| 94 | + image, model_path=model_path, output_type="invalid" |
| 95 | + ) |
| 96 | + finally: |
| 97 | + import os |
| 98 | + |
| 99 | + os.unlink(model_path) |
| 100 | + |
| 101 | + def test_convpaint_invalid_downsample(self): |
| 102 | + """Test that invalid image_downsample raises ValueError.""" |
| 103 | + from napari_tmidas.processing_functions.convpaint_prediction import ( |
| 104 | + convpaint_predict, |
| 105 | + ) |
| 106 | + |
| 107 | + image = np.random.randint(0, 255, (100, 100), dtype=np.uint8) |
| 108 | + |
| 109 | + # Create a temporary model file |
| 110 | + import tempfile |
| 111 | + |
| 112 | + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f: |
| 113 | + model_path = f.name |
| 114 | + |
| 115 | + try: |
| 116 | + with pytest.raises(ValueError, match="image_downsample"): |
| 117 | + convpaint_predict( |
| 118 | + image, model_path=model_path, image_downsample=0 |
| 119 | + ) |
| 120 | + finally: |
| 121 | + import os |
| 122 | + |
| 123 | + os.unlink(model_path) |
| 124 | + |
| 125 | + def test_semantic_to_instance_conversion_2d(self): |
| 126 | + """Test semantic to instance conversion for 2D images.""" |
| 127 | + from napari_tmidas.processing_functions.convpaint_prediction import ( |
| 128 | + _convert_semantic_to_instance, |
| 129 | + ) |
| 130 | + |
| 131 | + # Create a simple 2D semantic mask with two classes |
| 132 | + image = np.zeros((50, 50), dtype=np.uint8) |
| 133 | + image[10:20, 10:20] = 1 # Class 1 object |
| 134 | + image[30:40, 30:40] = 2 # Class 2 object |
| 135 | + |
| 136 | + result = _convert_semantic_to_instance(image) |
| 137 | + |
| 138 | + # Should have 2 unique labels plus background |
| 139 | + unique_labels = np.unique(result) |
| 140 | + assert len(unique_labels) == 3 # 0 (background), 1, 2 |
| 141 | + assert 0 in unique_labels |
| 142 | + |
| 143 | + def test_semantic_to_instance_conversion_3d(self): |
| 144 | + """Test semantic to instance conversion for 3D images.""" |
| 145 | + from napari_tmidas.processing_functions.convpaint_prediction import ( |
| 146 | + _convert_semantic_to_instance, |
| 147 | + ) |
| 148 | + |
| 149 | + # Create a simple 3D semantic mask (small Z stack) |
| 150 | + image = np.zeros((5, 50, 50), dtype=np.uint8) |
| 151 | + image[1:3, 10:20, 10:20] = 1 # 3D object class 1 |
| 152 | + image[2:4, 30:40, 30:40] = 2 # 3D object class 2 |
| 153 | + |
| 154 | + result = _convert_semantic_to_instance(image) |
| 155 | + |
| 156 | + # Should process as 3D volume |
| 157 | + assert result.shape == image.shape |
| 158 | + unique_labels = np.unique(result) |
| 159 | + assert len(unique_labels) >= 2 # At least background and 1+ objects |
| 160 | + |
| 161 | + def test_background_label_removal(self): |
| 162 | + """Test that background label is correctly removed.""" |
| 163 | + from napari_tmidas.processing_functions.convpaint_prediction import ( |
| 164 | + _convert_semantic_to_instance, |
| 165 | + ) |
| 166 | + |
| 167 | + # Create semantic mask with specific background label |
| 168 | + image = np.ones((50, 50), dtype=np.uint8) # Background = 1 |
| 169 | + image[10:20, 10:20] = 2 # Foreground = 2 |
| 170 | + image[30:40, 30:40] = 3 # Foreground = 3 |
| 171 | + |
| 172 | + # Simulate background removal (this happens in main function) |
| 173 | + image[image == 1] = 0 |
| 174 | + |
| 175 | + result = _convert_semantic_to_instance(image) |
| 176 | + |
| 177 | + # Background should be 0 |
| 178 | + assert result[0, 0] == 0 |
| 179 | + assert result[49, 49] == 0 |
| 180 | + # Objects should have non-zero labels |
| 181 | + assert result[15, 15] > 0 |
| 182 | + assert result[35, 35] > 0 |
| 183 | + |
| 184 | + |
| 185 | +class TestConvpaintEnvManager: |
| 186 | + """Test convpaint environment manager.""" |
| 187 | + |
| 188 | + def test_env_manager_exists(self): |
| 189 | + """Test that environment manager module exists.""" |
| 190 | + from napari_tmidas.processing_functions import convpaint_env_manager |
| 191 | + |
| 192 | + assert convpaint_env_manager is not None |
| 193 | + |
| 194 | + def test_env_manager_functions(self): |
| 195 | + """Test that required functions exist.""" |
| 196 | + from napari_tmidas.processing_functions.convpaint_env_manager import ( |
| 197 | + create_convpaint_env, |
| 198 | + get_env_python_path, |
| 199 | + is_convpaint_installed, |
| 200 | + is_env_created, |
| 201 | + ) |
| 202 | + |
| 203 | + # Just check they exist and are callable |
| 204 | + assert callable(is_convpaint_installed) |
| 205 | + assert callable(is_env_created) |
| 206 | + assert callable(get_env_python_path) |
| 207 | + assert callable(create_convpaint_env) |
| 208 | + |
| 209 | + def test_convpaint_not_installed_initially(self): |
| 210 | + """Test convpaint detection.""" |
| 211 | + from napari_tmidas.processing_functions.convpaint_env_manager import ( |
| 212 | + is_convpaint_installed, |
| 213 | + ) |
| 214 | + |
| 215 | + # This will return True or False depending on environment |
| 216 | + result = is_convpaint_installed() |
| 217 | + assert isinstance(result, bool) |
0 commit comments