99
1010from .base_widget import BaseWidget
1111from .util import _load_custom_model , _available_devices , _get_current_tiling
12- from ..model_utils import get_model , get_model_registry , get_device , get_default_tiling
12+ from ..model_utils import (
13+ get_model , get_model_registry , get_device , get_default_tiling , get_default_segmentation_settings
14+ )
1315
1416
15- # TODO Expose segmentation kwargs.
16- def _run_segmentation (image , model , model_type , tiling , device ):
17+ def _run_segmentation (image , model , model_type , tiling , device , min_size ):
1718 block_shape = [tiling ["tile" ][ax ] for ax in "zyx" ]
1819 halo = [tiling ["halo" ][ax ] for ax in "zyx" ]
1920 prediction = predict_with_halo (
2021 image , model , gpu_ids = [device ], block_shape = block_shape , halo = halo ,
2122 tqdm_desc = "Run prediction"
2223 )
24+ settings = get_default_segmentation_settings (model_type )
25+ foreground_threshold = settings .pop ("fg_threshold" , 0.5 )
26+ settings .pop ("seg_class" , None )
27+ settings = {name : 1.0 if val is None else val for name , val in settings .items ()}
2328 foreground_map , center_distances , boundary_distances = prediction
2429 segmentation = watershed_from_center_and_boundary_distances (
2530 center_distances , boundary_distances , foreground_map ,
26- center_distance_threshold = 0.5 ,
27- boundary_distance_threshold = 0.5 ,
28- foreground_threshold = 0.5 ,
29- distance_smoothing = 1.6 ,
30- min_size = 100 ,
31+ min_size = min_size , foreground_threshold = foreground_threshold ,
32+ ** settings ,
3133 )
3234 return segmentation
3335
@@ -110,7 +112,9 @@ def on_predict(self):
110112
111113 # Get the current tiling.
112114 self .tiling = _get_current_tiling (self .tiling , self .default_tiling , model_type )
113- segmentation = _run_segmentation (image , model = model , model_type = model_type , tiling = self .tiling , device = device )
115+ segmentation = _run_segmentation (
116+ image , model = model , model_type = model_type , tiling = self .tiling , device = device , min_size = self .min_size
117+ )
114118
115119 self .viewer .add_labels (segmentation , name = model_type )
116120 show_info (f"INFO: Segmentation of { model_type } added to layers." )
@@ -120,6 +124,11 @@ def _create_settings_widget(self):
120124 # setting_values.setToolTip(get_tooltip("embedding", "settings"))
121125 setting_values .setLayout (QVBoxLayout ())
122126
127+ # Create UI for the min-size parameter.
128+ self .min_size = 100
129+ self .min_size_menu , layout = self ._add_int_param ("min_size" , self .min_size , 0 , 10000 )
130+ setting_values .layout ().addLayout (layout )
131+
123132 # Create UI for the device.
124133 device = "auto"
125134 device_options = ["auto" ] + _available_devices ()
0 commit comments