Skip to content

Commit 79ee1ba

Browse files
committed
wip
1 parent be71fef commit 79ee1ba

2 files changed

Lines changed: 612 additions & 0 deletions

File tree

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import json
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
5+
import numpy as np
6+
import pyarrow.compute as pc
7+
import pyarrow.dataset as ds
8+
from numpy.typing import NDArray
9+
10+
from valor_lite.cache import DataType
11+
from valor_lite.classification.computation import compute_metrics
12+
from valor_lite.classification.metric import Metric, MetricType
13+
from valor_lite.classification.utilities import (
14+
unpack_precision_recall_iou_into_metric_lists,
15+
)
16+
17+
18+
@dataclass
19+
class EvaluatorInfo:
20+
number_of_rows: int = 0
21+
number_of_datums: int = 0
22+
number_of_labels: int = 0
23+
number_of_pixels: int = 0
24+
number_of_groundtruth_pixels: int = 0
25+
number_of_prediction_pixels: int = 0
26+
datum_metadata_types: dict[str, DataType] | None = None
27+
groundtruth_metadata_types: dict[str, DataType] | None = None
28+
prediction_metadata_types: dict[str, DataType] | None = None
29+
30+
31+
@dataclass
32+
class Filter:
33+
datums: pc.Expression | None = None
34+
groundtruths: pc.Expression | None = None
35+
predictions: pc.Expression | None = None
36+
37+
38+
class Evaluator:
39+
def __init__(
40+
self,
41+
name: str = "default",
42+
directory: str | Path = ".valor",
43+
labels_override: dict[int, str] | None = None,
44+
):
45+
self._directory = Path(directory)
46+
self._name = name
47+
self._path = self._directory / name
48+
self._cache_path = self._path / "counts"
49+
self._metadata_path = self._path / "metadata.json"
50+
51+
# link cache
52+
self._dataset = ds.dataset(self._cache_path, format="parquet")
53+
54+
# build evaluator meta
55+
(
56+
self._index_to_label,
57+
self._confusion_matrix,
58+
self._info,
59+
) = self.generate_meta(self._dataset, labels_override)
60+
61+
# read config
62+
with open(self._metadata_path, "r") as f:
63+
types = json.load(f)
64+
self._info.datum_metadata_types = types["datum"]
65+
self._info.groundtruth_metadata_types = types["groundtruth"]
66+
self._info.prediction_metadata_types = types["prediction"]
67+
with open(self._cache_path / ".cfg", "r") as f:
68+
cfg = json.load(f)
69+
self._detailed_batch_size = cfg["batch_size"]
70+
self._detailed_rows_per_file = cfg["rows_per_file"]
71+
self._detailed_compression = cfg["compression"]
72+
73+
@property
74+
def dataset(self) -> ds.Dataset:
75+
return self._dataset
76+
77+
@property
78+
def info(self) -> EvaluatorInfo:
79+
return self._info
80+
81+
@staticmethod
82+
def generate_meta(
83+
dataset: ds.Dataset,
84+
labels_override: dict[int, str] | None,
85+
) -> tuple[dict[int, str], NDArray[np.uint64], EvaluatorInfo]:
86+
"""
87+
Generate cache statistics.
88+
89+
Parameters
90+
----------
91+
dataset : Dataset
92+
Valor cache.
93+
labels_override : dict[int, str], optional
94+
Optional labels override. Use when operating over filtered data.
95+
96+
Returns
97+
-------
98+
labels : dict[int, str]
99+
Mapping of label ID's to label values.
100+
confusion_matrix : NDArray[np.uint64]
101+
Array of size (n_labels + 1, n_labels + 1) containing pair counts.
102+
info : EvaluatorInfo
103+
Evaluator cache details.
104+
"""
105+
labels = labels_override if labels_override else {}
106+
info = EvaluatorInfo()
107+
108+
for fragment in dataset.get_fragments():
109+
tbl = fragment.to_table()
110+
columns = (
111+
"datum_id",
112+
"gt_label_id",
113+
"pd_label_id",
114+
"count",
115+
)
116+
ids = np.column_stack(
117+
[tbl[col].to_numpy() for col in columns]
118+
).astype(np.int64)
119+
120+
# count number of rows
121+
info.number_of_rows += int(tbl.shape[0])
122+
123+
# count unique datums
124+
datum_ids = np.unique(ids[:, 0])
125+
info.number_of_datums += int(datum_ids.size)
126+
127+
# get gt labels
128+
gt_label_ids = ids[:, 1]
129+
gt_label_ids, gt_indices = np.unique(
130+
gt_label_ids, return_index=True
131+
)
132+
gt_labels = tbl["gt_label"].take(gt_indices).to_pylist()
133+
gt_labels = dict(zip(gt_label_ids.astype(int).tolist(), gt_labels))
134+
gt_labels.pop(-1, None)
135+
labels.update(gt_labels)
136+
137+
# get pd labels
138+
pd_label_ids = ids[:, 2]
139+
pd_label_ids, pd_indices, pd_counts = np.unique(
140+
pd_label_ids, return_index=True, return_counts=True
141+
)
142+
pd_labels = tbl["pd_label"].take(pd_indices).to_pylist()
143+
pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels))
144+
pd_labels.pop(-1, None)
145+
labels.update(pd_labels)
146+
147+
# post-process
148+
labels.pop(-1, None)
149+
150+
# create confusion matrix
151+
n_labels = len(labels)
152+
matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.uint64)
153+
for fragment in dataset.get_fragments():
154+
tbl = fragment.to_table()
155+
columns = (
156+
"datum_id",
157+
"gt_label_id",
158+
"pd_label_id",
159+
)
160+
ids = np.column_stack(
161+
[tbl[col].to_numpy() for col in columns]
162+
).astype(np.int64)
163+
counts = tbl["count"].to_numpy()
164+
165+
mask_null_gts = ids[:, 1] == -1
166+
mask_null_pds = ids[:, 2] == -1
167+
matrix[0, 0] = counts[mask_null_gts & mask_null_pds].sum()
168+
for idx in range(n_labels):
169+
mask_gts = ids[:, 1] == idx
170+
for pidx in range(n_labels):
171+
mask_pds = ids[:, 2] == pidx
172+
matrix[idx + 1, pidx + 1] = counts[
173+
mask_gts & mask_pds
174+
].sum()
175+
176+
mask_unmatched_gts = mask_gts & mask_null_pds
177+
matrix[idx + 1, 0] = counts[mask_unmatched_gts].sum()
178+
mask_unmatched_pds = mask_null_gts & (ids[:, 2] == idx)
179+
matrix[0, idx + 1] = counts[mask_unmatched_pds].sum()
180+
181+
# complete info object
182+
info.number_of_labels = len(labels)
183+
info.number_of_pixels = matrix.sum()
184+
info.number_of_groundtruth_pixels = matrix[1:, :].sum()
185+
info.number_of_prediction_pixels = matrix[:, 1:].sum()
186+
187+
return labels, matrix, info
188+
189+
def filter(
190+
self,
191+
filter_expr: Filter,
192+
name: str | None = None,
193+
directory: str | Path | None = None,
194+
) -> "Evaluator":
195+
"""
196+
Filter evaluator cache.
197+
198+
Parameters
199+
----------
200+
filter_expr : Filter
201+
An object containing filter expressions.
202+
name : str, optional
203+
Filtered cache name.
204+
directory : str | Path, optional
205+
The directory to store the filtered cache.
206+
207+
Returns
208+
-------
209+
Evaluator
210+
A new evaluator object containing the filtered cache.
211+
"""
212+
name = name if name else "filtered"
213+
directory = directory if directory else self._directory
214+
from valor_lite.classification.loader import Loader
215+
216+
return Loader.filter(
217+
name=name,
218+
directory=directory,
219+
evaluator=self,
220+
filter_expr=filter_expr,
221+
)
222+
223+
def compute_precision_recall_iou(self) -> dict[MetricType, list]:
224+
"""
225+
Performs an evaluation and returns metrics.
226+
227+
Returns
228+
-------
229+
dict[MetricType, list]
230+
A dictionary mapping MetricType enumerations to lists of computed metrics.
231+
"""
232+
results = compute_metrics(counts=self._confusion_matrix)
233+
return unpack_precision_recall_iou_into_metric_lists(
234+
results=results,
235+
index_to_label=self._index_to_label,
236+
)
237+
238+
def evaluate(self) -> dict[MetricType, list[Metric]]:
239+
"""
240+
Computes all available metrics.
241+
242+
Returns
243+
-------
244+
dict[MetricType, list[Metric]]
245+
Lists of metrics organized by metric type.
246+
"""
247+
return self.compute_precision_recall_iou()

0 commit comments

Comments
 (0)