Skip to content

Commit 038500d

Browse files
talmoclaude
andauthored
feat(model): mask unused_predictions + link-first mask merge (#474 Gap 1) (#478)
Add the segmentation-mask analogue of LabeledFrame.unused_predictions and make the auto-merge cascade honor the from_predicted provenance link for masks, mirroring the pose human-in-the-loop flow. - LabeledFrame.unused_predicted_masks: reports PredictedSegmentationMask objects with no adopting user mask. A prediction counts as adopted when a user mask in the frame links to it via from_predicted (checked first) or, lacking a link, spatially overlaps it (bbox-centroid within the 5px auto-merge default). - Auto-merge link-first matching: _resolve_annotation_auto now seeds its match set with from_predicted links (score inf, bypassing the distance threshold) before spatial centroid matching, so an adopted correction replaces its exact source prediction regardless of distance. Implemented generically via _find_annotation_link_matches; only masks carry the link today, so other modalities fall through to unchanged spatial behavior. - Docs: document the link-first precedence in merging.md and unused_predicted_masks in segmentation.md. Scope per design discussion: masks only, minimal/additive (no new public matcher abstraction). A follow-up issue tracks generalizing this to an AnnotationMatcher and adding from_predicted to the other dense modalities. Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 31f904f commit 038500d

4 files changed

Lines changed: 322 additions & 2 deletions

File tree

docs/merging.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,16 @@ same threshold as instance matching (default 5 pixels). Each modality is resolve
613613
independently — centroids by `(x, y)`, bounding boxes and ROIs by their centroid, and
614614
masks by the centroid of their bounding box.
615615

616+
For segmentation masks, an explicit provenance link takes precedence over spatial
617+
matching. If a `UserSegmentationMask` records (via `from_predicted`, set by
618+
[`PredictedSegmentationMask.to_user()`](model/segmentation.md#adopting-predictions-human-in-the-loop))
619+
that it was adopted from a `PredictedSegmentationMask` present in the merge, the two are
620+
paired directly — the user correction replaces its exact source prediction regardless of
621+
centroid distance — and spatial matching only resolves the remaining, unlinked
622+
annotations. Other modalities do not yet carry a `from_predicted` link and are matched
623+
spatially only. To list predicted masks that have not been adopted (by link or spatial
624+
overlap), use `LabeledFrame.unused_predicted_masks`.
625+
616626
New frames (no matching frame in the target) always copy all annotations from the
617627
source, regardless of strategy.
618628

docs/model/segmentation.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,31 @@ resolution can happen later at merge time (see [Merging](../merging.md)):
155155

156156
```
157157

158+
To find predicted masks that have **not** yet been corrected — the segmentation
159+
analogue of `LabeledFrame.unused_predictions` for poses — use
160+
`LabeledFrame.unused_predicted_masks`. A `PredictedSegmentationMask` is treated
161+
as adopted (and excluded) when a `UserSegmentationMask` in the same frame links
162+
to it via `from_predicted` (checked first), or, lacking a link, spatially
163+
overlaps it (bbox-centroid within 5 px). This drives the "retrain only what a
164+
human corrected" workflow:
165+
166+
```pycon
167+
>>> import numpy as np
168+
>>> import sleap_io as sio
169+
>>> video = sio.Video(filename="example.mp4", open_backend=False)
170+
>>> mask_data = np.zeros((100, 100), dtype=bool)
171+
>>> mask_data[20:40, 30:60] = True
172+
>>> pred_a = sio.PredictedSegmentationMask.from_numpy(mask_data, score=0.87)
173+
>>> pred_b = sio.PredictedSegmentationMask.from_numpy(mask_data, score=0.62)
174+
>>> pred_b.offset = (500.0, 500.0) # a separate prediction elsewhere in the frame
175+
>>> frame = sio.LabeledFrame(video=video, frame_idx=0, masks=[pred_a, pred_b])
176+
>>> frame.masks.append(pred_a.to_user()) # adopt pred_a, leave pred_b
177+
>>> unused = frame.unused_predicted_masks # only the uncorrected prediction
178+
>>> len(unused), unused[0] is pred_b
179+
(1, True)
180+
181+
```
182+
158183
### Multi-resolution masks
159184

160185
Segmentation masks stored at lower resolution — e.g., from a model that

sleap_io/model/labeled_frame.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,47 @@ def _find_annotation_matches(
9595
return matches
9696

9797

98+
def _find_annotation_link_matches(
99+
self_list: list,
100+
other_list: list,
101+
) -> list[tuple[int, int, float]]:
102+
"""Find user<->predicted matches via the ``from_predicted`` provenance link.
103+
104+
A match is recorded whenever a user annotation in one list explicitly records
105+
(via ``from_predicted``) that it was adopted from a predicted annotation in
106+
the other list. These take precedence over spatial matching (score is
107+
``inf``), so an adopted correction resolves against its exact source
108+
prediction regardless of centroid distance. Only modalities that carry
109+
``from_predicted`` (segmentation masks) can produce link matches; for every
110+
other modality this returns an empty list and merge behavior is unchanged.
111+
112+
Args:
113+
self_list: Annotations from the self frame.
114+
other_list: Annotations from the other frame.
115+
116+
Returns:
117+
List of ``(self_idx, other_idx, inf)`` tuples for linked pairs.
118+
"""
119+
matches = []
120+
other_id_to_idx = {id(b): j for j, b in enumerate(other_list)}
121+
self_id_to_idx = {id(a): i for i, a in enumerate(self_list)}
122+
# User annotation in self linked to a predicted annotation in other.
123+
for i, a in enumerate(self_list):
124+
src = getattr(a, "from_predicted", None)
125+
if src is not None:
126+
j = other_id_to_idx.get(id(src))
127+
if j is not None:
128+
matches.append((i, j, float("inf")))
129+
# User annotation in other linked to a predicted annotation in self.
130+
for j, b in enumerate(other_list):
131+
src = getattr(b, "from_predicted", None)
132+
if src is not None:
133+
i = self_id_to_idx.get(id(src))
134+
if i is not None:
135+
matches.append((i, j, float("inf")))
136+
return matches
137+
138+
98139
def _resolve_annotation_auto(
99140
self_list: list,
100141
other_list: list,
@@ -124,8 +165,13 @@ def _resolve_annotation_auto(
124165
if not ann.is_predicted:
125166
merged.append(ann)
126167

127-
# 2. Find spatial matches
128-
matches = _find_annotation_matches(self_list, other_list, attr, threshold)
168+
# 2. Find matches: explicit ``from_predicted`` links first (score ``inf``, so
169+
# the greedy pass below prefers them over spatial matches and ignores the
170+
# distance threshold), then spatial centroid matches as a fallback. Only
171+
# masks carry links today, so other modalities fall straight through to
172+
# spatial matching with unchanged behavior.
173+
matches = _find_annotation_link_matches(self_list, other_list)
174+
matches += _find_annotation_matches(self_list, other_list, attr, threshold)
129175

130176
# 3. Greedy one-to-one matching: sort by score descending, assign each
131177
# self/other index at most once so no annotation is silently dropped.
@@ -422,6 +468,47 @@ def unused_predictions(self) -> list[Instance]:
422468

423469
return unused_predictions
424470

471+
@property
472+
def unused_predicted_masks(self) -> list["SegmentationMask"]:
473+
"""Return predicted masks in this frame not yet adopted by a user mask.
474+
475+
A `PredictedSegmentationMask` is considered *adopted* (and so excluded
476+
from the result) when some `UserSegmentationMask` in the same frame
477+
either links to it via `from_predicted` (checked first) or, lacking an
478+
explicit link, spatially overlaps it (bbox-centroid distance within 5 px,
479+
the auto-merge default). This mirrors the link-first, spatial-fallback
480+
precedence used by the auto-merge cascade and supports the
481+
"retrain only what a human corrected" workflow.
482+
483+
This is the segmentation-mask analogue of `unused_predictions` (which
484+
covers `PredictedInstance` objects).
485+
486+
Returns:
487+
The `PredictedSegmentationMask` objects with no adopting user mask.
488+
"""
489+
from sleap_io.model.mask import PredictedSegmentationMask
490+
491+
predicted = [m for m in self.masks if isinstance(m, PredictedSegmentationMask)]
492+
if not predicted:
493+
return []
494+
user_masks = [m for m in self.masks if not m.is_predicted]
495+
496+
adopted: set[int] = set()
497+
# Link-first: predicted masks explicitly adopted via from_predicted.
498+
for u in user_masks:
499+
src = getattr(u, "from_predicted", None)
500+
if src is not None:
501+
adopted.add(id(src))
502+
# Spatial fallback: a user mask overlaps a still-unadopted prediction.
503+
remaining = [m for m in predicted if id(m) not in adopted]
504+
if remaining and user_masks:
505+
for self_idx, _other_idx, _score in _find_annotation_matches(
506+
remaining, user_masks, "masks", 5.0
507+
):
508+
adopted.add(id(remaining[self_idx]))
509+
510+
return [m for m in predicted if id(m) not in adopted]
511+
425512
def remove_predictions(self):
426513
"""Remove all predicted instances and annotations from the frame."""
427514
from sleap_io.model.bbox import PredictedBoundingBox

tests/model/test_labeled_frame.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from sleap_io import Instance, PredictedInstance, Skeleton, Track, Video
77
from sleap_io.model.labeled_frame import LabeledFrame
8+
from sleap_io.model.mask import PredictedSegmentationMask, UserSegmentationMask
89

910

1011
def test_labeled_frame():
@@ -187,6 +188,74 @@ def test_labeled_frame_unused_predictions():
187188
assert (lf2.unused_predictions[0].numpy() == 1).all()
188189

189190

191+
def test_unused_predicted_masks_none_when_no_predictions():
192+
"""A frame with no predicted masks reports no unused predictions."""
193+
video = Video("test.mp4")
194+
user = UserSegmentationMask.from_numpy(np.ones((5, 5), dtype=bool))
195+
lf = LabeledFrame(video=video, frame_idx=0, masks=[user])
196+
assert lf.unused_predicted_masks == []
197+
198+
199+
def test_unused_predicted_masks_unadopted_reported():
200+
"""A predicted mask with no adopting user mask is reported as unused."""
201+
video = Video("test.mp4")
202+
pred = PredictedSegmentationMask.from_numpy(np.ones((5, 5), dtype=bool), score=0.9)
203+
lf = LabeledFrame(video=video, frame_idx=0, masks=[pred])
204+
assert lf.unused_predicted_masks == [pred]
205+
206+
207+
def test_unused_predicted_masks_excludes_linked():
208+
"""A predicted mask adopted via from_predicted is not reported (link-first)."""
209+
video = Video("test.mp4")
210+
pred = PredictedSegmentationMask.from_numpy(np.ones((5, 5), dtype=bool), score=0.9)
211+
user = pred.to_user() # sets user.from_predicted = pred
212+
lf = LabeledFrame(video=video, frame_idx=0, masks=[pred, user])
213+
assert lf.unused_predicted_masks == []
214+
215+
216+
def test_unused_predicted_masks_link_overrides_distance():
217+
"""An explicit link counts as adopted even when the masks are far apart."""
218+
video = Video("test.mp4")
219+
pred = PredictedSegmentationMask.from_numpy(
220+
np.ones((5, 5), dtype=bool), score=0.9, offset=(0.0, 0.0)
221+
)
222+
user = pred.to_user()
223+
# Move the user mask far away; the from_predicted link should still count.
224+
user.offset = (500.0, 500.0)
225+
lf = LabeledFrame(video=video, frame_idx=0, masks=[pred, user])
226+
assert lf.unused_predicted_masks == []
227+
228+
229+
def test_unused_predicted_masks_spatial_fallback():
230+
"""An unlinked user mask overlapping a prediction adopts it spatially."""
231+
video = Video("test.mp4")
232+
pred = PredictedSegmentationMask.from_numpy(
233+
np.ones((10, 10), dtype=bool), score=0.9, offset=(5.0, 5.0)
234+
)
235+
# Unlinked user mask with an overlapping bbox centroid (within 5 px).
236+
user = UserSegmentationMask.from_numpy(
237+
np.ones((10, 10), dtype=bool), offset=(6.0, 6.0)
238+
)
239+
assert user.from_predicted is None
240+
lf = LabeledFrame(video=video, frame_idx=0, masks=[pred, user])
241+
assert lf.unused_predicted_masks == []
242+
243+
244+
def test_unused_predicted_masks_mixed():
245+
"""Only the prediction without an adopting user mask is reported."""
246+
video = Video("test.mp4")
247+
adopted = PredictedSegmentationMask.from_numpy(
248+
np.ones((5, 5), dtype=bool), score=0.9, offset=(0.0, 0.0)
249+
)
250+
user = adopted.to_user()
251+
# A second prediction far from any user mask remains unused.
252+
orphan = PredictedSegmentationMask.from_numpy(
253+
np.ones((5, 5), dtype=bool), score=0.8, offset=(500.0, 500.0)
254+
)
255+
lf = LabeledFrame(video=video, frame_idx=0, masks=[adopted, user, orphan])
256+
assert lf.unused_predicted_masks == [orphan]
257+
258+
190259
def test_labeled_frame_matches():
191260
"""Test LabeledFrame.matches() method."""
192261
video1 = Video(filename="test1.mp4")
@@ -1469,6 +1538,135 @@ def test_merge_annotations_auto_masks():
14691538
assert not lf1.masks[0].is_predicted
14701539

14711540

1541+
def test_merge_annotations_auto_masks_link_overrides_distance():
1542+
"""from_predicted link replaces the source prediction despite far distance."""
1543+
video = Video(filename="test.mp4", open_backend=False)
1544+
mask_data = np.ones((10, 10), dtype=bool)
1545+
# self holds the prediction; other holds a user correction adopted from it
1546+
# but moved far away (well beyond the 5 px spatial threshold).
1547+
self_pred = PredictedSegmentationMask.from_numpy(
1548+
mask_data, score=0.7, offset=(5.0, 5.0)
1549+
)
1550+
other_user = self_pred.to_user()
1551+
other_user.offset = (500.0, 500.0)
1552+
1553+
lf1 = LabeledFrame(video=video, frame_idx=0, masks=[self_pred])
1554+
lf2 = LabeledFrame(video=video, frame_idx=0, masks=[other_user])
1555+
1556+
lf1._merge_annotations(lf2, strategy="auto")
1557+
1558+
# Spatial matching alone would keep both (too far apart); the link resolves
1559+
# them as the same annotation and the user correction wins.
1560+
assert len(lf1.masks) == 1
1561+
assert not lf1.masks[0].is_predicted
1562+
1563+
1564+
def test_merge_annotations_auto_masks_link_self_side():
1565+
"""from_predicted link is honored when the user correction lives in self."""
1566+
video = Video(filename="test.mp4", open_backend=False)
1567+
mask_data = np.ones((10, 10), dtype=bool)
1568+
# other holds the source prediction; self holds the user correction adopted
1569+
# from it, moved far away (beyond the spatial threshold).
1570+
other_pred = PredictedSegmentationMask.from_numpy(
1571+
mask_data, score=0.7, offset=(5.0, 5.0)
1572+
)
1573+
self_user = other_pred.to_user()
1574+
self_user.offset = (500.0, 500.0)
1575+
1576+
lf1 = LabeledFrame(video=video, frame_idx=0, masks=[self_user])
1577+
lf2 = LabeledFrame(video=video, frame_idx=0, masks=[other_pred])
1578+
1579+
lf1._merge_annotations(lf2, strategy="auto")
1580+
1581+
# The user correction in self is kept and its linked source prediction from
1582+
# other is dropped, despite the large spatial distance.
1583+
assert len(lf1.masks) == 1
1584+
assert not lf1.masks[0].is_predicted
1585+
1586+
1587+
def test_merge_annotations_auto_masks_link_beats_spatial_decoy():
1588+
"""The link pairs with the true source, not a closer spatial decoy."""
1589+
video = Video(filename="test.mp4", open_backend=False)
1590+
mask_data = np.ones((6, 6), dtype=bool)
1591+
# True source the user adopted from, placed far from the user mask.
1592+
true_src = PredictedSegmentationMask.from_numpy(
1593+
mask_data, score=0.6, offset=(100.0, 100.0)
1594+
)
1595+
# A decoy prediction sitting right on top of the user mask.
1596+
decoy = PredictedSegmentationMask.from_numpy(
1597+
mask_data, score=0.9, offset=(6.0, 6.0)
1598+
)
1599+
user = true_src.to_user()
1600+
user.offset = (5.0, 5.0) # spatially nearest to `decoy`
1601+
1602+
lf1 = LabeledFrame(video=video, frame_idx=0, masks=[true_src, decoy])
1603+
lf2 = LabeledFrame(video=video, frame_idx=0, masks=[user])
1604+
1605+
lf1._merge_annotations(lf2, strategy="auto")
1606+
1607+
# The user replaces its linked true source; the decoy stays as a prediction.
1608+
assert sum(not m.is_predicted for m in lf1.masks) == 1
1609+
remaining_pred = [m for m in lf1.masks if m.is_predicted]
1610+
assert remaining_pred == [decoy]
1611+
1612+
1613+
def test_merge_annotations_auto_masks_link_multiple_pairs():
1614+
"""Independent from_predicted links resolve in both directions in one merge."""
1615+
video = Video(filename="test.mp4", open_backend=False)
1616+
mask_data = np.ones((8, 8), dtype=bool)
1617+
# Two source predictions, each adopted by a user correction in the *other*
1618+
# frame, with every mask placed far apart so only the links can pair them.
1619+
self_pred = PredictedSegmentationMask.from_numpy(
1620+
mask_data, score=0.5, offset=(200.0, 200.0)
1621+
)
1622+
other_pred = PredictedSegmentationMask.from_numpy(
1623+
mask_data, score=0.6, offset=(600.0, 600.0)
1624+
)
1625+
self_user = other_pred.to_user() # self user adopted from other's prediction
1626+
self_user.offset = (10.0, 10.0)
1627+
other_user = self_pred.to_user() # other user adopted from self's prediction
1628+
other_user.offset = (400.0, 400.0)
1629+
1630+
lf1 = LabeledFrame(video=video, frame_idx=0, masks=[self_user, self_pred])
1631+
lf2 = LabeledFrame(video=video, frame_idx=0, masks=[other_user, other_pred])
1632+
1633+
lf1._merge_annotations(lf2, strategy="auto")
1634+
1635+
# Both predictions are superseded by their linked corrections; only the two
1636+
# user masks remain.
1637+
assert len(lf1.masks) == 2
1638+
assert all(not m.is_predicted for m in lf1.masks)
1639+
1640+
1641+
def test_merge_annotations_auto_masks_link_source_absent():
1642+
"""A from_predicted link to a prediction absent from the merge falls back.
1643+
1644+
When the linked source is not present in the opposing frame, no link match is
1645+
produced (the link cannot be honored) and matching falls back to spatial
1646+
behavior.
1647+
"""
1648+
video = Video(filename="test.mp4", open_backend=False)
1649+
mask_data = np.ones((8, 8), dtype=bool)
1650+
external = PredictedSegmentationMask.from_numpy(mask_data, score=0.5)
1651+
1652+
# self's user links to `external` (not in other); other's user links to
1653+
# `external` too (not in self). Neither link can resolve to the opposing
1654+
# frame, and the two user masks are far apart.
1655+
self_user = external.to_user()
1656+
self_user.offset = (10.0, 10.0)
1657+
other_user = external.to_user()
1658+
other_user.offset = (900.0, 900.0)
1659+
1660+
lf1 = LabeledFrame(video=video, frame_idx=0, masks=[self_user])
1661+
lf2 = LabeledFrame(video=video, frame_idx=0, masks=[other_user])
1662+
1663+
lf1._merge_annotations(lf2, strategy="auto")
1664+
1665+
# Unresolvable links + far apart → both user masks are kept.
1666+
assert len(lf1.masks) == 2
1667+
assert all(not m.is_predicted for m in lf1.masks)
1668+
1669+
14721670
def test_merge_annotations_update_tracks_cascades():
14731671
"""Update_tracks updates annotation tracks from spatially matched other."""
14741672
from sleap_io.model.centroid import UserCentroid

0 commit comments

Comments
 (0)