33from pathlib import Path
44
55import numpy as np
6+ import pyarrow as pa
67import pyarrow .compute as pc
78import pyarrow .dataset as ds
89from numpy .typing import NDArray
910
10- from valor_lite .cache import DataType
11+ from valor_lite .cache import CacheReader , DataType
12+ from valor_lite .exceptions import EmptyCacheError
1113from valor_lite .semantic_segmentation .computation import compute_metrics
14+ from valor_lite .semantic_segmentation .format import PathFormatter
1215from valor_lite .semantic_segmentation .metric import Metric , MetricType
1316from valor_lite .semantic_segmentation .utilities import (
1417 unpack_precision_recall_iou_into_metric_lists ,
@@ -35,44 +38,171 @@ class Filter:
3538 predictions : pc .Expression | None = None
3639
3740
38- class Evaluator :
41+ class Evaluator ( PathFormatter ) :
3942 def __init__ (
4043 self ,
41- name : str = "default" ,
42- directory : str | Path = ".valor" ,
43- labels_override : dict [int , str ] | None = None ,
44+ path : str | Path ,
45+ cache : CacheReader ,
46+ info : EvaluatorInfo ,
47+ index_to_label : dict [int , str ],
48+ confusion_matrix : NDArray [np .uint64 ],
4449 ):
45- self ._directory = Path (directory )
46- self ._name = name
47- self ._path = self ._directory / name
48- self ._cache_path = self ._path / "counts"
49- self ._metadata_path = self ._path / "metadata.json"
50+ self ._path = Path (path )
51+ self ._cache = cache
52+ self ._info = info
53+ self ._index_to_label = index_to_label
54+ self ._confusion_matrix = confusion_matrix
55+
56+ @classmethod
57+ def load (
58+ cls ,
59+ path : str | Path ,
60+ index_to_label_override : dict [int , str ] | None = None ,
61+ ):
62+ # validate path
63+ path = Path (path )
64+ if not path .exists ():
65+ raise FileNotFoundError (f"Directory does not exist: { path } " )
66+ elif not path .is_dir ():
67+ raise NotADirectoryError (
68+ f"Path exists but is not a directory: { path } "
69+ )
5070
51- # link cache
52- self . _dataset = ds . dataset ( self . _cache_path , format = "parquet" )
71+ # load cache
72+ cache = CacheReader . load ( cls . _generate_cache_path ( path ) )
5373
5474 # build evaluator meta
5575 (
56- self . _index_to_label ,
57- self . _confusion_matrix ,
58- self . _info ,
59- ) = self .generate_meta (self . _dataset , labels_override )
76+ index_to_label ,
77+ confusion_matrix ,
78+ info ,
79+ ) = cls .generate_meta (cache . dataset , index_to_label_override )
6080
6181 # read config
62- with open (self ._metadata_path , "r" ) as f :
82+ metadata_path = cls ._generate_metadata_path (path )
83+ with open (metadata_path , "r" ) as f :
6384 types = json .load (f )
64- self ._info .datum_metadata_types = types ["datum" ]
65- self ._info .groundtruth_metadata_types = types ["groundtruth" ]
66- self ._info .prediction_metadata_types = types ["prediction" ]
67- with open (self ._cache_path / ".cfg" , "r" ) as f :
68- cfg = json .load (f )
69- self ._detailed_batch_size = cfg ["batch_size" ]
70- self ._detailed_rows_per_file = cfg ["rows_per_file" ]
71- self ._detailed_compression = cfg ["compression" ]
85+ info .datum_metadata_types = types ["datum_metadata_types" ]
86+ info .groundtruth_metadata_types = types [
87+ "groundtruth_metadata_types"
88+ ]
89+ info .prediction_metadata_types = types ["prediction_metadata_types" ]
90+
91+ return cls (
92+ path = path ,
93+ cache = cache ,
94+ info = info ,
95+ index_to_label = index_to_label ,
96+ confusion_matrix = confusion_matrix ,
97+ )
98+
99+ def filter (
100+ self ,
101+ path : str | Path ,
102+ filter_expr : Filter ,
103+ ) -> "Evaluator" :
104+ """
105+ Filter evaluator cache.
106+
107+ Parameters
108+ ----------
109+ path : str | Path
110+ Where to store the filtered cache.
111+ filter_expr : Filter
112+ An object containing filter expressions.
113+
114+ Returns
115+ -------
116+ Evaluator
117+ A new evaluator object containing the filtered cache.
118+ """
119+ from valor_lite .semantic_segmentation .loader import Loader
120+
121+ loader = Loader .create (
122+ path = path ,
123+ batch_size = self .cache .batch_size ,
124+ rows_per_file = self .cache .rows_per_file ,
125+ compression = self .cache .compression ,
126+ datum_metadata_types = self .info .datum_metadata_types ,
127+ groundtruth_metadata_types = self .info .groundtruth_metadata_types ,
128+ prediction_metadata_types = self .info .prediction_metadata_types ,
129+ )
130+ for fragment in self .cache .dataset .get_fragments ():
131+ tbl = fragment .to_table (filter = filter_expr .datums )
132+
133+ columns = (
134+ "datum_id" ,
135+ "gt_label_id" ,
136+ "pd_label_id" ,
137+ )
138+ pairs = np .column_stack ([tbl [col ].to_numpy () for col in columns ])
139+
140+ n_pairs = pairs .shape [0 ]
141+ gt_ids = pairs [:, (0 , 1 )].astype (np .int64 )
142+ pd_ids = pairs [:, (0 , 2 )].astype (np .int64 )
143+
144+ if filter_expr .groundtruths is not None :
145+ mask_valid_gt = np .zeros (n_pairs , dtype = np .bool_ )
146+ gt_tbl = tbl .filter (filter_expr .groundtruths )
147+ gt_pairs = np .column_stack (
148+ [
149+ gt_tbl [col ].to_numpy ()
150+ for col in ("datum_id" , "gt_label_id" )
151+ ]
152+ ).astype (np .int64 )
153+ for gt in np .unique (gt_pairs , axis = 0 ):
154+ mask_valid_gt |= (gt_ids == gt ).all (axis = 1 )
155+ else :
156+ mask_valid_gt = np .ones (n_pairs , dtype = np .bool_ )
157+
158+ if filter_expr .predictions is not None :
159+ mask_valid_pd = np .zeros (n_pairs , dtype = np .bool_ )
160+ pd_tbl = tbl .filter (filter_expr .predictions )
161+ pd_pairs = np .column_stack (
162+ [
163+ pd_tbl [col ].to_numpy ()
164+ for col in ("datum_id" , "pd_label_id" )
165+ ]
166+ ).astype (np .int64 )
167+ for pd in np .unique (pd_pairs , axis = 0 ):
168+ mask_valid_pd |= (pd_ids == pd ).all (axis = 1 )
169+ else :
170+ mask_valid_pd = np .ones (n_pairs , dtype = np .bool_ )
171+
172+ mask_valid = mask_valid_gt | mask_valid_pd
173+ mask_valid_gt &= mask_valid
174+ mask_valid_pd &= mask_valid
175+
176+ pairs [~ mask_valid_gt , 1 ] = - 1
177+ pairs [~ mask_valid_pd , 2 ] = - 1
178+
179+ for idx , col in enumerate (columns ):
180+ tbl = tbl .set_column (
181+ tbl .schema .names .index (col ), col , pa .array (pairs [:, idx ])
182+ )
183+ loader ._cache .write_table (tbl )
184+
185+ loader ._cache .flush ()
186+ if loader ._cache .dataset .count_rows () == 0 :
187+ raise EmptyCacheError ()
188+
189+ return loader .finalize ()
190+
191+ def delete (self ):
192+ """
193+ Delete evaluator cache.
194+ """
195+ from valor_lite .semantic_segmentation .loader import Loader
196+
197+ Loader .delete (self .path )
198+
199+ @property
200+ def path (self ) -> Path :
201+ return self ._path
72202
73203 @property
74- def dataset (self ) -> ds . Dataset :
75- return self ._dataset
204+ def cache (self ) -> CacheReader :
205+ return self ._cache
76206
77207 @property
78208 def info (self ) -> EvaluatorInfo :
@@ -185,105 +315,6 @@ def generate_meta(
185315
186316 return labels , matrix , info
187317
188- def filter (
189- self ,
190- filter_expr : Filter ,
191- name : str | None = None ,
192- directory : str | Path | None = None ,
193- ) -> "Evaluator" :
194- """
195- Filter evaluator cache.
196-
197- Parameters
198- ----------
199- filter_expr : Filter
200- An object containing filter expressions.
201- name : str, optional
202- Filtered cache name.
203- directory : str | Path, optional
204- The directory to store the filtered cache.
205-
206- Returns
207- -------
208- Evaluator
209- A new evaluator object containing the filtered cache.
210- """
211- loader = cls (
212- directory = directory ,
213- name = name ,
214- batch_size = evaluator ._detailed_batch_size ,
215- rows_per_file = evaluator ._detailed_rows_per_file ,
216- compression = evaluator ._detailed_compression ,
217- datum_metadata_types = evaluator .info .datum_metadata_types ,
218- groundtruth_metadata_types = evaluator .info .groundtruth_metadata_types ,
219- prediction_metadata_types = evaluator .info .prediction_metadata_types ,
220- )
221- for fragment in evaluator .dataset .get_fragments ():
222- tbl = fragment .to_table (filter = filter_expr .datums )
223-
224- columns = (
225- "datum_id" ,
226- "gt_label_id" ,
227- "pd_label_id" ,
228- )
229- pairs = np .column_stack ([tbl [col ].to_numpy () for col in columns ])
230-
231- n_pairs = pairs .shape [0 ]
232- gt_ids = pairs [:, (0 , 1 )].astype (np .int64 )
233- pd_ids = pairs [:, (0 , 2 )].astype (np .int64 )
234-
235- if filter_expr .groundtruths is not None :
236- mask_valid_gt = np .zeros (n_pairs , dtype = np .bool_ )
237- gt_tbl = tbl .filter (filter_expr .groundtruths )
238- gt_pairs = np .column_stack (
239- [
240- gt_tbl [col ].to_numpy ()
241- for col in ("datum_id" , "gt_label_id" )
242- ]
243- ).astype (np .int64 )
244- for gt in np .unique (gt_pairs , axis = 0 ):
245- mask_valid_gt |= (gt_ids == gt ).all (axis = 1 )
246- else :
247- mask_valid_gt = np .ones (n_pairs , dtype = np .bool_ )
248-
249- if filter_expr .predictions is not None :
250- mask_valid_pd = np .zeros (n_pairs , dtype = np .bool_ )
251- pd_tbl = tbl .filter (filter_expr .predictions )
252- pd_pairs = np .column_stack (
253- [
254- pd_tbl [col ].to_numpy ()
255- for col in ("datum_id" , "pd_label_id" )
256- ]
257- ).astype (np .int64 )
258- for pd in np .unique (pd_pairs , axis = 0 ):
259- mask_valid_pd |= (pd_ids == pd ).all (axis = 1 )
260- else :
261- mask_valid_pd = np .ones (n_pairs , dtype = np .bool_ )
262-
263- mask_valid = mask_valid_gt | mask_valid_pd
264- mask_valid_gt &= mask_valid
265- mask_valid_pd &= mask_valid
266-
267- pairs [~ mask_valid_gt , 1 ] = - 1
268- pairs [~ mask_valid_pd , 2 ] = - 1
269-
270- for idx , col in enumerate (columns ):
271- tbl = tbl .set_column (
272- tbl .schema .names .index (col ), col , pa .array (pairs [:, idx ])
273- )
274- loader ._cache .write_table (tbl )
275-
276- loader ._cache .flush ()
277- if loader ._cache .dataset .count_rows () == 0 :
278- raise EmptyCacheError ()
279-
280- evaluator = Evaluator (
281- directory = loader ._directory ,
282- name = loader ._name ,
283- labels_override = evaluator ._index_to_label ,
284- )
285- return evaluator
286-
287318 def compute_precision_recall_iou (self ) -> dict [MetricType , list ]:
288319 """
289320 Performs an evaluation and returns metrics.
0 commit comments