1+ import json
2+ from collections import defaultdict
13from dataclasses import asdict , dataclass
4+ from pathlib import Path
25
3- import json
46import numpy as np
57from numpy .typing import NDArray
6- from tqdm import tqdm
7- from pathlib import Path
8- from collections import defaultdict
9-
108from pyarrow import pa
119from pyarrow .compute import pc
1210from pyarrow .dataset import ds
11+ from tqdm import tqdm
1312
13+ from valor_lite .cache import (
14+ CacheReader ,
15+ DataType ,
16+ convert_type_mapping_to_schema ,
17+ )
1418from valor_lite .exceptions import EmptyCacheError , EmptyFilterError
1519from valor_lite .semantic_segmentation .annotation import Segmentation
1620from valor_lite .semantic_segmentation .computation import (
2327from valor_lite .semantic_segmentation .utilities import (
2428 unpack_precision_recall_iou_into_metric_lists ,
2529)
26- from valor_lite .cache import CacheReader , DataType , convert_type_mapping_to_schema
2730
2831
2932@dataclass
3033class EvaluatorInfo :
3134 number_of_datums : int = 0
32- number_of_groundtruth_annotations : int = 0
33- number_of_prediction_annotations : int = 0
3435 number_of_labels : int = 0
36+ number_of_pixels : int = 0
37+ number_of_groundtruth_pixels : int = 0
38+ number_of_prediction_pixels : int = 0
3539 number_of_rows : int = 0
3640 datum_metadata_types : dict [str , DataType ] | None = None
3741 groundtruth_metadata_types : dict [str , DataType ] | None = None
@@ -120,10 +124,9 @@ def generate_meta(
120124 tbl = fragment .to_table ()
121125 columns = (
122126 "datum_id" ,
123- "gt_id" ,
124- "pd_id" ,
125127 "gt_label_id" ,
126128 "pd_label_id" ,
129+ "counts" ,
127130 )
128131 ids = np .column_stack (
129132 [tbl [col ].to_numpy () for col in columns ]
@@ -136,18 +139,8 @@ def generate_meta(
136139 datum_ids = np .unique (ids [:, 0 ])
137140 info .number_of_datums += int (datum_ids .size )
138141
139- # count unique groundtruths
140- gt_ids = ids [:, 1 ]
141- gt_ids = np .unique (gt_ids [gt_ids >= 0 ])
142- info .number_of_groundtruth_annotations += int (gt_ids .shape [0 ])
143-
144- # count unique predictions
145- pd_ids = ids [:, 2 ]
146- pd_ids = np .unique (pd_ids [pd_ids >= 0 ])
147- info .number_of_prediction_annotations += int (pd_ids .shape [0 ])
148-
149142 # get gt labels
150- gt_label_ids = ids [:, 3 ]
143+ gt_label_ids = ids [:, 1 ]
151144 gt_label_ids , gt_indices = np .unique (
152145 gt_label_ids , return_index = True
153146 )
@@ -157,17 +150,17 @@ def generate_meta(
157150 labels .update (gt_labels )
158151
159152 # get pd labels
160- pd_label_ids = ids [:, 4 ]
161- pd_label_ids , pd_indices = np .unique (
162- pd_label_ids , return_index = True
153+ pd_label_ids = ids [:, 2 ]
154+ pd_label_ids , pd_indices , pd_counts = np .unique (
155+ pd_label_ids , return_index = True , return_counts = True
163156 )
164157 pd_labels = tbl ["pd_label" ].take (pd_indices ).to_pylist ()
165158 pd_labels = dict (zip (pd_label_ids .astype (int ).tolist (), pd_labels ))
166159 pd_labels .pop (- 1 , None )
167160 labels .update (pd_labels )
168161
169162 # count gts per label
170- gts = ids [:, ( 1 , 3 ) ].astype (np .int64 )
163+ gts = ids [:, 1 ].astype (np .int64 )
171164 unique_ann = np .unique (gts [gts [:, 0 ] >= 0 ], axis = 0 )
172165 unique_labels , label_counts = np .unique (
173166 unique_ann [:, 1 ], return_counts = True
@@ -181,14 +174,35 @@ def generate_meta(
181174 # complete info object
182175 info .number_of_labels = len (labels )
183176
184- # convert gt counts to numpy
185- number_of_groundtruths_per_label = np .zeros (
186- len (labels ), dtype = np .uint64
187- )
188- for k , v in gt_counts_per_lbl .items ():
189- number_of_groundtruths_per_label [int (k )] = v
177+ # create confusion matrix
178+ n_labels = len (labels )
179+ matrix = np .zeros ((n_labels + 1 , n_labels + 1 ), dtype = np .uint64 )
180+ for fragment in dataset .get_fragments ():
181+ tbl = fragment .to_table ()
182+ columns = (
183+ "datum_id" ,
184+ "gt_label_id" ,
185+ "pd_label_id" ,
186+ )
187+ ids = np .column_stack (
188+ [tbl [col ].to_numpy () for col in columns ]
189+ ).astype (np .int64 )
190+ counts = tbl ["counts" ].to_numpy ()
191+
192+ for idx in range (n_labels ):
193+ mask_gts = ids [:, 1 ] == idx
194+ for pidx in range (n_labels ):
195+ mask_pds = ids [:, 2 ] == pidx
196+ matrix [idx + 1 , pidx + 1 ] = counts [
197+ mask_gts & mask_pds
198+ ].sum ()
199+
200+ mask_unmatched_gts = mask_gts & (ids [:, 2 ] == - 1 )
201+ matrix [idx + 1 , 0 ] = counts [mask_unmatched_gts ].sum ()
202+ mask_unmatched_pds = (ids [:, 1 ] == - 1 ) & (ids [:, 2 ] == idx )
203+ matrix [0 , idx + 1 ] = counts [mask_unmatched_pds ]
190204
191- return labels , number_of_groundtruths_per_label , info
205+ return labels , matrix , info
192206
193207 @staticmethod
194208 def iterate_pairs (
@@ -241,8 +255,8 @@ def filter(
241255 from valor_lite .semantic_segmentation .loader import Loader
242256
243257 return Loader .filter (
244- directory = directory ,
245258 name = name ,
259+ directory = directory ,
246260 evaluator = self ,
247261 filter_expr = filter_expr ,
248- )
262+ )
0 commit comments