Skip to content

Commit ce7f940

Browse files
committed
added filtering tests
1 parent af31bf4 commit ce7f940

4 files changed

Lines changed: 121 additions & 8 deletions

File tree

src/valor_lite/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def convert_type_mapping_to_schema(
6060
"""
6161
if not type_mapping:
6262
return []
63-
return [(k, v.to_arrow()) for k, v in type_mapping.items()]
63+
return [(k, DataType(v).to_arrow()) for k, v in type_mapping.items()]
6464

6565

6666
class CacheReader:

src/valor_lite/semantic_segmentation/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,8 @@ def filter(
350350
pairs[~mask_valid_gt, 1] = -1
351351
pairs[~mask_valid_pd, 2] = -1
352352

353+
print(pairs)
354+
353355
for idx, col in enumerate(columns):
354356
tbl = tbl.set_column(
355357
tbl.schema.names.index(col), col, pa.array(pairs[:, idx])

tests/semantic_segmentation/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ def _generate_boolean_mask(
1818
return Bitmask(
1919
mask=mask,
2020
label=label,
21+
metadata={
22+
"gt_xmin": xmin,
23+
"pd_xmin": xmin,
24+
},
2125
)
2226

2327

tests/semantic_segmentation/test_filtering.py

Lines changed: 114 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
11
import numpy as np
2+
import pyarrow.compute as pc
23
import pytest
34

5+
from valor_lite.cache import DataType
46
from valor_lite.exceptions import EmptyFilterError
57
from valor_lite.semantic_segmentation import DataLoader, Segmentation
8+
from valor_lite.semantic_segmentation.evaluator import Filter
69

710

8-
def test_filtering(segmentations_from_boxes: list[Segmentation]):
11+
def test_filtering_raises(segmentations_from_boxes: list[Segmentation]):
12+
13+
loader = DataLoader()
14+
loader.add_data(segmentations_from_boxes)
15+
evaluator = loader.finalize()
16+
assert evaluator._confusion_matrix.shape == (3, 3)
17+
18+
with pytest.raises(EmptyFilterError):
19+
evaluator.create_filter(datums=[])
20+
assert evaluator._confusion_matrix.shape == (3, 3)
21+
22+
23+
def test_filtering_by_datum(segmentations_from_boxes: list[Segmentation]):
924

1025
loader = DataLoader()
1126
loader.add_data(segmentations_from_boxes)
@@ -51,13 +66,105 @@ def test_filtering(segmentations_from_boxes: list[Segmentation]):
5166
evaluator.create_filter(datums=[])
5267

5368

54-
def test_filtering_raises(segmentations_from_boxes: list[Segmentation]):
69+
def test_filtering_by_annotation_metadata(
70+
segmentations_from_boxes: list[Segmentation],
71+
):
5572

56-
loader = DataLoader()
73+
loader = DataLoader(
74+
groundtruth_metadata_types={
75+
"gt_xmin": DataType.FLOAT,
76+
},
77+
prediction_metadata_types={
78+
"pd_xmin": DataType.FLOAT,
79+
},
80+
)
5781
loader.add_data(segmentations_from_boxes)
5882
evaluator = loader.finalize()
59-
assert evaluator._confusion_matrix.shape == (3, 3)
6083

61-
with pytest.raises(EmptyFilterError):
62-
evaluator.create_filter(datums=[])
63-
assert evaluator._confusion_matrix.shape == (3, 3)
84+
total_pixels = 540_000
85+
assert evaluator.metadata.number_of_datums == 2
86+
assert evaluator.metadata.number_of_labels == 2
87+
assert evaluator.metadata.number_of_ground_truths == 25000
88+
assert evaluator.metadata.number_of_predictions == 15000
89+
assert evaluator.metadata.number_of_pixels == total_pixels
90+
91+
# test groundtruth filtering
92+
filter_ = Filter(groundtruths=pc.field("gt_xmin") < 100)
93+
filtered_evaluator = evaluator.filter(filter_)
94+
confusion_matrix = filtered_evaluator._confusion_matrix
95+
assert np.all(
96+
confusion_matrix
97+
== np.array(
98+
[
99+
[520000, 5000, 5000],
100+
[5000, 5000, 0],
101+
[0, 0, 0],
102+
]
103+
)
104+
)
105+
assert confusion_matrix.sum() == total_pixels
106+
107+
filter_ = Filter(groundtruths=pc.field("gt_xmin") > 100)
108+
filtered_evaluator = evaluator.filter(filter_)
109+
confusion_matrix = filtered_evaluator._confusion_matrix
110+
assert np.all(
111+
confusion_matrix
112+
== np.array(
113+
[
114+
[510001, 10000, 4999],
115+
[0, 0, 0],
116+
[14999, 0, 1],
117+
]
118+
)
119+
)
120+
assert confusion_matrix.sum() == total_pixels
121+
122+
# test prediction filtering
123+
filter_ = Filter(predictions=pc.field("pd_xmin") < 100)
124+
filtered_evaluator = evaluator.filter(filter_)
125+
confusion_matrix = filtered_evaluator._confusion_matrix
126+
assert np.all(
127+
confusion_matrix
128+
== np.array(
129+
[
130+
[510000, 5000, 0],
131+
[5000, 5000, 0],
132+
[15000, 0, 0],
133+
]
134+
)
135+
)
136+
assert confusion_matrix.sum() == total_pixels
137+
138+
filter_ = Filter(predictions=pc.field("pd_xmin") > 100)
139+
filtered_evaluator = evaluator.filter(filter_)
140+
confusion_matrix = filtered_evaluator._confusion_matrix
141+
assert np.all(
142+
confusion_matrix
143+
== np.array(
144+
[
145+
[510001, 0, 4999],
146+
[10000, 0, 0],
147+
[14999, 0, 1],
148+
]
149+
)
150+
)
151+
assert confusion_matrix.sum() == total_pixels
152+
153+
# filter out all gts and pds
154+
filter_ = Filter(
155+
groundtruths=pc.field("gt_xmin") > 1000,
156+
predictions=pc.field("pd_xmin") > 1000,
157+
)
158+
filtered_evaluator = evaluator.filter(filter_)
159+
confusion_matrix = filtered_evaluator._confusion_matrix
160+
assert np.all(
161+
confusion_matrix
162+
== np.array(
163+
[
164+
[total_pixels, 0, 0],
165+
[0, 0, 0],
166+
[0, 0, 0],
167+
]
168+
)
169+
)
170+
assert confusion_matrix.sum() == total_pixels

0 commit comments

Comments
 (0)