11import json
2- from collections import defaultdict
3- from dataclasses import asdict , dataclass
2+ from dataclasses import dataclass
43from pathlib import Path
54
65import numpy as np
6+ import pyarrow .compute as pc
7+ import pyarrow .dataset as ds
78from numpy .typing import NDArray
8- from pyarrow import pa
9- from pyarrow .compute import pc
10- from pyarrow .dataset import ds
11- from tqdm import tqdm
12-
13- from valor_lite .cache import (
14- CacheReader ,
15- DataType ,
16- convert_type_mapping_to_schema ,
17- )
18- from valor_lite .exceptions import EmptyCacheError , EmptyFilterError
19- from valor_lite .semantic_segmentation .annotation import Segmentation
20- from valor_lite .semantic_segmentation .computation import (
21- compute_intermediates ,
22- compute_label_metadata ,
23- compute_metrics ,
24- filter_cache ,
25- )
9+
10+ from valor_lite .cache import DataType
11+ from valor_lite .semantic_segmentation .computation import compute_metrics
2612from valor_lite .semantic_segmentation .metric import Metric , MetricType
2713from valor_lite .semantic_segmentation .utilities import (
2814 unpack_precision_recall_iou_into_metric_lists ,
3117
3218@dataclass
3319class EvaluatorInfo :
20+ number_of_rows : int = 0
3421 number_of_datums : int = 0
3522 number_of_labels : int = 0
3623 number_of_pixels : int = 0
3724 number_of_groundtruth_pixels : int = 0
3825 number_of_prediction_pixels : int = 0
39- number_of_rows : int = 0
4026 datum_metadata_types : dict [str , DataType ] | None = None
4127 groundtruth_metadata_types : dict [str , DataType ] | None = None
4228 prediction_metadata_types : dict [str , DataType ] | None = None
@@ -68,7 +54,7 @@ def __init__(
6854 # build evaluator meta
6955 (
7056 self ._index_to_label ,
71- self ._number_of_groundtruths_per_label ,
57+ self ._confusion_matrix ,
7258 self ._info ,
7359 ) = self .generate_meta (self ._dataset , labels_override )
7460
@@ -111,12 +97,11 @@ def generate_meta(
11197 -------
11298 labels : dict[int, str]
11399 Mapping of label ID's to label values.
114- number_of_groundtruths_per_label : NDArray[np.uint64]
115- Array of size (n_labels, ) containing ground truth counts.
100+ confusion_matrix : NDArray[np.uint64]
101+ Array of size (n_labels + 1, n_labels + 1 ) containing pair counts.
116102 info : EvaluatorInfo
117103 Evaluator cache details.
118104 """
119- gt_counts_per_lbl = defaultdict (int )
120105 labels = labels_override if labels_override else {}
121106 info = EvaluatorInfo ()
122107
@@ -126,7 +111,7 @@ def generate_meta(
126111 "datum_id" ,
127112 "gt_label_id" ,
128113 "pd_label_id" ,
129- "counts " ,
114+ "count " ,
130115 )
131116 ids = np .column_stack (
132117 [tbl [col ].to_numpy () for col in columns ]
@@ -159,21 +144,9 @@ def generate_meta(
159144 pd_labels .pop (- 1 , None )
160145 labels .update (pd_labels )
161146
162- # count gts per label
163- gts = ids [:, 1 ].astype (np .int64 )
164- unique_ann = np .unique (gts [gts [:, 0 ] >= 0 ], axis = 0 )
165- unique_labels , label_counts = np .unique (
166- unique_ann [:, 1 ], return_counts = True
167- )
168- for label_id , count in zip (unique_labels , label_counts ):
169- gt_counts_per_lbl [int (label_id )] += int (count )
170-
171147 # post-process
172148 labels .pop (- 1 , None )
173149
174- # complete info object
175- info .number_of_labels = len (labels )
176-
177150 # create confusion matrix
178151 n_labels = len (labels )
179152 matrix = np .zeros ((n_labels + 1 , n_labels + 1 ), dtype = np .uint64 )
@@ -187,8 +160,11 @@ def generate_meta(
187160 ids = np .column_stack (
188161 [tbl [col ].to_numpy () for col in columns ]
189162 ).astype (np .int64 )
190- counts = tbl ["counts " ].to_numpy ()
163+ counts = tbl ["count " ].to_numpy ()
191164
165+ mask_null_gts = ids [:, 1 ] == - 1
166+ mask_null_pds = ids [:, 2 ] == - 1
167+ matrix [0 , 0 ] = counts [mask_null_gts & mask_null_pds ].sum ()
192168 for idx in range (n_labels ):
193169 mask_gts = ids [:, 1 ] == idx
194170 for pidx in range (n_labels ):
@@ -197,35 +173,18 @@ def generate_meta(
197173 mask_gts & mask_pds
198174 ].sum ()
199175
200- mask_unmatched_gts = mask_gts & ( ids [:, 2 ] == - 1 )
176+ mask_unmatched_gts = mask_gts & mask_null_pds
201177 matrix [idx + 1 , 0 ] = counts [mask_unmatched_gts ].sum ()
202- mask_unmatched_pds = ( ids [:, 1 ] == - 1 ) & (ids [:, 2 ] == idx )
203- matrix [0 , idx + 1 ] = counts [mask_unmatched_pds ]
178+ mask_unmatched_pds = mask_null_gts & (ids [:, 2 ] == idx )
179+ matrix [0 , idx + 1 ] = counts [mask_unmatched_pds ]. sum ()
204180
205- return labels , matrix , info
206-
207- @staticmethod
208- def iterate_pairs (
209- dataset : ds .Dataset ,
210- columns : list [str ] | None = None ,
211- ):
212- for fragment in dataset .get_fragments ():
213- tbl = fragment .to_table (columns = columns )
214- yield np .column_stack (
215- [tbl .column (i ).to_numpy () for i in range (tbl .num_columns )]
216- )
181+ # complete info object
182+ info .number_of_labels = len (labels )
183+ info .number_of_pixels = matrix .sum ()
184+ info .number_of_groundtruth_pixels = matrix [1 :, :].sum ()
185+ info .number_of_prediction_pixels = matrix [:, 1 :].sum ()
217186
218- @staticmethod
219- def iterate_pairs_with_table (
220- dataset : ds .Dataset ,
221- columns : list [str ] | None = None ,
222- ):
223- for fragment in dataset .get_fragments ():
224- tbl = fragment .to_table ()
225- columns = columns if columns else tbl .columns
226- yield tbl , np .column_stack (
227- [tbl [col ].to_numpy () for col in columns ]
228- )
187+ return labels , matrix , info
229188
230189 def filter (
231190 self ,
@@ -260,3 +219,29 @@ def filter(
260219 evaluator = self ,
261220 filter_expr = filter_expr ,
262221 )
222+
223+ def compute_precision_recall_iou (self ) -> dict [MetricType , list ]:
224+ """
225+ Performs an evaluation and returns metrics.
226+
227+ Returns
228+ -------
229+ dict[MetricType, list]
230+ A dictionary mapping MetricType enumerations to lists of computed metrics.
231+ """
232+ results = compute_metrics (counts = self ._confusion_matrix )
233+ return unpack_precision_recall_iou_into_metric_lists (
234+ results = results ,
235+ index_to_label = self ._index_to_label ,
236+ )
237+
238+ def evaluate (self ) -> dict [MetricType , list [Metric ]]:
239+ """
240+ Computes all available metrics.
241+
242+ Returns
243+ -------
244+ dict[MetricType, list[Metric]]
245+ Lists of metrics organized by metric type.
246+ """
247+ return self .compute_precision_recall_iou ()
0 commit comments