11import tempfile
22from dataclasses import asdict , dataclass
3+ from pathlib import Path
34
45import numpy as np
56import pyarrow .compute as pc
@@ -44,33 +45,30 @@ def to_dict(self) -> dict[str, int | bool]:
4445 return asdict (self )
4546
4647
47- class Evaluator :
48+ class Evaluator ( CachedEvaluator ) :
4849 """
4950 Legacy Object Detection Evaluator
5051 """
5152
52- def __init__ (self , name : str = "default" ):
53- self ._evaluator = CachedEvaluator (name = name )
54-
5553 @property
5654 def metadata (self ) -> Metadata :
5755 """
5856 Evaluation metadata.
5957 """
6058 return Metadata (
61- number_of_datums = self ._evaluator . info .number_of_datums ,
62- number_of_labels = self ._evaluator . info .number_of_labels ,
63- number_of_ground_truths = self ._evaluator . info .number_of_groundtruth_annotations ,
64- number_of_predictions = self ._evaluator . info .number_of_prediction_annotations ,
59+ number_of_datums = self .info .number_of_datums ,
60+ number_of_labels = self .info .number_of_labels ,
61+ number_of_ground_truths = self .info .number_of_groundtruth_annotations ,
62+ number_of_predictions = self .info .number_of_prediction_annotations ,
6563 )
6664
6765 @property
6866 def _detailed_pairs (self ) -> np .ndarray :
6967 return np .concatenate (
7068 [
7169 pairs
72- for pairs in self ._evaluator . iterate_pairs (
73- self ._evaluator . _dataset ,
70+ for pairs in self .iterate_pairs (
71+ self ._dataset ,
7472 columns = [
7573 "datum_id" ,
7674 "gt_id" ,
@@ -87,7 +85,7 @@ def _detailed_pairs(self) -> np.ndarray:
8785 @property
8886 def _label_metadata (self ) -> np .ndarray :
8987 label_metadata = np .zeros (
90- (len (self ._evaluator . _index_to_label ), 2 ), dtype = np .int32
88+ (len (self ._index_to_label ), 2 ), dtype = np .int32
9189 )
9290
9391 # groundtruth labels
@@ -114,9 +112,10 @@ def _label_metadata(self) -> np.ndarray:
114112
115113 return label_metadata
116114
117- def filter (
118- self , filter_ : Filter
119- ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ], NDArray [np .int32 ],]:
115+ def filter ( # type: ignore - legacy function override does not match
116+ self ,
117+ filter_ : Filter ,
118+ ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ], NDArray [np .int32 ]]:
120119 """
121120 Performs filtering over the internal cache.
122121
@@ -135,12 +134,17 @@ def filter(
135134 Label metadata.
136135 """
137136 with tempfile .TemporaryDirectory () as tmpdir :
138- evaluator = Evaluator ()
139- evaluator . _evaluator = self . _evaluator .filter (
137+ name = "filtered"
138+ _evaluator = super () .filter (
140139 directory = tmpdir ,
141- name = "filtered" ,
140+ name = name ,
142141 filter_expr = filter_ ,
143142 )
143+ evaluator = Evaluator (
144+ name = name ,
145+ directory = tmpdir ,
146+ labels_override = _evaluator ._index_to_label ,
147+ )
144148 detailed_pairs = evaluator ._detailed_pairs
145149 label_metadata = evaluator ._label_metadata
146150 return detailed_pairs , detailed_pairs , label_metadata
@@ -229,7 +233,7 @@ def compute_precision_recall(
229233 """
230234 if filter_ is not None :
231235 with tempfile .TemporaryDirectory () as tmpdir :
232- evaluator = self . _evaluator .filter (
236+ evaluator = super () .filter (
233237 directory = tmpdir ,
234238 name = "filtered" ,
235239 filter_expr = filter_ ,
@@ -238,7 +242,7 @@ def compute_precision_recall(
238242 iou_thresholds = iou_thresholds ,
239243 score_thresholds = score_thresholds ,
240244 )
241- return self . _evaluator .compute_precision_recall (
245+ return super () .compute_precision_recall (
242246 iou_thresholds = iou_thresholds ,
243247 score_thresholds = score_thresholds ,
244248 )
@@ -268,7 +272,7 @@ def compute_confusion_matrix(
268272 """
269273 if filter_ is not None :
270274 with tempfile .TemporaryDirectory () as tmpdir :
271- evaluator = self . _evaluator .filter (
275+ evaluator = super () .filter (
272276 directory = tmpdir ,
273277 name = "filtered" ,
274278 filter_expr = filter_ ,
@@ -278,7 +282,7 @@ def compute_confusion_matrix(
278282 score_thresholds = score_thresholds ,
279283 )
280284 else :
281- metrics = self . _evaluator .compute_confusion_matrix_with_examples (
285+ metrics = super () .compute_confusion_matrix_with_examples (
282286 iou_thresholds = iou_thresholds ,
283287 score_thresholds = score_thresholds ,
284288 )
@@ -328,12 +332,29 @@ class DataLoader(CachedLoader):
328332 Legacy Object Detection DataLoader
329333 """
330334
331- def __init__ (self ):
332- super ().__init__ (
333- batch_size = 1_000 ,
334- rows_per_file = 10_000 ,
335+ def finalize (self ) -> Evaluator : # type: ignore - switching type
336+ evaluator = super ().finalize ()
337+ return Evaluator (
338+ name = evaluator ._name ,
339+ directory = evaluator ._directory ,
335340 )
336341
337- def finalize (self ) -> Evaluator : # type: ignore - switching type
338- _ = super ().finalize ()
339- return Evaluator ()
342+ @classmethod
343+ def filter (
344+ cls ,
345+ directory : str | Path ,
346+ name : str ,
347+ evaluator : CachedEvaluator ,
348+ filter_expr : Filter ,
349+ ) -> Evaluator :
350+ evaluator = super ().filter (
351+ directory = directory ,
352+ name = name ,
353+ evaluator = evaluator ,
354+ filter_expr = filter_expr ,
355+ )
356+ return Evaluator (
357+ directory = evaluator ._directory ,
358+ name = evaluator ._name ,
359+ labels_override = evaluator ._index_to_label ,
360+ )
0 commit comments