1- import warnings
21from dataclasses import asdict , dataclass
32
43import numpy as np
1716 unpack_confusion_matrix_into_metric_list ,
1817 unpack_precision_recall_rocauc_into_metric_lists ,
1918)
19+ from valor_lite .exceptions import EmptyEvaluatorException , EmptyFilterException
2020
2121"""
2222Usage
@@ -85,6 +85,18 @@ class Filter:
8585 valid_label_indices : NDArray [np .int32 ] | None
8686 metadata : Metadata
8787
88+ def __post_init__ (self ):
89+ # validate datum mask
90+ if not self .datum_mask .any ():
91+ raise EmptyFilterException ("filter removes all datums" )
92+
93+ # validate label indices
94+ if (
95+ self .valid_label_indices is not None
96+ and self .valid_label_indices .size == 0
97+ ):
98+ raise EmptyFilterException ("filter removes all labels" )
99+
88100
89101class Evaluator :
90102 """
@@ -155,7 +167,6 @@ def create_filter(
155167 datum_mask = np .ones (n_pairs , dtype = np .bool_ )
156168 if datum_ids is not None :
157169 if not datum_ids :
158- warnings .warn ("no valid filtered pairs" )
159170 return Filter (
160171 datum_mask = np .zeros_like (datum_mask ),
161172 valid_label_indices = None ,
@@ -173,7 +184,6 @@ def create_filter(
173184 valid_label_indices = None
174185 if labels is not None :
175186 if not labels :
176- warnings .warn ("no valid filtered pairs" )
177187 return Filter (
178188 datum_mask = datum_mask ,
179189 valid_label_indices = np .array ([], dtype = np .int32 ),
@@ -224,21 +234,6 @@ def filter(
224234 NDArray[int32]
225235 The filtered label metadata.
226236 """
227- empty_datum_mask = not filter_ .datum_mask .any ()
228- empty_label_mask = (
229- filter_ .valid_label_indices .size == 0
230- if filter_ .valid_label_indices is not None
231- else False
232- )
233- if empty_datum_mask or empty_label_mask :
234- if empty_datum_mask :
235- warnings .warn ("filter removes all datums" )
236- if empty_label_mask :
237- warnings .warn ("filter removes all labels" )
238- return (
239- np .array ([], dtype = np .float64 ),
240- np .zeros ((self .metadata .number_of_labels , 2 ), dtype = np .int32 ),
241- )
242237 return filter_cache (
243238 detailed_pairs = self ._detailed_pairs ,
244239 datum_mask = filter_ .datum_mask ,
@@ -502,9 +497,7 @@ def finalize(self):
502497 A ready-to-use evaluator object.
503498 """
504499 if self ._detailed_pairs .size == 0 :
505- self ._label_metadata = np .array ([], dtype = np .int32 )
506- warnings .warn ("evaluator is empty" )
507- return self
500+ raise EmptyEvaluatorException ()
508501
509502 self ._label_metadata = compute_label_metadata (
510503 ids = self ._detailed_pairs [:, :3 ].astype (np .int32 ),
0 commit comments