Skip to content

Commit d770fc8

Browse files
authored
add better exception handling for cache building and filtering (#845)
1 parent c04cc7d commit d770fc8

14 files changed

Lines changed: 526 additions & 617 deletions

File tree

src/valor_lite/classification/manager.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from dataclasses import asdict, dataclass
32

43
import numpy as np
@@ -17,6 +16,7 @@
1716
unpack_confusion_matrix_into_metric_list,
1817
unpack_precision_recall_rocauc_into_metric_lists,
1918
)
19+
from valor_lite.exceptions import EmptyEvaluatorException, EmptyFilterException
2020

2121
"""
2222
Usage
@@ -85,6 +85,18 @@ class Filter:
8585
valid_label_indices: NDArray[np.int32] | None
8686
metadata: Metadata
8787

88+
def __post_init__(self):
89+
# validate datum mask
90+
if not self.datum_mask.any():
91+
raise EmptyFilterException("filter removes all datums")
92+
93+
# validate label indices
94+
if (
95+
self.valid_label_indices is not None
96+
and self.valid_label_indices.size == 0
97+
):
98+
raise EmptyFilterException("filter removes all labels")
99+
88100

89101
class Evaluator:
90102
"""
@@ -155,7 +167,6 @@ def create_filter(
155167
datum_mask = np.ones(n_pairs, dtype=np.bool_)
156168
if datum_ids is not None:
157169
if not datum_ids:
158-
warnings.warn("no valid filtered pairs")
159170
return Filter(
160171
datum_mask=np.zeros_like(datum_mask),
161172
valid_label_indices=None,
@@ -173,7 +184,6 @@ def create_filter(
173184
valid_label_indices = None
174185
if labels is not None:
175186
if not labels:
176-
warnings.warn("no valid filtered pairs")
177187
return Filter(
178188
datum_mask=datum_mask,
179189
valid_label_indices=np.array([], dtype=np.int32),
@@ -224,21 +234,6 @@ def filter(
224234
NDArray[int32]
225235
The filtered label metadata.
226236
"""
227-
empty_datum_mask = not filter_.datum_mask.any()
228-
empty_label_mask = (
229-
filter_.valid_label_indices.size == 0
230-
if filter_.valid_label_indices is not None
231-
else False
232-
)
233-
if empty_datum_mask or empty_label_mask:
234-
if empty_datum_mask:
235-
warnings.warn("filter removes all datums")
236-
if empty_label_mask:
237-
warnings.warn("filter removes all labels")
238-
return (
239-
np.array([], dtype=np.float64),
240-
np.zeros((self.metadata.number_of_labels, 2), dtype=np.int32),
241-
)
242237
return filter_cache(
243238
detailed_pairs=self._detailed_pairs,
244239
datum_mask=filter_.datum_mask,
@@ -502,9 +497,7 @@ def finalize(self):
502497
A ready-to-use evaluator object.
503498
"""
504499
if self._detailed_pairs.size == 0:
505-
self._label_metadata = np.array([], dtype=np.int32)
506-
warnings.warn("evaluator is empty")
507-
return self
500+
raise EmptyEvaluatorException()
508501

509502
self._label_metadata = compute_label_metadata(
510503
ids=self._detailed_pairs[:, :3].astype(np.int32),

src/valor_lite/exceptions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
class EmptyEvaluatorException(Exception):
2+
def __init__(self):
3+
super().__init__(
4+
"evaluator cannot be finalized as it contains no data"
5+
)
6+
7+
8+
class EmptyFilterException(Exception):
9+
def __init__(self, message: str):
10+
super().__init__(message)
11+
12+
13+
class InternalCacheException(Exception):
14+
def __init__(self, message: str):
15+
super().__init__(message)

src/valor_lite/object_detection/computation.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from enum import IntFlag, auto
32

43
import numpy as np
@@ -280,14 +279,6 @@ def filter_cache(
280279
)
281280
detailed_pairs = detailed_pairs[~mask_null_pairs]
282281

283-
if detailed_pairs.size == 0:
284-
warnings.warn("no valid filtered pairs")
285-
return (
286-
np.array([], dtype=np.float64),
287-
np.array([], dtype=np.float64),
288-
np.zeros((n_labels, 2), dtype=np.int32),
289-
)
290-
291282
# sorts by score, iou with ground truth id as a tie-breaker
292283
indices = np.lexsort(
293284
(
@@ -441,15 +432,6 @@ def compute_precion_recall(
441432
counts = np.zeros((n_ious, n_scores, n_labels, 6), dtype=np.float64)
442433
pr_curve = np.zeros((n_ious, n_labels, 101, 2))
443434

444-
if ranked_pairs.size == 0:
445-
warnings.warn("no valid ranked pairs")
446-
return (
447-
(average_precision, mAP),
448-
(average_recall, mAR),
449-
counts,
450-
pr_curve,
451-
)
452-
453435
# start computation
454436
ids = ranked_pairs[:, :5].astype(np.int32)
455437
gt_ids = ids[:, 1]

0 commit comments

Comments
 (0)