Skip to content

Commit cf67f04

Browse files
committed
annotation filtering tests
1 parent 0c0d137 commit cf67f04

5 files changed

Lines changed: 181 additions & 32 deletions

File tree

src/valor_lite/object_detection/legacy.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import tempfile
22
from dataclasses import asdict, dataclass
3+
from pathlib import Path
34

45
import numpy as np
56
import pyarrow.compute as pc
@@ -44,33 +45,30 @@ def to_dict(self) -> dict[str, int | bool]:
4445
return asdict(self)
4546

4647

47-
class Evaluator:
48+
class Evaluator(CachedEvaluator):
4849
"""
4950
Legacy Object Detection Evaluator
5051
"""
5152

52-
def __init__(self, name: str = "default"):
53-
self._evaluator = CachedEvaluator(name=name)
54-
5553
@property
5654
def metadata(self) -> Metadata:
5755
"""
5856
Evaluation metadata.
5957
"""
6058
return Metadata(
61-
number_of_datums=self._evaluator.info.number_of_datums,
62-
number_of_labels=self._evaluator.info.number_of_labels,
63-
number_of_ground_truths=self._evaluator.info.number_of_groundtruth_annotations,
64-
number_of_predictions=self._evaluator.info.number_of_prediction_annotations,
59+
number_of_datums=self.info.number_of_datums,
60+
number_of_labels=self.info.number_of_labels,
61+
number_of_ground_truths=self.info.number_of_groundtruth_annotations,
62+
number_of_predictions=self.info.number_of_prediction_annotations,
6563
)
6664

