Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions micro_sam/sam_annotator/_annotator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import napari
from typing import Optional, List

import numpy as np

import napari
from qtpy import QtWidgets
from magicgui.widgets import Widget, Container, FunctionGui

Expand All @@ -16,19 +18,39 @@ class _AnnotatorBase(QtWidgets.QScrollArea):
The annotators differ in their data dimensionality and the widgets.
"""

def _create_layers(self):
def _require_layers(self, layer_choices: Optional[List[str]] = None):

# Check whether the image is initialized already. And use the image shape and scale for the layers.
state = AnnotatorState()
shape = self._shape if state.image_shape is None else state.image_shape

# Add the label layers for the current object, the automatic segmentation and the committed segmentation.
dummy_data = np.zeros(self._shape, dtype="uint32")
dummy_data = np.zeros(shape, dtype="uint32")
image_scale = state.image_scale

# Before adding new layers, we always check whether a layer with this name already exists or not.
if "current_object" not in self._viewer.layers:
if layer_choices and "current_object" in layer_choices: # Check at 'commit' call button.
widgets._validation_window_for_missing_layer("current_object")
self._viewer.add_labels(data=dummy_data, name="current_object")
if image_scale is not None:
self.layers["current_objects"].scale = image_scale

if "auto_segmentation" not in self._viewer.layers:
if layer_choices and "auto_segmentation" in layer_choices: # Check at 'commit' call button.
widgets._validation_window_for_missing_layer("auto_segmentation")
self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
if image_scale is not None:
self.layers["auto_segmentation"].scale = image_scale

if "committed_objects" not in self._viewer.layers:
if layer_choices and "committed_objects" in layer_choices: # Check at 'commit' call button.
widgets._validation_window_for_missing_layer("committed_objects")
self._viewer.add_labels(data=dummy_data, name="committed_objects")
# Randomize colors so it is easy to see when object committed.
self._viewer.layers["committed_objects"].new_colormap()
if image_scale is not None:
self.layers["committed_objects"].scale = image_scale

# Add the point layer for point prompts.
self._point_labels = ["positive", "negative"]
Expand Down Expand Up @@ -70,7 +92,7 @@ def _create_widgets(self):
# Create the prompt widget. (The same for all plugins.)
self._prompt_widget = widgets.create_prompt_menu(self._point_prompt_layer, self._point_labels)

# Create the dictionray for the widgets and get the widgets of the child plugin.
# Create the dictionary for the widgets and get the widgets of the child plugin.
self._widgets = {"embeddings": self._embedding_widget, "prompts": self._prompt_widget}
self._widgets.update(self._get_widgets())

Expand Down Expand Up @@ -131,7 +153,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", ndim: int) -> None:
# Initialize with a dummy shape, which is reset to the correct shape once an image is set.
self._ndim = ndim
self._shape = (256, 256) if ndim == 2 else (16, 256, 256)
self._create_layers()
self._require_layers()

# Create all the widgets and add them to the layout.
self._create_widgets()
Expand Down Expand Up @@ -179,6 +201,9 @@ def _update_image(self, segmentation_result=None):
)
self._shape = state.image_shape

# Before we reset the layers, we ensure all expected layers exist.
self._require_layers()

# Update the image scale.
scale = state.image_scale

Expand All @@ -187,12 +212,15 @@ def _update_image(self, segmentation_result=None):
self._viewer.layers["current_object"].scale = scale
self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32")
self._viewer.layers["auto_segmentation"].scale = scale

if segmentation_result is None or segmentation_result is False:
self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32")
else:
assert segmentation_result.shape == self._shape
self._viewer.layers["committed_objects"].data = segmentation_result
self._viewer.layers["committed_objects"].scale = scale

self._viewer.layers["point_prompts"].scale = scale
self._viewer.layers["prompts"].scale = scale

vutil.clear_annotations(self._viewer, clear_segmentations=False)
4 changes: 4 additions & 0 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch.nn as nn

import micro_sam
import micro_sam.util as util
from micro_sam.instance_segmentation import AMGBase, get_decoder
from micro_sam.precompute_state import cache_amg_state, cache_is_state
Expand Down Expand Up @@ -69,6 +70,9 @@ class AnnotatorState(metaclass=Singleton):
# z-range to limit the data being committed in 3d / tracking.
z_range: Optional[Tuple[int, int]] = None

# annotator_class
annotator: Optional["micro_sam.sam_annotator._annotator._AnnotatorBase"] = None

def initialize_predictor(
self,
image_data,
Expand Down
47 changes: 36 additions & 11 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
# from napari.qt.threading import thread_worker
from napari.utils import progress

from ._state import AnnotatorState
from . import util as vutil
from ._tooltips import get_tooltip
from ._state import AnnotatorState
from .. import instance_segmentation, util
from ..multi_dimensional_segmentation import (
segment_mask_in_volume, merge_instance_segmentation_3d, track_across_frames, PROJECTION_MODES, get_napari_track_data
Expand Down Expand Up @@ -496,8 +496,12 @@ def _mask_matched_objects(seg, prev_seg, preservation_threshold):


def _commit_impl(viewer, layer, preserve_mode, preservation_threshold):
# Check if we have a z_range. If yes, use it to set a bounding box.
state = AnnotatorState()

# Check whether all layers exist as expected or create new ones automatically.
state.annotator._require_layers(layer_choices=[layer, "committed_objects"])

# Check if we have a z_range. If yes, use it to set a bounding box.
if state.z_range is None:
bb = np.s_[:]
else:
Expand Down Expand Up @@ -750,6 +754,7 @@ def commit(
commit_path: Select a file path where the committed results and prompts will be saved.
This feature is still experimental.
"""
# Commit the segmentation layer.
_, seg, mask, bb = _commit_impl(viewer, layer, preserve_mode, preservation_threshold)

