Skip to content

Commit a576630

Browse files
committed
loader wip
1 parent 886b185 commit a576630

6 files changed

Lines changed: 360 additions & 39 deletions

File tree

src/valor_lite/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
class EmptyEvaluatorError(Exception):
1+
class EmptyCacheError(Exception):
22
def __init__(self):
33
super().__init__(
44
"evaluator cannot be finalized as it contains no data"

src/valor_lite/semantic_segmentation/computation.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ def filter_cache(
8282
return confusion_matrices, label_metadata
8383

8484

85-
def compute_intermediate_confusion_matrices(
85+
def compute_intermediates(
8686
groundtruths: NDArray[np.bool_],
8787
predictions: NDArray[np.bool_],
8888
groundtruth_labels: NDArray[np.int64],
8989
prediction_labels: NDArray[np.int64],
9090
n_labels: int,
91-
) -> NDArray[np.int64]:
91+
) -> NDArray[np.uint64]:
9292
"""
9393
Computes an intermediate confusion matrix containing label counts.
9494
@@ -107,7 +107,7 @@ def compute_intermediate_confusion_matrices(
107107
108108
Returns
109109
-------
110-
NDArray[np.int64]
110+
NDArray[np.uint64]
111111
A 2-D confusion matrix with shape (n_labels + 1, n_labels + 1).
112112
"""
113113

@@ -125,7 +125,7 @@ def compute_intermediate_confusion_matrices(
125125
intersected_groundtruth_counts = intersection_counts.sum(axis=1)
126126
intersected_prediction_counts = intersection_counts.sum(axis=0)
127127

128-
confusion_matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.int64)
128+
confusion_matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.uint64)
129129
confusion_matrix[0, 0] = background_counts
130130
confusion_matrix[
131131
np.ix_(groundtruth_labels + 1, prediction_labels + 1)
@@ -136,7 +136,6 @@ def compute_intermediate_confusion_matrices(
136136
confusion_matrix[groundtruth_labels + 1, 0] = (
137137
groundtruth_counts - intersected_groundtruth_counts
138138
)
139-
140139
return confusion_matrix
141140

142141

Lines changed: 247 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,248 @@
1+
from dataclasses import asdict, dataclass
2+
3+
import json
4+
import numpy as np
5+
from numpy.typing import NDArray
6+
from tqdm import tqdm
7+
from pathlib import Path
8+
from collections import defaultdict
9+
10+
from pyarrow import pa
11+
from pyarrow.compute import pc
12+
from pyarrow.dataset import ds
13+
14+
from valor_lite.exceptions import EmptyCacheError, EmptyFilterError
15+
from valor_lite.semantic_segmentation.annotation import Segmentation
16+
from valor_lite.semantic_segmentation.computation import (
17+
compute_intermediates,
18+
compute_label_metadata,
19+
compute_metrics,
20+
filter_cache,
21+
)
22+
from valor_lite.semantic_segmentation.metric import Metric, MetricType
23+
from valor_lite.semantic_segmentation.utilities import (
24+
unpack_precision_recall_iou_into_metric_lists,
25+
)
26+
from valor_lite.cache import CacheReader, DataType, convert_type_mapping_to_schema
27+
28+
29+
@dataclass
30+
class EvaluatorInfo:
31+
number_of_datums: int = 0
32+
number_of_groundtruth_annotations: int = 0
33+
number_of_prediction_annotations: int = 0
34+
number_of_labels: int = 0
35+
number_of_rows: int = 0
36+
datum_metadata_types: dict[str, DataType] | None = None
37+
groundtruth_metadata_types: dict[str, DataType] | None = None
38+
prediction_metadata_types: dict[str, DataType] | None = None
39+
40+
41+
@dataclass
42+
class Filter:
43+
datums: pc.Expression | None = None
44+
groundtruths: pc.Expression | None = None
45+
predictions: pc.Expression | None = None
46+
47+
148
class Evaluator:
2-
...
49+
def __init__(
50+
self,
51+
name: str = "default",
52+
directory: str | Path = ".valor",
53+
labels_override: dict[int, str] | None = None,
54+
):
55+
self._directory = Path(directory)
56+
self._name = name
57+
self._path = self._directory / name
58+
self._cache_path = self._path / "counts"
59+
self._metadata_path = self._path / "metadata.json"
60+
61+
# link cache
62+
self._dataset = ds.dataset(self._cache_path, format="parquet")
63+
64+
# build evaluator meta
65+
(
66+
self._index_to_label,
67+
self._number_of_groundtruths_per_label,
68+
self._info,
69+
) = self.generate_meta(self._dataset, labels_override)
70+
71+
# read config
72+
with open(self._metadata_path, "r") as f:
73+
types = json.load(f)
74+
self._info.datum_metadata_types = types["datum"]
75+
self._info.groundtruth_metadata_types = types["groundtruth"]
76+
self._info.prediction_metadata_types = types["prediction"]
77+
with open(self._cache_path / ".cfg", "r") as f:
78+
cfg = json.load(f)
79+
self._detailed_batch_size = cfg["batch_size"]
80+
self._detailed_rows_per_file = cfg["rows_per_file"]
81+
self._detailed_compression = cfg["compression"]
82+
83+
@property
84+
def dataset(self) -> ds.Dataset:
85+
return self._dataset
86+
87+
@property
88+
def info(self) -> EvaluatorInfo:
89+
return self._info
90+
91+
@staticmethod
92+
def generate_meta(
93+
dataset: ds.Dataset,
94+
labels_override: dict[int, str] | None,
95+
) -> tuple[dict[int, str], NDArray[np.uint64], EvaluatorInfo]:
96+
"""
97+
Generate cache statistics.
98+
99+
Parameters
100+
----------
101+
dataset : Dataset
102+
Valor cache.
103+
labels_override : dict[int, str], optional
104+
Optional labels override. Use when operating over filtered data.
105+
106+
Returns
107+
-------
108+
labels : dict[int, str]
109+
Mapping of label ID's to label values.
110+
number_of_groundtruths_per_label : NDArray[np.uint64]
111+
Array of size (n_labels,) containing ground truth counts.
112+
info : EvaluatorInfo
113+
Evaluator cache details.
114+
"""
115+
gt_counts_per_lbl = defaultdict(int)
116+
labels = labels_override if labels_override else {}
117+
info = EvaluatorInfo()
118+
119+
for fragment in dataset.get_fragments():
120+
tbl = fragment.to_table()
121+
columns = (
122+
"datum_id",
123+
"gt_id",
124+
"pd_id",
125+
"gt_label_id",
126+
"pd_label_id",
127+
)
128+
ids = np.column_stack(
129+
[tbl[col].to_numpy() for col in columns]
130+
).astype(np.int64)
131+
132+
# count number of rows
133+
info.number_of_rows += int(tbl.shape[0])
134+
135+
# count unique datums
136+
datum_ids = np.unique(ids[:, 0])
137+
info.number_of_datums += int(datum_ids.size)
138+
139+
# count unique groundtruths
140+
gt_ids = ids[:, 1]
141+
gt_ids = np.unique(gt_ids[gt_ids >= 0])
142+
info.number_of_groundtruth_annotations += int(gt_ids.shape[0])
143+
144+
# count unique predictions
145+
pd_ids = ids[:, 2]
146+
pd_ids = np.unique(pd_ids[pd_ids >= 0])
147+
info.number_of_prediction_annotations += int(pd_ids.shape[0])
148+
149+
# get gt labels
150+
gt_label_ids = ids[:, 3]
151+
gt_label_ids, gt_indices = np.unique(
152+
gt_label_ids, return_index=True
153+
)
154+
gt_labels = tbl["gt_label"].take(gt_indices).to_pylist()
155+
gt_labels = dict(zip(gt_label_ids.astype(int).tolist(), gt_labels))
156+
gt_labels.pop(-1, None)
157+
labels.update(gt_labels)
158+
159+
# get pd labels
160+
pd_label_ids = ids[:, 4]
161+
pd_label_ids, pd_indices = np.unique(
162+
pd_label_ids, return_index=True
163+
)
164+
pd_labels = tbl["pd_label"].take(pd_indices).to_pylist()
165+
pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels))
166+
pd_labels.pop(-1, None)
167+
labels.update(pd_labels)
168+
169+
# count gts per label
170+
gts = ids[:, (1, 3)].astype(np.int64)
171+
unique_ann = np.unique(gts[gts[:, 0] >= 0], axis=0)
172+
unique_labels, label_counts = np.unique(
173+
unique_ann[:, 1], return_counts=True
174+
)
175+
for label_id, count in zip(unique_labels, label_counts):
176+
gt_counts_per_lbl[int(label_id)] += int(count)
177+
178+
# post-process
179+
labels.pop(-1, None)
180+
181+
# complete info object
182+
info.number_of_labels = len(labels)
183+
184+
# convert gt counts to numpy
185+
number_of_groundtruths_per_label = np.zeros(
186+
len(labels), dtype=np.uint64
187+
)
188+
for k, v in gt_counts_per_lbl.items():
189+
number_of_groundtruths_per_label[int(k)] = v
190+
191+
return labels, number_of_groundtruths_per_label, info
192+
193+
@staticmethod
194+
def iterate_pairs(
195+
dataset: ds.Dataset,
196+
columns: list[str] | None = None,
197+
):
198+
for fragment in dataset.get_fragments():
199+
tbl = fragment.to_table(columns=columns)
200+
yield np.column_stack(
201+
[tbl.column(i).to_numpy() for i in range(tbl.num_columns)]
202+
)
203+
204+
@staticmethod
205+
def iterate_pairs_with_table(
206+
dataset: ds.Dataset,
207+
columns: list[str] | None = None,
208+
):
209+
for fragment in dataset.get_fragments():
210+
tbl = fragment.to_table()
211+
columns = columns if columns else tbl.columns
212+
yield tbl, np.column_stack(
213+
[tbl[col].to_numpy() for col in columns]
214+
)
215+
216+
def filter(
217+
self,
218+
filter_expr: Filter,
219+
name: str | None = None,
220+
directory: str | Path | None = None,
221+
) -> "Evaluator":
222+
"""
223+
Filter evaluator cache.
224+
225+
Parameters
226+
----------
227+
filter_expr : Filter
228+
An object containing filter expressions.
229+
name : str, optional
230+
Filtered cache name.
231+
directory : str | Path, optional
232+
The directory to store the filtered cache.
233+
234+
Returns
235+
-------
236+
Evaluator
237+
A new evaluator object containing the filtered cache.
238+
"""
239+
name = name if name else "filtered"
240+
directory = directory if directory else self._directory
241+
from valor_lite.semantic_segmentation.loader import Loader
242+
243+
return Loader.filter(
244+
directory=directory,
245+
name=name,
246+
evaluator=self,
247+
filter_expr=filter_expr,
248+
)

0 commit comments

Comments
 (0)