Skip to content

Commit a968529

Browse files
authored
handle overlapped segmentations more gracefully (#849)
1 parent e66ccf5 commit a968529

2 files changed

Lines changed: 99 additions & 68 deletions

File tree

src/valor_lite/semantic_segmentation/annotation.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from dataclasses import dataclass, field
23

34
import numpy as np
@@ -79,30 +80,37 @@ def __post_init__(self):
7980
)
8081
self.size = self.shape[0] * self.shape[1]
8182

82-
mask_accumulation = None
83-
for groundtruth in self.groundtruths:
84-
if self.shape != groundtruth.mask.shape:
85-
raise ValueError(
86-
f"ground truth masks for datum '{self.uid}' should have shape '{self.shape}'. Received mask with shape '{groundtruth.mask.shape}'"
87-
)
88-
89-
if mask_accumulation is None:
90-
mask_accumulation = groundtruth.mask.copy()
91-
elif np.logical_and(mask_accumulation, groundtruth.mask).any():
92-
raise ValueError("ground truth masks cannot overlap")
93-
else:
94-
mask_accumulation = mask_accumulation | groundtruth.mask
83+
self._validate_bitmasks(self.groundtruths, "ground truth")
84+
self._validate_bitmasks(self.predictions, "prediction")
9585

86+
def _validate_bitmasks(self, bitmasks: list[Bitmask], key: str):
9687
mask_accumulation = None
97-
for prediction in self.predictions:
98-
if self.shape != prediction.mask.shape:
88+
mask_overlap_accumulation = None
89+
for idx, bitmask in enumerate(bitmasks):
90+
if not isinstance(bitmask, Bitmask):
91+
raise ValueError(f"expected 'Bitmask', got '{bitmask}'")
92+
if self.shape != bitmask.mask.shape:
9993
raise ValueError(
100-
f"prediction masks for datum '{self.uid}' should have shape '{self.shape}'. Received mask with shape '{prediction.mask.shape}'"
94+
f"{key} masks for datum '{self.uid}' should have shape '{self.shape}'. Received mask with shape '{bitmask.mask.shape}'"
10195
)
10296

10397
if mask_accumulation is None:
104-
mask_accumulation = prediction.mask.copy()
105-
elif np.logical_and(mask_accumulation, prediction.mask).any():
106-
raise ValueError("prediction masks cannot overlap")
98+
mask_accumulation = bitmask.mask.copy()
99+
mask_overlap_accumulation = np.zeros_like(mask_accumulation)
100+
elif np.logical_and(mask_accumulation, bitmask.mask).any():
101+
mask_overlap = np.logical_and(mask_accumulation, bitmask.mask)
102+
bitmasks[idx].mask[mask_overlap] = False
103+
mask_overlap_accumulation = (
104+
mask_overlap_accumulation | mask_overlap
105+
)
107106
else:
108-
mask_accumulation = mask_accumulation | prediction.mask
107+
mask_accumulation = mask_accumulation | bitmask.mask
108+
if (
109+
mask_overlap_accumulation is not None
110+
and mask_overlap_accumulation.any()
111+
):
112+
count = mask_overlap_accumulation.sum()
113+
total = mask_overlap_accumulation.size
114+
warnings.warn(
115+
f"{key} masks for datum '{self.uid}' had {count} / {total} pixels overlapped."
116+
)

tests/semantic_segmentation/test_annotation.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -105,54 +105,6 @@ def test_segmentation():
105105
)
106106
assert "Received mask with shape '(1, 2)'" in str(e)
107107

