Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions src/valor_lite/classification/manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from dataclasses import asdict, dataclass

import numpy as np
Expand All @@ -17,6 +16,7 @@
unpack_confusion_matrix_into_metric_list,
unpack_precision_recall_rocauc_into_metric_lists,
)
from valor_lite.exceptions import EmptyEvaluatorException, EmptyFilterException

"""
Usage
Expand Down Expand Up @@ -85,6 +85,18 @@ class Filter:
valid_label_indices: NDArray[np.int32] | None
metadata: Metadata

def __post_init__(self):
# validate datum mask
if not self.datum_mask.any():
raise EmptyFilterException("filter removes all datums")

# validate label indices
if (
self.valid_label_indices is not None
and self.valid_label_indices.size == 0
):
raise EmptyFilterException("filter removes all labels")


class Evaluator:
"""
Expand Down Expand Up @@ -155,7 +167,6 @@ def create_filter(
datum_mask = np.ones(n_pairs, dtype=np.bool_)
if datum_ids is not None:
if not datum_ids:
warnings.warn("no valid filtered pairs")
return Filter(
datum_mask=np.zeros_like(datum_mask),
valid_label_indices=None,
Expand All @@ -173,7 +184,6 @@ def create_filter(
valid_label_indices = None
if labels is not None:
if not labels:
warnings.warn("no valid filtered pairs")
return Filter(
datum_mask=datum_mask,
valid_label_indices=np.array([], dtype=np.int32),
Expand Down Expand Up @@ -224,21 +234,6 @@ def filter(
NDArray[int32]
The filtered label metadata.
"""
empty_datum_mask = not filter_.datum_mask.any()
empty_label_mask = (
filter_.valid_label_indices.size == 0
if filter_.valid_label_indices is not None
else False
)
if empty_datum_mask or empty_label_mask:
if empty_datum_mask:
warnings.warn("filter removes all datums")
if empty_label_mask:
warnings.warn("filter removes all labels")
return (
np.array([], dtype=np.float64),
np.zeros((self.metadata.number_of_labels, 2), dtype=np.int32),
)
return filter_cache(
detailed_pairs=self._detailed_pairs,
datum_mask=filter_.datum_mask,
Expand Down Expand Up @@ -502,9 +497,7 @@ def finalize(self):
A ready-to-use evaluator object.
"""
if self._detailed_pairs.size == 0:
self._label_metadata = np.array([], dtype=np.int32)
warnings.warn("evaluator is empty")
return self
raise EmptyEvaluatorException
Comment thread
czaloom marked this conversation as resolved.
Outdated

self._label_metadata = compute_label_metadata(
ids=self._detailed_pairs[:, :3].astype(np.int32),
Expand Down
15 changes: 15 additions & 0 deletions src/valor_lite/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class EmptyEvaluatorException(Exception):
def __init__(self):
super().__init__(
"evaluator cannot be finalized as it contains no data"
)


class EmptyFilterException(Exception):
def __init__(self, message: str):
super().__init__(message)


class InternalCacheException(Exception):
def __init__(self, message: str):
super().__init__(message)
18 changes: 0 additions & 18 deletions src/valor_lite/object_detection/computation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from enum import IntFlag, auto

import numpy as np
Expand Down Expand Up @@ -280,14 +279,6 @@ def filter_cache(
)
detailed_pairs = detailed_pairs[~mask_null_pairs]

if detailed_pairs.size == 0:
warnings.warn("no valid filtered pairs")
return (
np.array([], dtype=np.float64),
np.array([], dtype=np.float64),
np.zeros((n_labels, 2), dtype=np.int32),
)

# sorts by score, iou with ground truth id as a tie-breaker
indices = np.lexsort(
(
Expand Down Expand Up @@ -441,15 +432,6 @@ def compute_precion_recall(
counts = np.zeros((n_ious, n_scores, n_labels, 6), dtype=np.float64)
pr_curve = np.zeros((n_ious, n_labels, 101, 2))

if ranked_pairs.size == 0:
warnings.warn("no valid ranked pairs")
return (
(average_precision, mAP),
(average_recall, mAR),
counts,
pr_curve,
)

# start computation
ids = ranked_pairs[:, :5].astype(np.int32)
gt_ids = ids[:, 1]
Expand Down
Loading