1- import heapq
21import json
3- import tempfile
42from collections import defaultdict
53from dataclasses import dataclass
64from pathlib import Path
97import pyarrow as pa
108import pyarrow .compute as pc
119import pyarrow .dataset as ds
12- import pyarrow .parquet as pq
1310from numpy .typing import NDArray
1411
15- from valor_lite .cache import CacheReader , CacheWriter , DataType
12+ from valor_lite .cache import CacheReader , DataType
1613from valor_lite .object_detection .computation import (
1714 compute_average_precision ,
1815 compute_average_recall ,
1916 compute_confusion_matrix ,
2017 compute_counts ,
2118 compute_pair_classifications ,
2219 compute_precision_recall_f1 ,
23- rank_table ,
2420)
2521from valor_lite .object_detection .format import PathFormatter
2622from valor_lite .object_detection .metric import Metric , MetricType
@@ -72,24 +68,30 @@ def __init__(
7268 number_of_groundtruths_per_label
7369 )
7470
71+ @property
72+ def path (self ) -> Path :
73+ return self ._path
74+
75+ @property
76+ def detailed (self ) -> CacheReader :
77+ return self ._detailed_cache
78+
79+ @property
80+ def ranked (self ) -> CacheReader :
81+ return self ._ranked_cache
82+
83+ @property
84+ def info (self ) -> EvaluatorInfo :
85+ return self ._info
86+
7587 @classmethod
76- def create (
88+ def load (
7789 cls ,
7890 path : str | Path ,
79- batch_size : int = 1_000 ,
8091 index_to_label_override : dict [int , str ] | None = None ,
8192 ):
82- """
83- Create a ranked pair cache.
84-
85- Parameters
86- ----------
87- path : str | Path
88- Where to store the evaluator cache.
89- batch_size : int, default=1_000
90- Sets the batch size for reading. Defaults to 1_000.
91- """
9293 detailed_cache = CacheReader (cls ._generate_detailed_cache_path (path ))
94+ ranked_cache = CacheReader (cls ._generate_ranked_cache_path (path ))
9395
9496 # build evaluator meta
9597 (
@@ -108,119 +110,6 @@ def create(
108110 ]
109111 info .prediction_metadata_types = types ["prediction_metadata_types" ]
110112
111- # create ranked cache schema
112- annotation_metadata_keys = {
113- * (
114- set (info .groundtruth_metadata_types .keys ())
115- if info .groundtruth_metadata_types
116- else {}
117- ),
118- * (
119- set (info .prediction_metadata_types .keys ())
120- if info .prediction_metadata_types
121- else {}
122- ),
123- }
124- pruned_schema = pa .schema (
125- [
126- field
127- for field in detailed_cache .schema
128- if field .name not in annotation_metadata_keys
129- ]
130- )
131- ranked_schema = pruned_schema .append (
132- pa .field ("iou_prev" , pa .float64 ())
133- )
134- ranked_schema = ranked_schema .append (
135- pa .field ("high_score" , pa .bool_ ())
136- )
137-
138- n_labels = len (index_to_label )
139-
140- with CacheWriter .create (
141- path = cls ._generate_ranked_cache_path (path ),
142- schema = ranked_schema ,
143- batch_size = detailed_cache .batch_size ,
144- rows_per_file = detailed_cache .rows_per_file ,
145- compression = detailed_cache .compression ,
146- ) as ranked_cache :
147- if detailed_cache .num_dataset_files == 1 :
148- pf = pq .ParquetFile (detailed_cache .dataset_files [0 ])
149- tbl = pf .read ()
150- ranked_tbl = rank_table (tbl , n_labels )
151- ranked_cache .write_table (ranked_tbl )
152- else :
153- pruned_detailed_columns = [
154- field .name for field in pruned_schema
155- ]
156- with tempfile .TemporaryDirectory () as tmpdir :
157-
158- # rank individual files
159- tmpfiles = []
160- for idx , fragment in enumerate (
161- detailed_cache .dataset .get_fragments ()
162- ):
163- fragment_path = Path (tmpdir ) / f"{ idx :06d} .parquet"
164- tbl = fragment .to_table (
165- columns = pruned_detailed_columns
166- )
167- ranked_tbl = rank_table (tbl , n_labels )
168- pq .write_table (ranked_tbl , fragment_path )
169- tmpfiles .append (fragment_path )
170-
171- def generate_heap_item (batches , batch_idx , row_idx ):
172- score = batches [batch_idx ]["score" ][row_idx ].as_py ()
173- iou = batches [batch_idx ]["iou" ][row_idx ].as_py ()
174- return (
175- - score ,
176- - iou ,
177- batch_idx ,
178- row_idx ,
179- )
180-
181- # merge sorted rows
182- heap = []
183- batch_iterators = []
184- batches = []
185- for batch_idx , batch_path in enumerate (tmpfiles ):
186- pf = pq .ParquetFile (batch_path )
187- batch_iter = pf .iter_batches (batch_size = batch_size )
188- batch_iterators .append (batch_iter )
189- batches .append (next (batch_iterators [batch_idx ], None ))
190- if (
191- batches [batch_idx ] is not None
192- and len (batches [batch_idx ]) > 0
193- ):
194- heapq .heappush (
195- heap , generate_heap_item (batches , batch_idx , 0 )
196- )
197-
198- while heap :
199- _ , _ , batch_idx , row_idx = heapq .heappop (heap )
200- row_table = batches [batch_idx ].slice (row_idx , 1 )
201- ranked_cache .write_batch (row_table )
202- row_idx += 1
203- if row_idx < len (batches [batch_idx ]):
204- heapq .heappush (
205- heap ,
206- generate_heap_item (
207- batches , batch_idx , row_idx
208- ),
209- )
210- else :
211- batches [batch_idx ] = next (
212- batch_iterators [batch_idx ], None
213- )
214- if (
215- batches [batch_idx ] is not None
216- and len (batches [batch_idx ]) > 0
217- ):
218- heapq .heappush (
219- heap ,
220- generate_heap_item (batches , batch_idx , 0 ),
221- )
222-
223- ranked_cache = CacheReader (cls ._generate_ranked_cache_path (path ))
224113 return cls (
225114 path = path ,
226115 detailed_cache = detailed_cache ,
@@ -230,56 +119,101 @@ def generate_heap_item(batches, batch_idx, row_idx):
230119 number_of_groundtruths_per_label = number_of_groundtruths_per_label ,
231120 )
232121
233- @classmethod
234- def load (
235- cls ,
122+ def filter (
123+ self ,
236124 path : str | Path ,
237- index_to_label_override : dict [int , str ] | None = None ,
238- ):
239- detailed_cache = CacheReader (cls ._generate_detailed_cache_path (path ))
240- ranked_cache = CacheReader (cls ._generate_ranked_cache_path (path ))
125+ filter_expr : Filter ,
126+ batch_size : int = 1_000 ,
127+ ) -> "Evaluator" :
128+ """
129+ Filter evaluator cache.
241130
242- # build evaluator meta
243- (
244- index_to_label ,
245- number_of_groundtruths_per_label ,
246- info ,
247- ) = cls .generate_meta (detailed_cache .dataset , index_to_label_override )
131+ Parameters
132+ ----------
133+ path : str | Path
134+ Where to store the filtered cache.
135+ filter_expr : Filter
136+ An object containing filter expressions.
137+ batch_size : int
138+ The maximum number of rows read into memory per file.
248139
249- # read config
250- metadata_path = cls ._generate_metadata_path (path )
251- with open (metadata_path , "r" ) as f :
252- types = json .load (f )
253- info .datum_metadata_types = types ["datum_metadata_types" ]
254- info .groundtruth_metadata_types = types [
255- "groundtruth_metadata_types"
256- ]
257- info .prediction_metadata_types = types ["prediction_metadata_types" ]
140+ Returns
141+ -------
142+ Evaluator
143+ A new evaluator object containing the filtered cache.
144+ """
145+ from valor_lite .object_detection .loader import Loader
258146
259- return cls (
147+ loader = Loader . create (
260148 path = path ,
261- detailed_cache = detailed_cache ,
262- ranked_cache = ranked_cache ,
263- info = info ,
264- index_to_label = index_to_label ,
265- number_of_groundtruths_per_label = number_of_groundtruths_per_label ,
149+ batch_size = self .detailed .batch_size ,
150+ rows_per_file = self .detailed .rows_per_file ,
151+ compression = self .detailed .compression ,
152+ datum_metadata_types = self .info .datum_metadata_types ,
153+ groundtruth_metadata_types = self .info .groundtruth_metadata_types ,
154+ prediction_metadata_types = self .info .prediction_metadata_types ,
266155 )
156+ for fragment in self .detailed .dataset .get_fragments ():
157+ tbl = fragment .to_table (filter = filter_expr .datums )
267158
268- @property
269- def path (self ) -> Path :
270- return self ._path
159+ columns = (
160+ "datum_id" ,
161+ "gt_id" ,
162+ "pd_id" ,
163+ "gt_label_id" ,
164+ "pd_label_id" ,
165+ "iou" ,
166+ "score" ,
167+ )
168+ pairs = np .column_stack ([tbl [col ].to_numpy () for col in columns ])
169+
170+ n_pairs = pairs .shape [0 ]
171+ gt_ids = pairs [:, (0 , 1 )].astype (np .int64 )
172+ pd_ids = pairs [:, (0 , 2 )].astype (np .int64 )
173+
174+ if filter_expr .groundtruths is not None :
175+ mask_valid_gt = np .zeros (n_pairs , dtype = np .bool_ )
176+ gt_tbl = tbl .filter (filter_expr .groundtruths )
177+ gt_pairs = np .column_stack (
178+ [gt_tbl [col ].to_numpy () for col in ("datum_id" , "gt_id" )]
179+ ).astype (np .int64 )
180+ for gt in np .unique (gt_pairs , axis = 0 ):
181+ mask_valid_gt |= (gt_ids == gt ).all (axis = 1 )
182+ else :
183+ mask_valid_gt = np .ones (n_pairs , dtype = np .bool_ )
184+
185+ if filter_expr .predictions is not None :
186+ mask_valid_pd = np .zeros (n_pairs , dtype = np .bool_ )
187+ pd_tbl = tbl .filter (filter_expr .predictions )
188+ pd_pairs = np .column_stack (
189+ [pd_tbl [col ].to_numpy () for col in ("datum_id" , "pd_id" )]
190+ ).astype (np .int64 )
191+ for pd in np .unique (pd_pairs , axis = 0 ):
192+ mask_valid_pd |= (pd_ids == pd ).all (axis = 1 )
193+ else :
194+ mask_valid_pd = np .ones (n_pairs , dtype = np .bool_ )
271195
272- @ property
273- def detailed ( self ) -> CacheReader :
274- return self . _detailed_cache
196+ mask_valid = mask_valid_gt | mask_valid_pd
197+ mask_valid_gt &= mask_valid
198+ mask_valid_pd &= mask_valid
275199
276- @ property
277- def ranked ( self ) -> CacheReader :
278- return self . _ranked_cache
200+ pairs [ np . ix_ ( ~ mask_valid_gt , ( 1 , 3 ))] = - 1.0 # type: ignore - numpy ix_
201+ pairs [ np . ix_ ( ~ mask_valid_pd , ( 2 , 4 , 6 ))] = - 1.0 # type: ignore - numpy ix_
202+ pairs [ ~ mask_valid_pd | ~ mask_valid_gt , 5 ] = 0.0
279203
280- @property
281- def info (self ) -> EvaluatorInfo :
282- return self ._info
204+ for idx , col in enumerate (columns ):
205+ tbl = tbl .set_column (
206+ tbl .schema .names .index (col ), col , pa .array (pairs [:, idx ])
207+ )
208+
209+ mask_invalid = ~ mask_valid | (pairs [:, (1 , 2 )] < 0 ).all (axis = 1 )
210+ filtered_tbl = tbl .filter (pa .array (~ mask_invalid ))
211+ loader ._cache .write_table (filtered_tbl )
212+
213+ return loader .finalize (
214+ batch_size = batch_size ,
215+ index_to_label_override = self ._index_to_label ,
216+ )
283217
284218 @staticmethod
285219 def generate_meta (
@@ -406,34 +340,6 @@ def iterate_pairs_with_table(
406340 [tbl [col ].to_numpy () for col in columns ]
407341 )
408342
409- def filter (
410- self ,
411- path : str | Path ,
412- filter_expr : Filter ,
413- ) -> "Evaluator" :
414- """
415- Filter evaluator cache.
416-
417- Parameters
418- ----------
419- path : str | Path
420- Where to store the filtered cache.
421- filter_expr : Filter
422- An object containing filter expressions.
423-
424- Returns
425- -------
426- Evaluator
427- A new evaluator object containing the filtered cache.
428- """
429- from valor_lite .object_detection .loader import Loader
430-
431- return Loader .filter (
432- path = path ,
433- evaluator = self ,
434- filter_expr = filter_expr ,
435- )
436-
437343 def compute_precision_recall (
438344 self ,
439345 iou_thresholds : list [float ],
0 commit comments