6765
@property
6866
def _detailed_pairs(self) -> np.ndarray:
6967
return np.concatenate(
7068
[
7169
pairs
72-
for pairs in self._evaluator.iterate_pairs(
73-
self._evaluator._dataset,
70+
for pairs in self.iterate_pairs(
71+
self._dataset,
7472
columns=[
7573
"datum_id",
7674
"gt_id",
@@ -87,7 +85,7 @@ def _detailed_pairs(self) -> np.ndarray:
8785
@property
8886
def _label_metadata(self) -> np.ndarray:
8987
label_metadata = np.zeros(
90-
(len(self._evaluator._index_to_label), 2), dtype=np.int32
88+
(len(self._index_to_label), 2), dtype=np.int32
9189
)
9290

9391
# groundtruth labels
@@ -114,9 +112,10 @@ def _label_metadata(self) -> np.ndarray:
114112

115113
return label_metadata
116114

117-
def filter(
118-
self, filter_: Filter
119-
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
115+
def filter( # type: ignore - legacy function override does not match
116+
self,
117+
filter_: Filter,
118+
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32]]:
120119
"""
121120
Performs filtering over the internal cache.
122121
@@ -135,12 +134,17 @@ def filter(
135134
Label metadata.
136135
"""
137136
with tempfile.TemporaryDirectory() as tmpdir:
138-
evaluator = Evaluator()
139-
evaluator._evaluator = self._evaluator.filter(
137+
name = "filtered"
138+
_evaluator = super().filter(
140139
directory=tmpdir,
141-
name="filtered",
140+
name=name,
142141
filter_expr=filter_,
143142
)
143+
evaluator = Evaluator(
144+
name=name,
145+
directory=tmpdir,
146+
labels_override=_evaluator._index_to_label,
147+
)
144148
detailed_pairs = evaluator._detailed_pairs
145149
label_metadata = evaluator._label_metadata
146150
return detailed_pairs, detailed_pairs, label_metadata
@@ -229,7 +233,7 @@ def compute_precision_recall(
229233
"""
230234
if filter_ is not None:
231235
with tempfile.TemporaryDirectory() as tmpdir:
232-
evaluator = self._evaluator.filter(
236+
evaluator = super().filter(
233237
directory=tmpdir,
234238
name="filtered",
235239
filter_expr=filter_,
@@ -238,7 +242,7 @@ def compute_precision_recall(
238242
iou_thresholds=iou_thresholds,
239243
score_thresholds=score_thresholds,
240244
)
241-
return self._evaluator.compute_precision_recall(
245+
return super().compute_precision_recall(
242246
iou_thresholds=iou_thresholds,
243247
score_thresholds=score_thresholds,
244248
)
@@ -268,7 +272,7 @@ def compute_confusion_matrix(
268272
"""
269273
if filter_ is not None:
270274
with tempfile.TemporaryDirectory() as tmpdir:
271-
evaluator = self._evaluator.filter(
275+
evaluator = super().filter(
272276
directory=tmpdir,
273277
name="filtered",
274278
filter_expr=filter_,
@@ -278,7 +282,7 @@ def compute_confusion_matrix(
278282
score_thresholds=score_thresholds,
279283
)
280284
else:
281-
metrics = self._evaluator.compute_confusion_matrix_with_examples(
285+
metrics = super().compute_confusion_matrix_with_examples(
282286
iou_thresholds=iou_thresholds,
283287
score_thresholds=score_thresholds,
284288
)
@@ -328,12 +332,29 @@ class DataLoader(CachedLoader):
328332
Legacy Object Detection DataLoader
329333
"""
330334

331-
def __init__(self):
332-
super().__init__(
333-
batch_size=1_000,
334-
rows_per_file=10_000,
335+
def finalize(self) -> Evaluator: # type: ignore - switching type
336+
evaluator = super().finalize()
337+
return Evaluator(
338+
name=evaluator._name,
339+
directory=evaluator._directory,
335340
)
336341

337-
def finalize(self) -> Evaluator: # type: ignore - switching type
338-
_ = super().finalize()
339-
return Evaluator()
342+
@classmethod
343+
def filter(
344+
cls,
345+
directory: str | Path,
346+
name: str,
347+
evaluator: CachedEvaluator,
348+
filter_expr: Filter,
349+
) -> Evaluator:
350+
evaluator = super().filter(
351+
directory=directory,
352+
name=name,
353+
evaluator=evaluator,
354+
filter_expr=filter_expr,
355+
)
356+
return Evaluator(
357+
directory=evaluator._directory,
358+
name=evaluator._name,
359+
labels_override=evaluator._index_to_label,
360+
)

src/valor_lite/object_detection/loader.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,24 +390,27 @@ def filter(
390390
gt_ids = pairs[:, (0, 1)].astype(np.int64)
391391
pd_ids = pairs[:, (0, 2)].astype(np.int64)
392392

393-
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
394-
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
395-
396393
if filter_expr.groundtruths is not None:
394+
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
397395
gt_tbl = tbl.filter(filter_expr.groundtruths)
398396
gt_pairs = np.column_stack(
399397
[gt_tbl[col].to_numpy() for col in ("datum_id", "gt_id")]
400398
).astype(np.int64)
401399
for gt in np.unique(gt_pairs, axis=0):
402400
mask_valid_gt |= (gt_ids == gt).all(axis=1)
401+
else:
402+
mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
403403

404404
if filter_expr.predictions is not None:
405+
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
405406
pd_tbl = tbl.filter(filter_expr.predictions)
406407
pd_pairs = np.column_stack(
407408
[pd_tbl[col].to_numpy() for col in ("datum_id", "pd_id")]
408409
).astype(np.int64)
409410
for pd in np.unique(pd_pairs, axis=0):
410411
mask_valid_pd |= (pd_ids == pd).all(axis=1)
412+
else:
413+
mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
411414

412415
mask_valid = mask_valid_gt | mask_valid_pd
413416
mask_valid_gt &= mask_valid

tests/object_detection/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def basic_detections(
181181
ymin=rect1[2],
182182
ymax=rect1[3],
183183
labels=["v1"],
184+
metadata={
185+
"gt_rect": "rect1",
186+
},
184187
),
185188
BoundingBox(
186189
uid=str(uuid4()),
@@ -189,6 +192,9 @@ def basic_detections(
189192
ymin=rect3[2],
190193
ymax=rect3[3],
191194
labels=["v2"],
195+
metadata={
196+
"gt_rect": "rect3",
197+
},
192198
),
193199
],
194200
predictions=[
@@ -200,6 +206,9 @@ def basic_detections(
200206
ymax=rect1[3],
201207
labels=["v1"],
202208
scores=[0.3],
209+
metadata={
210+
"pd_rect": "rect1",
211+
},
203212
),
204213
],
205214
),
@@ -213,6 +222,9 @@ def basic_detections(
213222
ymin=rect2[2],
214223
ymax=rect2[3],
215224
labels=["v1"],
225+
metadata={
226+
"gt_rect": "rect2",
227+
},
216228
),
217229
],
218230
predictions=[
@@ -224,6 +236,9 @@ def basic_detections(
224236
ymax=rect2[3],
225237
labels=["v2"],
226238
scores=[0.98],
239+
metadata={
240+
"pd_rect": "rect2",
241+
},
227242
),
228243
],
229244
),

tests/object_detection/test_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_iou_computation():
7676
loader.add_bounding_boxes([detection])
7777
evaluator = loader.finalize()
7878

79-
tbl = evaluator._evaluator._dataset.to_table()
79+
tbl = evaluator._dataset.to_table()
8080
assert tbl.shape == (7, 12)
8181

8282
# show that three unique IOUs exist

tests/object_detection/test_filtering.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
from uuid import uuid4
33

44
import numpy as np
5+
import pyarrow.compute as pc
56
import pytest
67

8+
from valor_lite.cache import DataType
79
from valor_lite.exceptions import EmptyCacheError, EmptyFilterError
810
from valor_lite.object_detection import (
911
BoundingBox,
1012
DataLoader,
1113
Detection,
1214
MetricType,
1315
)
16+
from valor_lite.object_detection.evaluator import Filter
1417

1518

1619
@pytest.fixture
@@ -590,3 +593,110 @@ def test_filtering_four_detections_by_indices(
590593
assert m in expected_metrics
591594
for m in expected_metrics:
592595
assert m in actual_metrics
596+
597+
598+
def test_filtering_four_detections_by_annotation_metadata(
599+
four_detections: list[Detection],
600+
):
601+
"""
602+
Basic object detection test that combines the labels of basic_detections_first_class and basic_detections_second_class.
603+
604+
groundtruths
605+
datum uid1
606+
box 1 - label v1 - tp
607+
box 3 - label v2 - fn unmatched ground truths
608+
datum uid2
609+
box 2 - label v1 - fn misclassification
610+
datum uid3
611+
box 1 - label v1 - tp
612+
box 3 - label v2 - fn unmatched ground truths
613+
datum uid4
614+
box 2 - label v1 - fn misclassification
615+
616+
predictions
617+
datum uid1
618+
box 1 - label v1 - score 0.3 - tp
619+
datum uid2
620+
box 2 - label v2 - score 0.98 - fp misclassification
621+
datum uid3
622+
box 1 - label v1 - score 0.3 - tp
623+
datum uid4
624+
box 2 - label v2 - score 0.98 - fp misclassification
625+
"""
626+
627+
loader = DataLoader(
628+
groundtruth_metadata_types={
629+
"gt_rect": DataType.STRING,
630+
},
631+
prediction_metadata_types={
632+
"pd_rect": DataType.STRING,
633+
},
634+
)
635+
loader.add_bounding_boxes(four_detections)
636+
evaluator = loader.finalize()
637+
638+
# remove all FN groundtruths
639+
filter_ = Filter(
640+
groundtruths=pc.field("gt_rect") == "rect1",
641+
)
642+
metrics = evaluator.evaluate(
643+
iou_thresholds=[0.5], score_thresholds=[0.1], filter_=filter_
644+
)
645+
actual_metrics = [m.to_dict() for m in metrics[MetricType.Counts]]
646+
expected_metrics = [
647+
{
648+
"type": "Counts",
649+
"value": {"tp": 2, "fp": 0, "fn": 0},
650+
"parameters": {
651+
"iou_threshold": 0.5,
652+
"score_threshold": 0.1,
653+
"label": "v1",
654+
},
655+
},
656+
{
657+
"type": "Counts",
658+
"value": {"tp": 0, "fp": 2, "fn": 0},
659+
"parameters": {
660+
"iou_threshold": 0.5,
661+
"score_threshold": 0.1,
662+
"label": "v2",
663+
},
664+
},
665+
]
666+
for m in actual_metrics:
667+
assert m in expected_metrics
668+
for m in expected_metrics:
669+
assert m in actual_metrics
670+
671+
# remove TP ground truths
672+
filter_ = Filter(
673+
groundtruths=pc.field("gt_rect") != "rect1",
674+
)
675+
metrics = evaluator.evaluate(
676+
iou_thresholds=[0.5], score_thresholds=[0.1], filter_=filter_
677+
)
678+
actual_metrics = [m.to_dict() for m in metrics[MetricType.Counts]]
679+
expected_metrics = [
680+
{
681+
"type": "Counts",
682+
"value": {"tp": 0, "fp": 2, "fn": 2},
683+
"parameters": {
684+
"iou_threshold": 0.5,
685+
"score_threshold": 0.1,
686+
"label": "v1",
687+
},
688+
},
689+
{
690+
"type": "Counts",
691+
"value": {"tp": 0, "fp": 2, "fn": 2},
692+
"parameters": {
693+
"iou_threshold": 0.5,
694+
"score_threshold": 0.1,
695+
"label": "v2",
696+
},
697+
},
698+
]
699+
for m in actual_metrics:
700+
assert m in expected_metrics
701+
for m in expected_metrics:
702+
assert m in actual_metrics

0 commit comments

Comments
 (0)