108-
# test ground truths cannot overlap
109-
with pytest.raises(ValueError) as e:
110-
Segmentation(
111-
uid="uid",
112-
groundtruths=[
113-
Bitmask(
114-
mask=np.array([[True, True, True]]),
115-
label="label1",
116-
),
117-
Bitmask(
118-
mask=np.array([[False, False, True]]),
119-
label="label2",
120-
),
121-
],
122-
predictions=[
123-
Bitmask(
124-
mask=np.array([[True, False, False]]),
125-
label="label",
126-
)
127-
],
128-
shape=(1, 3),
129-
)
130-
assert "ground truth masks cannot overlap" in str(e)
131-
132-
# test predictions cannot overlap
133-
with pytest.raises(ValueError) as e:
134-
Segmentation(
135-
uid="uid",
136-
groundtruths=[
137-
Bitmask(
138-
mask=np.array([[True, True, True]]),
139-
label="label1",
140-
),
141-
],
142-
predictions=[
143-
Bitmask(
144-
mask=np.array([[True, False, True]]),
145-
label="label",
146-
),
147-
Bitmask(
148-
mask=np.array([[False, False, True]]),
149-
label="label2",
150-
),
151-
],
152-
shape=(1, 3),
153-
)
154-
assert "prediction masks cannot overlap" in str(e)
155-
156108
# allow missing ground truths
157109
Segmentation(
158110
uid="uid",
@@ -179,6 +131,22 @@ def test_segmentation():
179131
shape=(1, 2),
180132
)
181133

134+
# wrong annotation type
135+
with pytest.raises(ValueError):
136+
Segmentation(
137+
uid="uid",
138+
groundtruths=[{"a": 1}], # type: ignore - testing
139+
predictions=[],
140+
shape=(1, 2),
141+
)
142+
with pytest.raises(ValueError):
143+
Segmentation(
144+
uid="uid",
145+
groundtruths=[],
146+
predictions=[{"a": 1}], # type: ignore - testing
147+
shape=(1, 2),
148+
)
149+
182150

183151
def test_segmentation_shape():
184152
Segmentation(uid="uid", groundtruths=[], predictions=[], shape=(1, 1))
@@ -197,3 +165,58 @@ def test_segmentation_shape():
197165
Segmentation(
198166
uid="uid", groundtruths=[], predictions=[], shape=(-100, 100)
199167
)
168+
169+
170+
def _create_overlapped_masks() -> tuple[Bitmask, Bitmask]:
171+
mask0 = np.zeros((100, 100), dtype=np.bool_)
172+
mask0[:50, :] = True
173+
mask1 = np.ones((100, 100), dtype=np.bool_)
174+
bitmask0 = Bitmask(mask=mask0, label="dog")
175+
bitmask1 = Bitmask(mask=mask1, label="cat")
176+
return bitmask0, bitmask1
177+
178+
179+
def test_segmentations_overlap():
180+
bitmask0, bitmask1 = _create_overlapped_masks()
181+
Segmentation(
182+
uid="uid123",
183+
groundtruths=[bitmask0],
184+
predictions=[bitmask1],
185+
shape=(100, 100),
186+
)
187+
assert bitmask0.mask.sum() == 5000
188+
assert bitmask1.mask.sum() == 10000
189+
190+
bitmask0, bitmask1 = _create_overlapped_masks()
191+
with pytest.warns(UserWarning) as e:
192+
Segmentation(
193+
uid="uid123",
194+
groundtruths=[bitmask0, bitmask1],
195+
predictions=[],
196+
shape=(100, 100),
197+
)
198+
assert (
199+
str(e._list[0].message)
200+
== "ground truth masks for datum 'uid123' had 5000 / 10000 pixels overlapped."
201+
)
202+
assert bitmask0.mask.sum() == 5000
203+
assert (
204+
bitmask1.mask.sum() == 5000
205+
) # overlapped pixels omitted from second mask
206+
207+
bitmask0, bitmask1 = _create_overlapped_masks()
208+
with pytest.warns(UserWarning) as e:
209+
Segmentation(
210+
uid="uid123",
211+
groundtruths=[],
212+
predictions=[bitmask0, bitmask1],
213+
shape=(100, 100),
214+
)
215+
assert (
216+
str(e._list[0].message)
217+
== "prediction masks for datum 'uid123' had 5000 / 10000 pixels overlapped."
218+
)
219+
assert bitmask0.mask.sum() == 5000
220+
assert (
221+
bitmask1.mask.sum() == 5000
222+
) # overlapped pixels omitted from second mask

0 commit comments

Comments
 (0)