Skip to content

Commit dc65c15

Browse files
committed
add back filter post init validator
1 parent 3e895e4 commit dc65c15

5 files changed

Lines changed: 123 additions & 2 deletions

File tree

src/valor_lite/object_detection/manager.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,21 @@ class Filter:
9999
mask_predictions: NDArray[np.bool_]
100100
metadata: Metadata
101101

102+
def __post_init__(self):
103+
# validate datums mask
104+
if not self.mask_datums.any():
105+
raise EmptyFilterError("filter removes all datums")
106+
107+
# validate annotation masks
108+
no_gts = self.mask_groundtruths.all()
109+
no_pds = self.mask_predictions.all()
110+
if no_gts and no_pds:
111+
raise EmptyFilterError("filter removes all annotations")
112+
elif no_gts:
113+
warnings.warn("filter removes all ground truths")
114+
elif no_pds:
115+
warnings.warn("filter removes all predictions")
116+
102117

103118
class Evaluator:
104119
"""

src/valor_lite/semantic_segmentation/manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ class Filter:
7171
label_mask: NDArray[np.bool_]
7272
metadata: Metadata
7373

74+
def __post_init__(self):
75+
# validate datum mask
76+
if not self.datum_mask.any():
77+
raise EmptyFilterError("filter removes all datums")
78+
79+
# validate label mask
80+
if self.label_mask.all():
81+
raise EmptyFilterError("filter removes all labels")
82+
7483

7584
class Evaluator:
7685
"""

tests/classification/test_filtering.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
import numpy as np
55
import pytest
66

7-
from valor_lite.classification import Classification, DataLoader, MetricType
7+
from valor_lite.classification import (
8+
Classification,
9+
DataLoader,
10+
Filter,
11+
Metadata,
12+
MetricType,
13+
)
814
from valor_lite.exceptions import EmptyFilterError
915

1016

@@ -890,3 +896,23 @@ def test_filtering_six_classifications_by_indices(
890896
assert m in expected_metrics
891897
for m in expected_metrics:
892898
assert m in actual_metrics
899+
900+
901+
def test_filter_object():
902+
903+
# check that no datums are defined
904+
with pytest.raises(EmptyFilterError) as e:
905+
Filter(
906+
datum_mask=np.array([False, False, False]),
907+
valid_label_indices=np.array([0, 1, 2]),
908+
metadata=Metadata(),
909+
)
910+
assert "filter removes all datums" in str(e)
911+
912+
# check that no labels are defined
913+
with pytest.raises(EmptyFilterError) as e:
914+
Filter(
915+
datum_mask=np.array([True, False, False]),
916+
valid_label_indices=np.array([]),
917+
metadata=Metadata(),
918+
)

tests/object_detection/test_filtering.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
BoundingBox,
1010
DataLoader,
1111
Detection,
12+
Filter,
13+
Metadata,
1214
MetricType,
1315
)
1416

@@ -780,3 +782,47 @@ def test_filtering_invalid_indices(
780782
with pytest.raises(ValueError) as e:
781783
evaluator.create_filter(labels=np.array([1000]))
782784
assert "cannot exceed total number of labels" in str(e)
785+
786+
787+
def test_filter_object():
788+
789+
mask = np.array([True, False, False])
790+
true_mask = np.array([True, True, True])
791+
false_mask = ~true_mask
792+
793+
# check that no datums are defined
794+
with pytest.raises(EmptyFilterError) as e:
795+
Filter(
796+
mask_datums=false_mask,
797+
mask_groundtruths=mask,
798+
mask_predictions=mask,
799+
metadata=Metadata(),
800+
)
801+
assert "filter removes all datums" in str(e)
802+
803+
# check that no annotations are defined
804+
with pytest.raises(EmptyFilterError) as e:
805+
Filter(
806+
mask_datums=mask,
807+
mask_groundtruths=true_mask,
808+
mask_predictions=true_mask,
809+
metadata=Metadata(),
810+
)
811+
812+
# check that no ground truths are defined
813+
with pytest.warns(UserWarning):
814+
Filter(
815+
mask_datums=mask,
816+
mask_groundtruths=true_mask,
817+
mask_predictions=mask,
818+
metadata=Metadata(),
819+
)
820+
821+
# check that no predictions are defined
822+
with pytest.warns(UserWarning):
823+
Filter(
824+
mask_datums=mask,
825+
mask_groundtruths=mask,
826+
mask_predictions=true_mask,
827+
metadata=Metadata(),
828+
)

tests/semantic_segmentation/test_filtering.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import pytest
33

44
from valor_lite.exceptions import EmptyFilterError
5-
from valor_lite.semantic_segmentation import DataLoader, Segmentation
5+
from valor_lite.semantic_segmentation import (
6+
DataLoader,
7+
Filter,
8+
Metadata,
9+
Segmentation,
10+
)
611

712

813
def test_filtering(segmentations_from_boxes: list[Segmentation]):
@@ -182,3 +187,23 @@ def test_filtering_invalid_indices(
182187
with pytest.raises(ValueError) as e:
183188
evaluator.create_filter(labels=np.array([1000]))
184189
assert "cannot exceed total number of labels" in str(e)
190+
191+
192+
def test_filter_object():
193+
194+
mask = np.array([True, False, False])
195+
true_mask = np.array([True, True, True])
196+
false_mask = ~true_mask
197+
198+
# check that no datums are defined
199+
with pytest.raises(EmptyFilterError) as e:
200+
Filter(datum_mask=false_mask, label_mask=mask, metadata=Metadata())
201+
assert "filter removes all datums" in str(e)
202+
203+
# check that no labels are defined
204+
with pytest.raises(EmptyFilterError) as e:
205+
Filter(
206+
datum_mask=mask,
207+
label_mask=true_mask,
208+
metadata=Metadata(),
209+
)

0 commit comments

Comments
 (0)