Skip to content

Commit f98b740

Browse files
committed
passing tests
1 parent 161ad0d commit f98b740

10 files changed

Lines changed: 197 additions & 627 deletions

File tree

src/valor_lite/semantic_segmentation/computation.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def compute_intermediates(
109109

110110

111111
def compute_metrics(
112-
confusion_matrices: NDArray[np.int64],
112+
counts: NDArray[np.uint64],
113113
) -> tuple[
114114
NDArray[np.float64],
115115
NDArray[np.float64],
@@ -126,8 +126,8 @@ def compute_metrics(
126126
127127
Parameters
128128
----------
129-
confusion_matrices : NDArray[np.int64]
130-
A 3-D array containing confusion matrices for each datum with shape (n_datums, n_labels + 1, n_labels + 1).
129+
counts : NDArray[np.int64]
130+
A 2-D confusion matrix with shape (n_labels + 1, n_labels + 1).
131131
label_metadata : NDArray[np.int64]
132132
A 2-D array containing label metadata with shape (n_labels, 2).
133133
Index 0: Ground Truth Label Count
@@ -150,15 +150,10 @@ def compute_metrics(
150150
NDArray[np.float64]
151151
Unmatched ground truth ratios.
152152
"""
153-
n_labels = confusion_matrices.shape[-1] - 1
154-
n_pixels = confusion_matrices.sum()
155-
label_metadata = np.zeros((n_labels, 2), dtype=np.int64)
156-
label_metadata[:, 0] = confusion_matrices[:, 1:, :].sum(axis=(0, 2))
157-
label_metadata[:, 1] = confusion_matrices[:, :, 1:].sum(axis=(0, 1))
158-
gt_counts = label_metadata[:, 0]
159-
pd_counts = label_metadata[:, 1]
160-
161-
counts = confusion_matrices.sum(axis=0)
153+
n_labels = counts.shape[0] - 1
154+
n_pixels = counts.sum()
155+
gt_counts = counts[1:, :].sum(axis=1)
156+
pd_counts = counts[:, 1:].sum(axis=0)
162157

163158
# compute iou, unmatched_ground_truth and unmatched predictions
164159
intersection_ = counts[1:, 1:]

src/valor_lite/semantic_segmentation/evaluator.py

Lines changed: 50 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,14 @@
11
import json
2-
from collections import defaultdict
3-
from dataclasses import asdict, dataclass
2+
from dataclasses import dataclass
43
from pathlib import Path
54

65
import numpy as np
6+
import pyarrow.compute as pc
7+
import pyarrow.dataset as ds
78
from numpy.typing import NDArray
8-
from pyarrow import pa
9-
from pyarrow.compute import pc
10-
from pyarrow.dataset import ds
11-
from tqdm import tqdm
12-
13-
from valor_lite.cache import (
14-
CacheReader,
15-
DataType,
16-
convert_type_mapping_to_schema,
17-
)
18-
from valor_lite.exceptions import EmptyCacheError, EmptyFilterError
19-
from valor_lite.semantic_segmentation.annotation import Segmentation
20-
from valor_lite.semantic_segmentation.computation import (
21-
compute_intermediates,
22-
compute_label_metadata,
23-
compute_metrics,
24-
filter_cache,
25-
)
9+
10+
from valor_lite.cache import DataType
11+
from valor_lite.semantic_segmentation.computation import compute_metrics
2612
from valor_lite.semantic_segmentation.metric import Metric, MetricType
2713
from valor_lite.semantic_segmentation.utilities import (
2814
unpack_precision_recall_iou_into_metric_lists,
@@ -31,12 +17,12 @@
3117

3218
@dataclass
3319
class EvaluatorInfo:
20+
number_of_rows: int = 0
3421
number_of_datums: int = 0
3522
number_of_labels: int = 0
3623
number_of_pixels: int = 0
3724
number_of_groundtruth_pixels: int = 0
3825
number_of_prediction_pixels: int = 0
39-
number_of_rows: int = 0
4026
datum_metadata_types: dict[str, DataType] | None = None
4127
groundtruth_metadata_types: dict[str, DataType] | None = None
4228
prediction_metadata_types: dict[str, DataType] | None = None
@@ -68,7 +54,7 @@ def __init__(
6854
# build evaluator meta
6955
(
7056
self._index_to_label,
71-
self._number_of_groundtruths_per_label,
57+
self._confusion_matrix,
7258
self._info,
7359
) = self.generate_meta(self._dataset, labels_override)
7460

@@ -111,12 +97,11 @@ def generate_meta(
11197
-------
11298
labels : dict[int, str]
11399
Mapping of label ID's to label values.
114-
number_of_groundtruths_per_label : NDArray[np.uint64]
115-
Array of size (n_labels,) containing ground truth counts.
100+
confusion_matrix : NDArray[np.uint64]
101+
Array of size (n_labels + 1, n_labels + 1) containing pair counts.
116102
info : EvaluatorInfo
117103
Evaluator cache details.
118104
"""
119-
gt_counts_per_lbl = defaultdict(int)
120105
labels = labels_override if labels_override else {}
121106
info = EvaluatorInfo()
122107

@@ -126,7 +111,7 @@ def generate_meta(
126111
"datum_id",
127112
"gt_label_id",
128113
"pd_label_id",
129-
"counts",
114+
"count",
130115
)
131116
ids = np.column_stack(
132117
[tbl[col].to_numpy() for col in columns]
@@ -159,21 +144,9 @@ def generate_meta(
159144
pd_labels.pop(-1, None)
160145
labels.update(pd_labels)
161146

162-
# count gts per label
163-
gts = ids[:, 1].astype(np.int64)
164-
unique_ann = np.unique(gts[gts[:, 0] >= 0], axis=0)
165-
unique_labels, label_counts = np.unique(
166-
unique_ann[:, 1], return_counts=True
167-
)
168-
for label_id, count in zip(unique_labels, label_counts):
169-
gt_counts_per_lbl[int(label_id)] += int(count)
170-
171147
# post-process
172148
labels.pop(-1, None)
173149

174-
# complete info object
175-
info.number_of_labels = len(labels)
176-
177150
# create confusion matrix
178151
n_labels = len(labels)
179152
matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.uint64)
@@ -187,8 +160,11 @@ def generate_meta(
187160
ids = np.column_stack(
188161
[tbl[col].to_numpy() for col in columns]
189162
).astype(np.int64)
190-
counts = tbl["counts"].to_numpy()
163+
counts = tbl["count"].to_numpy()
191164

165+
mask_null_gts = ids[:, 1] == -1
166+
mask_null_pds = ids[:, 2] == -1
167+
matrix[0, 0] = counts[mask_null_gts & mask_null_pds].sum()
192168
for idx in range(n_labels):
193169
mask_gts = ids[:, 1] == idx
194170
for pidx in range(n_labels):
@@ -197,35 +173,18 @@ def generate_meta(
197173
mask_gts & mask_pds
198174
].sum()
199175

200-
mask_unmatched_gts = mask_gts & (ids[:, 2] == -1)
176+
mask_unmatched_gts = mask_gts & mask_null_pds
201177
matrix[idx + 1, 0] = counts[mask_unmatched_gts].sum()
202-
mask_unmatched_pds = (ids[:, 1] == -1) & (ids[:, 2] == idx)
203-
matrix[0, idx + 1] = counts[mask_unmatched_pds]
178+
mask_unmatched_pds = mask_null_gts & (ids[:, 2] == idx)
179+
matrix[0, idx + 1] = counts[mask_unmatched_pds].sum()
204180

205-
return labels, matrix, info
206-
207-
@staticmethod
208-
def iterate_pairs(
209-
dataset: ds.Dataset,
210-
columns: list[str] | None = None,
211-
):
212-
for fragment in dataset.get_fragments():
213-
tbl = fragment.to_table(columns=columns)
214-
yield np.column_stack(
215-
[tbl.column(i).to_numpy() for i in range(tbl.num_columns)]
216-
)
181+
# complete info object
182+
info.number_of_labels = len(labels)
183+
info.number_of_pixels = matrix.sum()
184+
info.number_of_groundtruth_pixels = matrix[1:, :].sum()
185+
info.number_of_prediction_pixels = matrix[:, 1:].sum()
217186

218-
@staticmethod
219-
def iterate_pairs_with_table(
220-
dataset: ds.Dataset,
221-
columns: list[str] | None = None,
222-
):
223-
for fragment in dataset.get_fragments():
224-
tbl = fragment.to_table()
225-
columns = columns if columns else tbl.columns
226-
yield tbl, np.column_stack(
227-
[tbl[col].to_numpy() for col in columns]
228-
)
187+
return labels, matrix, info
229188

230189
def filter(
231190
self,
@@ -260,3 +219,29 @@ def filter(
260219
evaluator=self,
261220
filter_expr=filter_expr,
262221
)
222+
223+
def compute_precision_recall_iou(self) -> dict[MetricType, list]:
224+
"""
225+
Performs an evaluation and returns metrics.
226+
227+
Returns
228+
-------
229+
dict[MetricType, list]
230+
A dictionary mapping MetricType enumerations to lists of computed metrics.
231+
"""
232+
results = compute_metrics(counts=self._confusion_matrix)
233+
return unpack_precision_recall_iou_into_metric_lists(
234+
results=results,
235+
index_to_label=self._index_to_label,
236+
)
237+
238+
def evaluate(self) -> dict[MetricType, list[Metric]]:
239+
"""
240+
Computes all available metrics.
241+
242+
Returns
243+
-------
244+
dict[MetricType, list[Metric]]
245+
Lists of metrics organized by metric type.
246+
"""
247+
return self.compute_precision_recall_iou()

src/valor_lite/semantic_segmentation/loader.py

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ def __init__(
3939
self._labels: dict[str, int] = {}
4040
self._index_to_label: dict[int, str] = {}
4141
self._datum_count = 0
42-
self._groundtruth_pixel_count = 0
43-
self._prediction_pixel_count = 0
44-
self._total_pixel_count = 0
4542

4643
with open(self._metadata_path, "w") as f:
4744
types = {
@@ -247,18 +244,28 @@ def add_data(
247244
},
248245
]
249246
)
247+
rows.append(
248+
{
249+
# datum
250+
"datum_uid": segmentation.uid,
251+
"datum_id": self._datum_count,
252+
**datum_metadata,
253+
# groundtruth
254+
"gt_label": None,
255+
"gt_label_id": -1,
256+
# prediction
257+
"pd_label": None,
258+
"pd_label_id": -1,
259+
# pair
260+
"count": counts[0, 0],
261+
}
262+
)
250263
self._cache.write_rows(rows)
251264

252265
# update datum cache
253266
self._datum_count += 1
254267

255-
def finalize(
256-
self,
257-
rows_per_file: int | None = None,
258-
compression: str | None = None,
259-
write_batch_size: int | None = None,
260-
read_batch_size: int = 1000,
261-
):
268+
def finalize(self):
262269
"""
263270
Performs data finalization and some preprocessing steps.
264271
@@ -271,18 +278,10 @@ def finalize(
271278
if self._cache.dataset.count_rows() == 0:
272279
raise EmptyCacheError()
273280

274-
evaluator = Evaluator(
281+
return Evaluator(
275282
directory=self._directory,
276283
name=self._name,
277284
)
278-
evaluator.rank(
279-
where=self._ranked_path,
280-
rows_per_file=rows_per_file,
281-
compression=compression,
282-
write_batch_size=write_batch_size,
283-
read_batch_size=read_batch_size,
284-
)
285-
return evaluator
286285

287286
@classmethod
288287
def filter(
@@ -302,59 +301,60 @@ def filter(
302301
groundtruth_metadata_types=evaluator.info.groundtruth_metadata_types,
303302
prediction_metadata_types=evaluator.info.prediction_metadata_types,
304303
)
305-
for fragment in evaluator.detailed.get_fragments():
304+
for fragment in evaluator.dataset.get_fragments():
306305
tbl = fragment.to_table(filter=filter_expr.datums)
307306

308307
columns = (
309308
"datum_id",
310-
"gt_id",
311-
"pd_id",
312309
"gt_label_id",
313310
"pd_label_id",
314-
"iou",
315-
"score",
316311
)
317312
pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
318313

319314
n_pairs = pairs.shape[0]
320315
gt_ids = pairs[:, (0, 1)].astype(np.int64)
321316
pd_ids = pairs[:, (0, 2)].astype(np.int64)
322317

323-
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
324-
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
325-
326318
if filter_expr.groundtruths is not None:
319+
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
327320
gt_tbl = tbl.filter(filter_expr.groundtruths)
328321
gt_pairs = np.column_stack(
329-
[gt_tbl[col].to_numpy() for col in ("datum_id", "gt_id")]
322+
[
323+
gt_tbl[col].to_numpy()
324+
for col in ("datum_id", "gt_label_id")
325+
]
330326
).astype(np.int64)
331327
for gt in np.unique(gt_pairs, axis=0):
332328
mask_valid_gt |= (gt_ids == gt).all(axis=1)
329+
else:
330+
mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
333331

334332
if filter_expr.predictions is not None:
333+
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
335334
pd_tbl = tbl.filter(filter_expr.predictions)
336335
pd_pairs = np.column_stack(
337-
[pd_tbl[col].to_numpy() for col in ("datum_id", "pd_id")]
336+
[
337+
pd_tbl[col].to_numpy()
338+
for col in ("datum_id", "pd_label_id")
339+
]
338340
).astype(np.int64)
339341
for pd in np.unique(pd_pairs, axis=0):
340342
mask_valid_pd |= (pd_ids == pd).all(axis=1)
343+
else:
344+
mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
341345

342346
mask_valid = mask_valid_gt | mask_valid_pd
343347
mask_valid_gt &= mask_valid
344348
mask_valid_pd &= mask_valid
345349

346-
pairs[np.ix_(~mask_valid_gt, (1, 3))] = -1.0 # type: ignore - numpy ix_
347-
pairs[np.ix_(~mask_valid_pd, (2, 4, 6))] = -1.0 # type: ignore - numpy ix_
348-
pairs[~mask_valid_pd | ~mask_valid_gt, 5] = 0.0
350+
pairs[~mask_valid_gt, 1] = -1
351+
pairs[~mask_valid_pd, 2] = -1
349352

350353
for idx, col in enumerate(columns):
351354
tbl = tbl.set_column(
352355
tbl.schema.names.index(col), col, pa.array(pairs[:, idx])
353356
)
354-
355-
mask_invalid = ~mask_valid | (pairs[:, (1, 2)] < 0).all(axis=1)
356-
filtered_tbl = tbl.filter(pa.array(~mask_invalid))
357-
loader._cache.write_table(filtered_tbl)
357+
loader._cache.write_table(tbl)
358358

359359
loader._cache.flush()
360360
if loader._cache.dataset.count_rows() == 0:
@@ -365,5 +365,4 @@ def filter(
365365
name=loader._name,
366366
labels_override=evaluator._index_to_label,
367367
)
368-
evaluator.rank(where=loader._ranked_path)
369368
return evaluator

0 commit comments

Comments
 (0)