Skip to content

Commit b82c4eb

Browse files
Ensure valid layers exist for each annotator click menu (#943)
Ensure valid layers exist for each annotator click menu --------- Co-authored-by: Constantin Pape <[email protected]>
1 parent ec8f447 commit b82c4eb

File tree

7 files changed

+257
-86
lines changed

7 files changed

+257
-86
lines changed

micro_sam/sam_annotator/_annotator.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import napari
1+
from typing import Optional, List
2+
23
import numpy as np
34

5+
import napari
46
from qtpy import QtWidgets
57
from magicgui.widgets import Widget, Container, FunctionGui
68

@@ -16,19 +18,39 @@ class _AnnotatorBase(QtWidgets.QScrollArea):
1618
The annotators differ in their data dimensionality and the widgets.
1719
"""
1820

19-
def _create_layers(self):
21+
def _require_layers(self, layer_choices: Optional[List[str]] = None):
22+
23+
# Check whether the image is initialized already. And use the image shape and scale for the layers.
24+
state = AnnotatorState()
25+
shape = self._shape if state.image_shape is None else state.image_shape
26+
2027
# Add the label layers for the current object, the automatic segmentation and the committed segmentation.
21-
dummy_data = np.zeros(self._shape, dtype="uint32")
28+
dummy_data = np.zeros(shape, dtype="uint32")
29+
image_scale = state.image_scale
2230

2331
# Before adding new layers, we always check whether a layer with this name already exists or not.
2432
if "current_object" not in self._viewer.layers:
33+
if layer_choices and "current_object" in layer_choices: # Check at 'commit' call button.
34+
widgets._validation_window_for_missing_layer("current_object")
2535
self._viewer.add_labels(data=dummy_data, name="current_object")
36+
if image_scale is not None:
37+
self.layers["current_objects"].scale = image_scale
38+
2639
if "auto_segmentation" not in self._viewer.layers:
40+
if layer_choices and "auto_segmentation" in layer_choices: # Check at 'commit' call button.
41+
widgets._validation_window_for_missing_layer("auto_segmentation")
2742
self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
43+
if image_scale is not None:
44+
self.layers["auto_segmentation"].scale = image_scale
45+
2846
if "committed_objects" not in self._viewer.layers:
47+
if layer_choices and "committed_objects" in layer_choices: # Check at 'commit' call button.
48+
widgets._validation_window_for_missing_layer("committed_objects")
2949
self._viewer.add_labels(data=dummy_data, name="committed_objects")
3050
# Randomize colors so it is easy to see when object committed.
3151
self._viewer.layers["committed_objects"].new_colormap()
52+
if image_scale is not None:
53+
self.layers["committed_objects"].scale = image_scale
3254

3355
# Add the point layer for point prompts.
3456
self._point_labels = ["positive", "negative"]
@@ -70,7 +92,7 @@ def _create_widgets(self):
7092
# Create the prompt widget. (The same for all plugins.)
7193
self._prompt_widget = widgets.create_prompt_menu(self._point_prompt_layer, self._point_labels)
7294

73-
# Create the dictionray for the widgets and get the widgets of the child plugin.
95+
# Create the dictionary for the widgets and get the widgets of the child plugin.
7496
self._widgets = {"embeddings": self._embedding_widget, "prompts": self._prompt_widget}
7597
self._widgets.update(self._get_widgets())
7698

@@ -131,7 +153,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", ndim: int) -> None:
131153
# Initialize with a dummy shape, which is reset to the correct shape once an image is set.
132154
self._ndim = ndim
133155
self._shape = (256, 256) if ndim == 2 else (16, 256, 256)
134-
self._create_layers()
156+
self._require_layers()
135157

136158
# Create all the widgets and add them to the layout.
137159
self._create_widgets()
@@ -179,6 +201,9 @@ def _update_image(self, segmentation_result=None):
179201
)
180202
self._shape = state.image_shape
181203

204+
# Before we reset the layers, we ensure all expected layers exist.
205+
self._require_layers()
206+
182207
# Update the image scale.
183208
scale = state.image_scale
184209

@@ -187,12 +212,15 @@ def _update_image(self, segmentation_result=None):
187212
self._viewer.layers["current_object"].scale = scale
188213
self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32")
189214
self._viewer.layers["auto_segmentation"].scale = scale
215+
190216
if segmentation_result is None or segmentation_result is False:
191217
self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32")
192218
else:
193219
assert segmentation_result.shape == self._shape
194220
self._viewer.layers["committed_objects"].data = segmentation_result
195221
self._viewer.layers["committed_objects"].scale = scale
222+
196223
self._viewer.layers["point_prompts"].scale = scale
197224
self._viewer.layers["prompts"].scale = scale
225+
198226
vutil.clear_annotations(self._viewer, clear_segmentations=False)

micro_sam/sam_annotator/_state.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch.nn as nn
1616

17+
import micro_sam
1718
import micro_sam.util as util
1819
from micro_sam.instance_segmentation import AMGBase, get_decoder
1920
from micro_sam.precompute_state import cache_amg_state, cache_is_state
@@ -69,6 +70,9 @@ class AnnotatorState(metaclass=Singleton):
6970
# z-range to limit the data being committed in 3d / tracking.
7071
z_range: Optional[Tuple[int, int]] = None
7172

73+
# annotator_class
74+
annotator: Optional["micro_sam.sam_annotator._annotator._AnnotatorBase"] = None
75+
7276
def initialize_predictor(
7377
self,
7478
image_data,

micro_sam/sam_annotator/_widgets.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
# from napari.qt.threading import thread_worker
3131
from napari.utils import progress
3232

33-
from ._state import AnnotatorState
3433
from . import util as vutil
3534
from ._tooltips import get_tooltip
35+
from ._state import AnnotatorState
3636
from .. import instance_segmentation, util
3737
from ..multi_dimensional_segmentation import (
3838
segment_mask_in_volume, merge_instance_segmentation_3d, track_across_frames, PROJECTION_MODES, get_napari_track_data
@@ -496,8 +496,12 @@ def _mask_matched_objects(seg, prev_seg, preservation_threshold):
496496

497497

498498
def _commit_impl(viewer, layer, preserve_mode, preservation_threshold):
499-
# Check if we have a z_range. If yes, use it to set a bounding box.
500499
state = AnnotatorState()
500+
501+
# Check whether all layers exist as expected or create new ones automatically.
502+
state.annotator._require_layers(layer_choices=[layer, "committed_objects"])
503+
504+
# Check if we have a z_range. If yes, use it to set a bounding box.
501505
if state.z_range is None:
502506
bb = np.s_[:]
503507
else:
@@ -750,6 +754,7 @@ def commit(
750754
commit_path: Select a file path where the committed results and prompts will be saved.
751755
This feature is still experimental.
752756
"""
757+
# Commit the segmentation layer.
753758
_, seg, mask, bb = _commit_impl(viewer, layer, preserve_mode, preservation_threshold)
754759

755760
if commit_path is not None:
@@ -964,12 +969,27 @@ def _validate_embeddings(viewer: "napari.viewer.Viewer"):
964969
# return False
965970

966971

967-
def _validate_prompts(viewer: "napari.viewer.Viewer") -> bool:
968-
if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0:
969-
msg = "No prompts were given. Please provide prompts to run interactive segmentation."
970-
return _generate_message("error", msg)
972+
def _validation_window_for_missing_layer(layer_choice):
973+
if layer_choice == "committed_objects":
974+
msg = "The 'committed_objects' layer to commit masks is missing. Please try to commit again."
971975
else:
972-
return False
976+
msg = f"The '{layer_choice}' layer to commit is missing. Please re-annotate and try again."
977+
978+
return _generate_message(message_type="error", message=msg)
979+
980+
981+
def _validate_layers(viewer: "napari.viewer.Viewer", automatic_segmentation: bool = False) -> bool:
982+
# Check whether all layers exist as expected or create new ones automatically.
983+
state = AnnotatorState()
984+
state.annotator._require_layers()
985+
986+
if not automatic_segmentation:
987+
# Check prompts layer.
988+
if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0:
989+
msg = "No prompts were given. Please provide prompts to run interactive segmentation."
990+
return _generate_message("error", msg)
991+
else:
992+
return False
973993

974994

975995
@magic_factory(call_button="Segment Object [S]")
@@ -982,7 +1002,7 @@ def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None:
9821002
"""
9831003
if _validate_embeddings(viewer):
9841004
return None
985-
if _validate_prompts(viewer):
1005+
if _validate_layers(viewer):
9861006
return None
9871007

9881008
shape = viewer.layers["current_object"].data.shape
@@ -1016,7 +1036,7 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None:
10161036
"""
10171037
if _validate_embeddings(viewer):
10181038
return None
1019-
if _validate_prompts(viewer):
1039+
if _validate_layers(viewer):
10201040
return None
10211041

10221042
shape = viewer.layers["current_object"].data.shape[1:]
@@ -1057,8 +1077,9 @@ def segment_frame(viewer: "napari.viewer.Viewer") -> None:
10571077
"""
10581078
if _validate_embeddings(viewer):
10591079
return None
1060-
if _validate_prompts(viewer):
1080+
if _validate_layers(viewer):
10611081
return None
1082+
10621083
state = AnnotatorState()
10631084
shape = state.image_shape[1:]
10641085
position = viewer.dims.point
@@ -1626,8 +1647,9 @@ def update_segmentation(seg):
16261647
def __call__(self):
16271648
if _validate_embeddings(self._viewer):
16281649
return None
1629-
if _validate_prompts(self._viewer):
1650+
if _validate_layers(self._viewer):
16301651
return None
1652+
16311653
if self.tracking:
16321654
return self._run_tracking()
16331655
else:
@@ -1889,6 +1911,9 @@ def update_segmentation(seg):
18891911
self._viewer.layers["auto_segmentation"].data[i] = seg
18901912
self._viewer.layers["auto_segmentation"].refresh()
18911913

1914+
# Validate all layers.
1915+
_validate_layers(self._viewer, automatic_segmentation=True)
1916+
18921917
seg = seg_impl()
18931918
update_segmentation(seg)
18941919
# worker = seg_impl()

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,18 @@ def _get_widgets(self):
2424
"clear": widgets.clear(),
2525
}
2626

27-
def __init__(self, viewer: "napari.viewer.Viewer") -> None:
27+
def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
2828
super().__init__(viewer=viewer, ndim=2)
2929

30+
# Set the expected annotator class to the state.
31+
state = AnnotatorState()
32+
33+
# Reset the state.
34+
if reset_state:
35+
state.reset_state()
36+
37+
state.annotator = self
38+
3039

3140
def annotator_2d(
3241
image: np.ndarray,
@@ -85,7 +94,7 @@ def annotator_2d(
8594
viewer = napari.Viewer()
8695

8796
viewer.add_image(image, name="image")
88-
annotator = Annotator2d(viewer)
97+
annotator = Annotator2d(viewer, reset_state=False)
8998

9099
# Trigger layer update of the annotator so that layers have the correct shape.
91100
# And initialize the 'committed_objects' with the segmentation result if it was given.

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,19 @@ def _get_widgets(self):
2424
"clear": widgets.clear_volume(),
2525
}
2626

27-
def __init__(self, viewer: "napari.viewer.Viewer") -> None:
27+
def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
2828
self._with_decoder = AnnotatorState().decoder is not None
2929
super().__init__(viewer=viewer, ndim=3)
3030

31+
# Set the expected annotator class to the state.
32+
state = AnnotatorState()
33+
34+
# Reset the state.
35+
if reset_state:
36+
state.reset_state()
37+
38+
state.annotator = self
39+
3140
def _update_image(self, segmentation_result=None):
3241
super()._update_image(segmentation_result=segmentation_result)
3342
# Load the amg state from the embedding path.
@@ -95,7 +104,7 @@ def annotator_3d(
95104
viewer = napari.Viewer()
96105

97106
viewer.add_image(image, name="image")
98-
annotator = Annotator3d(viewer)
107+
annotator = Annotator3d(viewer, reset_state=False)
99108

100109
# Trigger layer update of the annotator so that layers have the correct shape.
101110
# And initialize the 'committed_objects' with the segmentation result if it was given.

0 commit comments

Comments
 (0)