Skip to content

Commit e2421a6

Browse files
Update the segmentation settings (#84)
* Update SGN model * Update segmentation setting logic
1 parent bc80409 commit e2421a6

File tree

3 files changed

+65
-36
lines changed

3 files changed

+65
-36
lines changed

flamingo_tools/model_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,51 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
117117
return model
118118

119119

120+
def get_default_segmentation_settings(model_type: str) -> Dict[str, Union[str, float]]:
121+
"""Get the default settings for instance segmentation post-processing for a given model.
122+
123+
Args:
124+
model_type: The model. One of 'SGN', 'SGN-lowres', 'IHC', 'IHC-lowres'.
125+
126+
Returns:
127+
Dictionary with the default segmentation settings.
128+
"""
129+
all_default_kwargs = {
130+
"SGN": {
131+
"center_distance_threshold": 0.4,
132+
"boundary_distance_threshold": 0.5,
133+
"fg_threshold": 0.5,
134+
"distance_smoothing": 0.0,
135+
"seg_class": "sgn",
136+
},
137+
"SGN-lowres": {
138+
"center_distance_threshold": None,
139+
"boundary_distance_threshold": 0.5,
140+
"fg_threshold": 0.5,
141+
"distance_smoothing": 0.0,
142+
"seg_class": "sgn_low",
143+
},
144+
"IHC": {
145+
"center_distance_threshold": 0.5,
146+
"boundary_distance_threshold": 0.6,
147+
"fg_threshold": 0.5,
148+
"distance_smoothing": 0.6,
149+
"seg_class": "ihc",
150+
},
151+
"IHC-lowres": {
152+
"center_distance_threshold": 0.5,
153+
"boundary_distance_threshold": 0.6,
154+
"fg_threshold": 0.5,
155+
"distance_smoothing": 0.6,
156+
"seg_class": "ihc",
157+
},
158+
}
159+
if model_type not in all_default_kwargs:
160+
raise ValueError(f"Invalid model: {model_type}. Choose one of {list(all_default_kwargs.keys())}.")
161+
default_kwargs = all_default_kwargs[model_type]
162+
return default_kwargs
163+
164+
120165
def get_default_tiling() -> Dict[str, Dict[str, int]]:
121166
"""Determine the tile shape and halo depending on the available VRAM.
122167

flamingo_tools/plugin/segmentation_widget.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,27 @@
99

1010
from .base_widget import BaseWidget
1111
from .util import _load_custom_model, _available_devices, _get_current_tiling
12-
from ..model_utils import get_model, get_model_registry, get_device, get_default_tiling
12+
from ..model_utils import (
13+
get_model, get_model_registry, get_device, get_default_tiling, get_default_segmentation_settings
14+
)
1315

1416

15-
# TODO Expose segmentation kwargs.
16-
def _run_segmentation(image, model, model_type, tiling, device):
17+
def _run_segmentation(image, model, model_type, tiling, device, min_size):
1718
block_shape = [tiling["tile"][ax] for ax in "zyx"]
1819
halo = [tiling["halo"][ax] for ax in "zyx"]
1920
prediction = predict_with_halo(
2021
image, model, gpu_ids=[device], block_shape=block_shape, halo=halo,
2122
tqdm_desc="Run prediction"
2223
)
24+
settings = get_default_segmentation_settings(model_type)
25+
foreground_threshold = settings.pop("fg_threshold", 0.5)
26+
settings.pop("seg_class", None)
27+
settings = {name: 1.0 if val is None else val for name, val in settings.items()}
2328
foreground_map, center_distances, boundary_distances = prediction
2429
segmentation = watershed_from_center_and_boundary_distances(
2530
center_distances, boundary_distances, foreground_map,
26-
center_distance_threshold=0.5,
27-
boundary_distance_threshold=0.5,
28-
foreground_threshold=0.5,
29-
distance_smoothing=1.6,
30-
min_size=100,
31+
min_size=min_size, foreground_threshold=foreground_threshold,
32+
**settings,
3133
)
3234
return segmentation
3335

@@ -110,7 +112,9 @@ def on_predict(self):
110112

111113
# Get the current tiling.
112114
self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
113-
segmentation = _run_segmentation(image, model=model, model_type=model_type, tiling=self.tiling, device=device)
115+
segmentation = _run_segmentation(
116+
image, model=model, model_type=model_type, tiling=self.tiling, device=device, min_size=self.min_size
117+
)
114118

115119
self.viewer.add_labels(segmentation, name=model_type)
116120
show_info(f"INFO: Segmentation of {model_type} added to layers.")
@@ -120,6 +124,11 @@ def _create_settings_widget(self):
120124
# setting_values.setToolTip(get_tooltip("embedding", "settings"))
121125
setting_values.setLayout(QVBoxLayout())
122126

127+
# Create UI for the min-size parameter.
128+
self.min_size = 100
129+
self.min_size_menu, layout = self._add_int_param("min_size", self.min_size, 0, 10000)
130+
setting_values.layout().addLayout(layout)
131+
123132
# Create UI for the device.
124133
device = "auto"
125134
device_options = ["auto"] + _available_devices()

flamingo_tools/segmentation/cli.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from .unet_prediction import run_unet_prediction
66
from .synapse_detection import marker_detection
7-
from ..model_utils import get_model_path
7+
from ..model_utils import get_model_path, get_default_segmentation_settings
88

99

1010
def _get_model_path(model_type, checkpoint_path=None):
@@ -49,32 +49,7 @@ def _convert_argval(value):
4949

5050

5151
def _parse_segmentation_kwargs(extra_kwargs, model_type):
52-
if model_type == "SGN":
53-
default_kwargs = {
54-
"center_distance_threshold": 0.4,
55-
"boundary_distance_threshold": 0.5,
56-
"fg_threshold": 0.5,
57-
"distance_smoothing": 0.0,
58-
"seg_class": "sgn",
59-
}
60-
elif model_type == "SGN-lowres":
61-
default_kwargs = {
62-
"center_distance_threshold": None,
63-
"boundary_distance_threshold": 0.5,
64-
"fg_threshold": 0.5,
65-
"distance_smoothing": 0.0,
66-
"seg_class": "sgn_low",
67-
}
68-
else:
69-
assert model_type.startswith("IHC")
70-
default_kwargs = {
71-
"center_distance_threshold": 0.5,
72-
"boundary_distance_threshold": 0.6,
73-
"fg_threshold": 0.5,
74-
"distance_smoothing": 0.6,
75-
"seg_class": "ihc",
76-
}
77-
52+
default_kwargs = get_default_segmentation_settings(model_type)
7853
kwargs = _parse_kwargs(extra_kwargs, **default_kwargs)
7954
return kwargs
8055

0 commit comments

Comments
 (0)