Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions micro_sam/sam_annotator/_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def _update_image(self, segmentation_result=None):
)
self._shape = state.image_shape

# Before we reset the layers, we ensure all expected layers exist.
self._create_layers()

# Update the image scale.
scale = state.image_scale

Expand All @@ -187,12 +190,15 @@ def _update_image(self, segmentation_result=None):
self._viewer.layers["current_object"].scale = scale
self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32")
self._viewer.layers["auto_segmentation"].scale = scale

if segmentation_result is None or segmentation_result is False:
self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32")
else:
assert segmentation_result.shape == self._shape
self._viewer.layers["committed_objects"].data = segmentation_result
self._viewer.layers["committed_objects"].scale = scale

self._viewer.layers["point_prompts"].scale = scale
self._viewer.layers["prompts"].scale = scale

vutil.clear_annotations(self._viewer, clear_segmentations=False)
71 changes: 70 additions & 1 deletion micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
# from napari.qt.threading import thread_worker
from napari.utils import progress

from ._state import AnnotatorState
from . import util as vutil
from ._tooltips import get_tooltip
from ._state import AnnotatorState
from .. import instance_segmentation, util
from ..multi_dimensional_segmentation import (
segment_mask_in_volume, merge_instance_segmentation_3d, track_across_frames, PROJECTION_MODES, get_napari_track_data
Expand Down Expand Up @@ -509,6 +509,10 @@ def commit(
commit_path: Select a file path where the committed results and prompts will be saved.
This feature is still experimental.
"""

# Validate all layers.
viewer = _validate_layers(viewer, layer)

_, seg, mask, bb = _commit_impl(viewer, layer, preserve_committed)

if commit_path is not None:
Expand Down Expand Up @@ -715,7 +719,69 @@ def _validate_embeddings(viewer: "napari.viewer.Viewer"):
# return False


def _validate_layers(viewer, layer_choice=None):

# Let's find the first image layer to use as our reference for getting the shape.
image_layers = [layer for layer in viewer.layers if isinstance(layer, napari.layers.Image)]
if image_layers:
_shape = image_layers[0].data.shape
else:
raise RuntimeError("Seems like there is no image available for segmentation.")

# Add the label layers for the current object, the automatic segmentation and the committed segmentation.
dummy_data = np.zeros(_shape, dtype="uint32")

def _validation_window_for_missing_layer():
return _generate_message(
message_type="error",
message=f"The '{layer_choice}' layer to commit is missing. Please re-annotate and try again."
)

# Validate whether the layers pre-exist as expected or not. Otherwise, create them!
if "current_object" not in viewer.layers:
if "current_object" == layer_choice:
_validation_window_for_missing_layer()
viewer.add_labels(data=dummy_data, name="current_object")

if "auto_segmentation" not in viewer.layers:
if "auto_segmentation" == layer_choice:
_validation_window_for_missing_layer()
viewer.add_labels(data=dummy_data, name="auto_segmentation")

if "committed_objects" not in viewer.layers:
viewer.add_labels(data=dummy_data, name="committed_objects")
# Randomize colors so it is easy to see when object committed.
viewer.layers["committed_objects"].new_colormap()

if "point_prompts" not in viewer.layers:
_point_labels = ["positive", "negative"]
_point_prompt_layer = viewer.add_points(
name="point_prompts",
property_choices={"label": _point_labels},
border_color="label",
border_color_cycle=vutil.LABEL_COLOR_CYCLE,
symbol="o",
face_color="transparent",
border_width=0.5,
size=12,
ndim=viewer.dims.ndim,
)
_point_prompt_layer.border_color_mode = "cycle"

if "prompts" not in viewer.layers:
# Add the shape layer for box and other shape prompts.
viewer.add_shapes(
face_color="transparent", edge_color="green", edge_width=4, name="prompts", ndim=viewer.dims.ndim,
)

return viewer


def _validate_prompts(viewer: "napari.viewer.Viewer") -> bool:

# Check whether all layers exist as expected or create new ones automatically.
viewer = _validate_layers(viewer)

if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0:
msg = "No prompts were given. Please provide prompts to run interactive segmentation."
return _generate_message("error", msg)
Expand Down Expand Up @@ -1749,6 +1815,9 @@ def update_segmentation(seg):
self._viewer.layers["auto_segmentation"].data[i] = seg
self._viewer.layers["auto_segmentation"].refresh()

# Validate all layers.
_validate_layers(self._viewer)

seg = seg_impl()
update_segmentation(seg)
# worker = seg_impl()
Expand Down
File renamed without changes.
Loading