Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions micro_sam/sam_annotator/_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,37 @@ class _AnnotatorBase(QtWidgets.QScrollArea):
The annotators differ in their data dimensionality and the widgets.
"""

def _create_layers(self):
def _require_layers(self, layer_choice=None):

# Check whether the image is initialized already. And use the image shape and scale for the layers.
state = AnnotatorState()
shape = self._shape if state.image_shape is None else state.image_shape

# Add the label layers for the current object, the automatic segmentation and the committed segmentation.
dummy_data = np.zeros(self._shape, dtype="uint32")
dummy_data = np.zeros(shape, dtype="uint32")
image_scale = state.image_scale

# Before adding new layers, we always check whether a layer with this name already exists or not.
if "current_object" not in self._viewer.layers:
if "current_object" == layer_choice: # Check at 'commit' call button.
widgets._validation_window_for_missing_layer(layer_choice)
self._viewer.add_labels(data=dummy_data, name="current_object")
if image_scale is not None:
self.layers["current_objects"].scale = image_scale

if "auto_segmentation" not in self._viewer.layers:
if "auto_segmentation" == layer_choice: # Check at 'commit' call button situation.
widgets._validation_window_for_missing_layer(layer_choice)
self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
if image_scale is not None:
self.layers["auto_segmentation"].scale = image_scale

if "committed_objects" not in self._viewer.layers:
self._viewer.add_labels(data=dummy_data, name="committed_objects")
# Randomize colors so it is easy to see when object committed.
self._viewer.layers["committed_objects"].new_colormap()
if image_scale is not None:
self.layers["committed_objects"].scale = image_scale

# Add the point layer for point prompts.
self._point_labels = ["positive", "negative"]
Expand Down Expand Up @@ -131,7 +149,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", ndim: int) -> None:
# Initialize with a dummy shape, which is reset to the correct shape once an image is set.
self._ndim = ndim
self._shape = (256, 256) if ndim == 2 else (16, 256, 256)
self._create_layers()
self._require_layers()

# Create all the widgets and add them to the layout.
self._create_widgets()
Expand Down Expand Up @@ -179,6 +197,9 @@ def _update_image(self, segmentation_result=None):
)
self._shape = state.image_shape

# Before we reset the layers, we ensure all expected layers exist.
self._require_layers()

# Update the image scale.
scale = state.image_scale

Expand All @@ -187,12 +208,15 @@ def _update_image(self, segmentation_result=None):
self._viewer.layers["current_object"].scale = scale
self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32")
self._viewer.layers["auto_segmentation"].scale = scale

if segmentation_result is None or segmentation_result is False:
self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32")
else:
assert segmentation_result.shape == self._shape
self._viewer.layers["committed_objects"].data = segmentation_result
self._viewer.layers["committed_objects"].scale = scale

self._viewer.layers["point_prompts"].scale = scale
self._viewer.layers["prompts"].scale = scale

vutil.clear_annotations(self._viewer, clear_segmentations=False)
4 changes: 4 additions & 0 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch.nn as nn

import micro_sam
import micro_sam.util as util
from micro_sam.instance_segmentation import AMGBase, get_decoder
from micro_sam.precompute_state import cache_amg_state, cache_is_state
Expand Down Expand Up @@ -69,6 +70,9 @@ class AnnotatorState(metaclass=Singleton):
# z-range to limit the data being committed in 3d / tracking.
z_range: Optional[Tuple[int, int]] = None

# annotator_class
annotator: Optional["micro_sam.sam_annotator._annotator._AnnotatorBase"] = None

def initialize_predictor(
self,
image_data,
Expand Down
49 changes: 38 additions & 11 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
# from napari.qt.threading import thread_worker
from napari.utils import progress

from ._state import AnnotatorState
from . import util as vutil
from ._tooltips import get_tooltip
from ._state import AnnotatorState
from .. import instance_segmentation, util
from ..multi_dimensional_segmentation import (
segment_mask_in_volume, merge_instance_segmentation_3d, track_across_frames, PROJECTION_MODES, get_napari_track_data
Expand Down Expand Up @@ -509,6 +509,11 @@ def commit(
commit_path: Select a file path where the committed results and prompts will be saved.
This feature is still experimental.
"""

# Check whether all layers exist as expected or create new ones automatically.
state = AnnotatorState()
state.annotator._require_layers(layer_choice=layer)