if commit_path is not None:
Expand Down Expand Up @@ -964,12 +969,27 @@ def _validate_embeddings(viewer: "napari.viewer.Viewer"):
# return False


def _validate_prompts(viewer: "napari.viewer.Viewer") -> bool:
if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0:
msg = "No prompts were given. Please provide prompts to run interactive segmentation."
return _generate_message("error", msg)
def _validation_window_for_missing_layer(layer_choice):
if layer_choice == "committed_objects":
msg = "The 'committed_objects' layer to commit masks is missing. Please try to commit again."
else:
return False
msg = f"The '{layer_choice}' layer to commit is missing. Please re-annotate and try again."

return _generate_message(message_type="error", message=msg)


def _validate_layers(viewer: "napari.viewer.Viewer", automatic_segmentation: bool = False) -> bool:
# Check whether all layers exist as expected or create new ones automatically.
state = AnnotatorState()
state.annotator._require_layers()

if not automatic_segmentation:
# Check prompts layer.
if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0:
msg = "No prompts were given. Please provide prompts to run interactive segmentation."
return _generate_message("error", msg)
else:
return False


@magic_factory(call_button="Segment Object [S]")
Expand All @@ -982,7 +1002,7 @@ def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None:
"""
if _validate_embeddings(viewer):
return None
if _validate_prompts(viewer):
if _validate_layers(viewer):
return None

shape = viewer.layers["current_object"].data.shape
Expand Down Expand Up @@ -1016,7 +1036,7 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None:
"""
if _validate_embeddings(viewer):
return None
if _validate_prompts(viewer):
if _validate_layers(viewer):
return None

shape = viewer.layers["current_object"].data.shape[1:]
Expand Down Expand Up @@ -1057,8 +1077,9 @@ def segment_frame(viewer: "napari.viewer.Viewer") -> None:
"""
if _validate_embeddings(viewer):
return None
if _validate_prompts(viewer):
if _validate_layers(viewer):
return None

state = AnnotatorState()
shape = state.image_shape[1:]
position = viewer.dims.point
Expand Down Expand Up @@ -1626,8 +1647,9 @@ def update_segmentation(seg):
def __call__(self):
if _validate_embeddings(self._viewer):
return None
if _validate_prompts(self._viewer):
if _validate_layers(self._viewer):
return None

if self.tracking:
return self._run_tracking()
else:
Expand Down Expand Up @@ -1889,6 +1911,9 @@ def update_segmentation(seg):
self._viewer.layers["auto_segmentation"].data[i] = seg
self._viewer.layers["auto_segmentation"].refresh()

# Validate all layers.
_validate_layers(self._viewer, automatic_segmentation=True)

seg = seg_impl()
update_segmentation(seg)
# worker = seg_impl()
Expand Down
13 changes: 11 additions & 2 deletions micro_sam/sam_annotator/annotator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@ def _get_widgets(self):
"clear": widgets.clear(),
}

def __init__(self, viewer: "napari.viewer.Viewer") -> None:
def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
super().__init__(viewer=viewer, ndim=2)

# Set the expected annotator class to the state.
state = AnnotatorState()

# Reset the state.
if reset_state:
state.reset_state()

state.annotator = self


def annotator_2d(
image: np.ndarray,
Expand Down Expand Up @@ -85,7 +94,7 @@ def annotator_2d(
viewer = napari.Viewer()

viewer.add_image(image, name="image")
annotator = Annotator2d(viewer)
annotator = Annotator2d(viewer, reset_state=False)

# Trigger layer update of the annotator so that layers have the correct shape.
# And initialize the 'committed_objects' with the segmentation result if it was given.
Expand Down
13 changes: 11 additions & 2 deletions micro_sam/sam_annotator/annotator_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,19 @@ def _get_widgets(self):
"clear": widgets.clear_volume(),
}

def __init__(self, viewer: "napari.viewer.Viewer") -> None:
def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
self._with_decoder = AnnotatorState().decoder is not None
super().__init__(viewer=viewer, ndim=3)

# Set the expected annotator class to the state.
state = AnnotatorState()

# Reset the state.
if reset_state:
state.reset_state()

state.annotator = self

def _update_image(self, segmentation_result=None):
super()._update_image(segmentation_result=segmentation_result)
# Load the amg state from the embedding path.
Expand Down Expand Up @@ -95,7 +104,7 @@ def annotator_3d(
viewer = napari.Viewer()

viewer.add_image(image, name="image")
annotator = Annotator3d(viewer)
annotator = Annotator3d(viewer, reset_state=False)

# Trigger layer update of the annotator so that layers have the correct shape.
# And initialize the 'committed_objects' with the segmentation result if it was given.
Expand Down
Loading