_, seg, mask, bb = _commit_impl(viewer, layer, preserve_committed)

if commit_path is not None:
Expand Down Expand Up @@ -551,6 +556,10 @@ def commit_track(
# Commit the segmentation layer.
id_offset, seg, mask, bb = _commit_impl(viewer, layer, preserve_committed)

# Check whether all layers exist as expected or create new ones automatically.
if _validate_layers(viewer):
return None

# Update the lineages.
state = AnnotatorState()
lineage = state.lineage
Expand Down Expand Up @@ -715,12 +724,25 @@ def _validate_embeddings(viewer: "napari.viewer.Viewer"):
# return False


def _validate_prompts(viewer: "napari.viewer.Viewer") -> bool:
if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0:
msg = "No prompts were given. Please provide prompts to run interactive segmentation."
return _generate_message("error", msg)
else:
return False
def _validation_window_for_missing_layer(layer_choice):
return _generate_message(
message_type="error",
message=f"The '{layer_choice}' layer to commit is missing. Please re-annotate and try again."
)


def _validate_layers(viewer: "napari.viewer.Viewer", automatic_segmentation: bool = False) -> bool:
# Check whether all layers exist as expected or create new ones automatically.
state = AnnotatorState()
state.annotator._require_layers()

if not automatic_segmentation:
# Check prompts layer.
if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0:
msg = "No prompts were given. Please provide prompts to run interactive segmentation."
return _generate_message("error", msg)
else:
return False


@magic_factory(call_button="Segment Object [S]")
Expand All @@ -733,7 +755,7 @@ def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None:
"""
if _validate_embeddings(viewer):
return None
if _validate_prompts(viewer):
if _validate_layers(viewer):
return None

shape = viewer.layers["current_object"].data.shape
Expand Down Expand Up @@ -767,7 +789,7 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None:
"""
if _validate_embeddings(viewer):
return None
if _validate_prompts(viewer):
if _validate_layers(viewer):
return None

shape = viewer.layers["current_object"].data.shape[1:]
Expand Down Expand Up @@ -808,8 +830,9 @@ def segment_frame(viewer: "napari.viewer.Viewer") -> None:
"""
if _validate_embeddings(viewer):
return None
if _validate_prompts(viewer):
if _validate_layers(viewer):
return None

state = AnnotatorState()
shape = state.image_shape[1:]
position = viewer.dims.point
Expand Down Expand Up @@ -1486,8 +1509,9 @@ def update_segmentation(seg):
def __call__(self):
if _validate_embeddings(self._viewer):
return None
if _validate_prompts(self._viewer):
if _validate_layers(self._viewer):
return None

if self.tracking:
return self._run_tracking()
else:
Expand Down Expand Up @@ -1749,6 +1773,9 @@ def update_segmentation(seg):
self._viewer.layers["auto_segmentation"].data[i] = seg
self._viewer.layers["auto_segmentation"].refresh()

# Validate all layers.
_validate_layers(self._viewer, automatic_segmentation=True)

seg = seg_impl()
update_segmentation(seg)
# worker = seg_impl()
Expand Down
13 changes: 11 additions & 2 deletions micro_sam/sam_annotator/annotator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@ def _get_widgets(self):
"clear": widgets.clear(),
}

def __init__(self, viewer: "napari.viewer.Viewer") -> None:
def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
super().__init__(viewer=viewer, ndim=2)

# Set the expected annotator class to the state.
state = AnnotatorState()

# Reset the state.
if reset_state:
state.reset_state()

state.annotator = self


def annotator_2d(
image: np.ndarray,
Expand Down Expand Up @@ -85,7 +94,7 @@ def annotator_2d(
viewer = napari.Viewer()

viewer.add_image(image, name="image")
annotator = Annotator2d(viewer)
annotator = Annotator2d(viewer, reset_state=False)

# Trigger layer update of the annotator so that layers have the correct shape.
# And initialize the 'committed_objects' with the segmentation result if it was given.
Expand Down
13 changes: 11 additions & 2 deletions micro_sam/sam_annotator/annotator_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,19 @@ def _get_widgets(self):
"clear": widgets.clear_volume(),
}

def __init__(self, viewer: "napari.viewer.Viewer") -> None:
def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
self._with_decoder = AnnotatorState().decoder is not None
super().__init__(viewer=viewer, ndim=3)

# Set the expected annotator class to the state.
state = AnnotatorState()

# Reset the state.
if reset_state:
state.reset_state()

state.annotator = self

def _update_image(self, segmentation_result=None):
super()._update_image(segmentation_result=segmentation_result)
# Load the amg state from the embedding path.
Expand Down Expand Up @@ -94,7 +103,7 @@ def annotator_3d(
viewer = napari.Viewer()

viewer.add_image(image, name="image")
annotator = Annotator3d(viewer)
annotator = Annotator3d(viewer, reset_state=False)

# Trigger layer update of the annotator so that layers have the correct shape.
# And initialize the 'committed_objects' with the segmentation result if it was given.
Expand Down
62 changes: 52 additions & 10 deletions micro_sam/sam_annotator/annotator_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,43 @@ class AnnotatorTracking(_AnnotatorBase):
# The tracking annotator needs different settings for the prompt layers
# to support the additional tracking state.
# That's why we over-ride this function.
def _create_layers(self):
def _require_layers(self, layer_choice=None):

# Check whether the image is initialized already. And use the image shape and scale for the layers.
state = AnnotatorState()
shape = self._shape if state.image_shape is None else state.image_shape

# Add the label layers for the current object, the automatic segmentation and the committed segmentation.
dummy_data = np.zeros(shape, dtype="uint32")
image_scale = state.image_scale

# Before adding new layers, we always check whether a layer with this name already exists or not.
if "current_object" not in self._viewer.layers:
if "current_object" == layer_choice: # Check at 'commit' call button.
widgets._validation_window_for_missing_layer(layer_choice)
self._viewer.add_labels(data=dummy_data, name="current_object")
if image_scale is not None:
self.layers["current_objects"].scale = image_scale

if "auto_segmentation" not in self._viewer.layers:
if "auto_segmentation" == layer_choice: # Check at 'commit' call button situation.
widgets._validation_window_for_missing_layer(layer_choice)
self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
if image_scale is not None:
self.layers["auto_segmentation"].scale = image_scale

if "committed_objects" not in self._viewer.layers:
self._viewer.add_labels(data=dummy_data, name="committed_objects")
# Randomize colors so it is easy to see when object committed.
self._viewer.layers["committed_objects"].new_colormap()
if image_scale is not None:
self.layers["committed_objects"].scale = image_scale

# Add the point prompts layer.
# NOTE: The lines below ensure that there is no existing 'point_prompts' layer with same name, and remove them.
if "point_prompts" in self._viewer.layers:
self._viewer.remove(self._viewer.layers["point_prompts"])

self._point_labels = ["positive", "negative"]
self._track_state_labels = ["track", "division"]

Expand All @@ -124,6 +160,11 @@ def _create_layers(self):
self._point_prompt_layer.border_color_mode = "cycle"
self._point_prompt_layer.face_color_mode = "cycle"

# Add the point prompts layer.
# NOTE: The lines below ensure that there is no existing 'prompts' layer with same name, and remove them.
if "prompts" in self._viewer.layers:
self._viewer.remove(self._viewer.layers["prompts"])

# Using the box layer to set divisions currently doesn't work.
# That's why some of the code below is commented out.
self._box_prompt_layer = self._viewer.add_shapes(
Expand All @@ -139,14 +180,6 @@ def _create_layers(self):
)
# self._box_prompt_layer.edge_color_mode = "cycle"

# Add the label layers for the current object, the automatic segmentation and the committed segmentation.
dummy_data = np.zeros(self._shape, dtype="uint32")
self._viewer.add_labels(data=dummy_data, name="current_object")
self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
self._viewer.add_labels(data=dummy_data, name="committed_objects")
# Randomize colors so it is easy to see when object committed.
self._viewer.layers["committed_objects"].new_colormap()

def _get_widgets(self):
state = AnnotatorState()
# Create the tracking state menu.
Expand All @@ -165,14 +198,23 @@ def _get_widgets(self):
"clear": widgets.clear_track(),
}

def __init__(self, viewer: "napari.viewer.Viewer") -> None:
def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
# Initialize the state for tracking.
self._init_track_state()
self._with_decoder = AnnotatorState().decoder is not None
super().__init__(viewer=viewer, ndim=3)
# Go to t=0.
self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])

# Set the expected annotator class to the state.
state = AnnotatorState()

# Reset the state.
if reset_state:
state.reset_state()

state.annotator = self

def _init_track_state(self):
state = AnnotatorState()
state.current_track_id = 1
Expand Down
File renamed without changes.