Skip to content

Commit d89e0cf

Browse files
committed
tidying up
1 parent 75d4727 commit d89e0cf

4 files changed

Lines changed: 52 additions & 148 deletions

File tree

src/valor_lite/object_detection/computation.py

Lines changed: 2 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -174,54 +174,7 @@ def compute_polygon_iou(
174174
return ious
175175

176176

177-
def compute_label_metadata(
178-
ids: NDArray[np.int32],
179-
n_labels: int,
180-
) -> NDArray[np.uint32]:
181-
"""
182-
Computes label metadata returning a count of annotations per label.
183-
184-
Parameters
185-
----------
186-
detailed_pairs : NDArray[np.int32]
187-
Detailed annotation pairings with shape (N, 7).
188-
Index 0 - Datum Index
189-
Index 1 - GroundTruth Index
190-
Index 2 - Prediction Index
191-
Index 3 - GroundTruth Label Index
192-
Index 4 - Prediction Label Index
193-
n_labels : int
194-
The total number of unique labels.
195-
196-
Returns
197-
-------
198-
NDArray[np.int32]
199-
The label metadata array with shape (n_labels, 2).
200-
Index 0 - Ground truth label count
201-
Index 1 - Prediction label count
202-
"""
203-
label_metadata = np.zeros((n_labels, 2), dtype=np.uint32)
204-
205-
ground_truth_pairs = ids[:, (0, 1, 3)]
206-
ground_truth_pairs = ground_truth_pairs[ground_truth_pairs[:, 1] >= 0]
207-
unique_pairs = np.unique(ground_truth_pairs, axis=0)
208-
label_indices, unique_counts = np.unique(
209-
unique_pairs[:, 2], return_counts=True
210-
)
211-
label_metadata[label_indices.astype(np.int32), 0] = unique_counts
212-
213-
prediction_pairs = ids[:, (0, 2, 4)]
214-
prediction_pairs = prediction_pairs[prediction_pairs[:, 1] >= 0]
215-
unique_pairs = np.unique(prediction_pairs, axis=0)
216-
label_indices, unique_counts = np.unique(
217-
unique_pairs[:, 2], return_counts=True
218-
)
219-
label_metadata[label_indices.astype(np.int32), 1] = unique_counts
220-
221-
return label_metadata
222-
223-
224-
def rank_pairs_returning_indices(sorted_pairs: NDArray[np.float64]):
177+
def rank_pairs(sorted_pairs: NDArray[np.float64]):
225178
"""
226179
Prunes and ranks prediction pairs.
227180
@@ -327,7 +280,7 @@ def rank_table(tbl: pa.Table, number_of_labels: int) -> pa.Table:
327280
pairs = np.column_stack(
328281
[sorted_tbl[col].to_numpy() for col in numeric_columns]
329282
)
330-
pairs, indices = rank_pairs_returning_indices(pairs)
283+
pairs, indices = rank_pairs(pairs)
331284
ranked_tbl = sorted_tbl.take(indices)
332285
lower_iou_bound, winning_predictions = calculate_ranking_boundaries(
333286
pairs, number_of_labels=number_of_labels
@@ -344,57 +297,6 @@ def rank_table(tbl: pa.Table, number_of_labels: int) -> pa.Table:
344297
return ranked_tbl
345298

346299

347-
def rank_pairs(
348-
detailed_pairs: NDArray[np.float64],
349-
) -> NDArray[np.float64]:
350-
"""
351-
Highly optimized pair ranking for computing precision and recall based metrics.
352-
353-
Only ground truths and predictions that provide unique information are kept. The unkept
354-
pairs are represented via the label metadata array.
355-
356-
Parameters
357-
----------
358-
detailed_pairs : NDArray[np.float64]
359-
Detailed annotation pairs with shape (n_pairs, 7).
360-
Index 0 - Datum Index
361-
Index 1 - GroundTruth Index
362-
Index 2 - Prediction Index
363-
Index 3 - GroundTruth Label Index
364-
Index 4 - Prediction Label Index
365-
Index 5 - IOU
366-
Index 6 - Score
367-
368-
Returns
369-
-------
370-
NDArray[np.float64]
371-
Array of ranked pairs for precision-recall metric computation.
372-
"""
373-
# remove unmatched ground truths
374-
pairs = detailed_pairs[detailed_pairs[:, 2] >= 0.0]
375-
376-
# find best fits for prediction
377-
mask_label_match = np.isclose(pairs[:, 3], pairs[:, 4])
378-
matched_predictions = np.unique(pairs[mask_label_match, 2])
379-
mask_unmatched_predictions = ~np.isin(pairs[:, 2], matched_predictions)
380-
pairs = pairs[mask_label_match | mask_unmatched_predictions]
381-
382-
# only keep the highest ranked pair
383-
_, indices = np.unique(pairs[:, [0, 2, 4]], axis=0, return_index=True)
384-
pairs = pairs[indices]
385-
386-
# np.unique orders its results by value, we need to sort the indices to maintain the results of the lexsort
387-
indices = np.lexsort(
388-
(
389-
-pairs[:, 5], # iou
390-
-pairs[:, 6], # score
391-
)
392-
)
393-
pairs = pairs[indices]
394-
395-
return pairs
396-
397-
398300
def compute_counts(
399301
ranked_pairs: NDArray[np.float64],
400302
iou_thresholds: NDArray[np.float64],

src/valor_lite/object_detection/evaluator.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import pyarrow.compute as pc
1111
import pyarrow.dataset as ds
1212
import pyarrow.parquet as pq
13-
from numpy.typing import NDArray
1413

1514
from valor_lite.cache import CacheReader, CacheWriter, DataType
1615
from valor_lite.object_detection.computation import (
@@ -24,27 +23,13 @@
2423
)
2524
from valor_lite.object_detection.metric import Metric, MetricType
2625
from valor_lite.object_detection.utilities import (
26+
create_mapping,
2727
unpack_confusion_matrix,
2828
unpack_examples,
2929
unpack_precision_recall_into_metric_lists,
3030
)
3131

3232

33-
def create_mapping(
34-
tbl: pa.Table,
35-
pairs: NDArray[np.float64],
36-
index: int,
37-
id_col: str,
38-
uid_col: str,
39-
) -> dict[int, str]:
40-
col = pairs[:, index].astype(np.int64)
41-
values, indices = np.unique(col, return_index=True)
42-
indices = indices[values >= 0]
43-
return {
44-
tbl[id_col][idx].as_py(): tbl[uid_col][idx].as_py() for idx in indices
45-
}
46-
47-
4833
@dataclass
4934
class EvaluatorInfo:
5035
number_of_datums: int = 0
@@ -65,7 +50,6 @@ class Filter:
6550
datums: pc.Expression | None = None
6651
groundtruths: pc.Expression | None = None
6752
predictions: pc.Expression | None = None
68-
labels: pc.Expression | None = None
6953

7054

7155
class Evaluator:
@@ -90,7 +74,7 @@ def __init__(
9074
self._index_to_label,
9175
self._number_of_groundtruths_per_label,
9276
self._info,
93-
) = self._generate_meta(labels_override)
77+
) = self.generate_meta(self._dataset, labels_override)
9478

9579
with open(self._metadata_path, "r") as f:
9680
types = json.load(f)
@@ -135,19 +119,22 @@ def detailed(self) -> ds.Dataset:
135119
return self._dataset
136120

137121
@property
138-
def ranked(self):
122+
def ranked(self) -> ds.Dataset:
139123
return ds.dataset(self._ranked_path, format="parquet")
140124

141125
@property
142126
def info(self) -> EvaluatorInfo:
143127
return self._info
144128

145-
def _generate_meta(self, labels_override: dict[int, str] | None):
129+
@staticmethod
130+
def generate_meta(
131+
dataset: ds.Dataset, labels_override: dict[int, str] | None
132+
):
146133
gt_counts_per_lbl = defaultdict(int)
147134
labels = labels_override if labels_override else {}
148135
info = EvaluatorInfo()
149136

150-
for fragment in self.detailed.get_fragments():
137+
for fragment in dataset.get_fragments():
151138
tbl = fragment.to_table()
152139
columns = (
153140
"datum_id",
@@ -217,6 +204,29 @@ def _generate_meta(self, labels_override: dict[int, str] | None):
217204

218205
return labels, number_of_groundtruths_per_label, info
219206

207+
@staticmethod
208+
def iterate_pairs(
209+
dataset: ds.Dataset,
210+
columns: list[str] | None = None,
211+
):
212+
for fragment in dataset.get_fragments():
213+
tbl = fragment.to_table(columns=columns)
214+
yield np.column_stack(
215+
[tbl.column(i).to_numpy() for i in range(tbl.num_columns)]
216+
)
217+
218+
@staticmethod
219+
def iterate_pairs_with_table(
220+
dataset: ds.Dataset,
221+
columns: list[str] | None = None,
222+
):
223+
for fragment in dataset.get_fragments():
224+
tbl = fragment.to_table()
225+
columns = columns if columns else tbl.columns
226+
yield tbl, np.column_stack(
227+
[tbl[col].to_numpy() for col in columns]
228+
)
229+
220230
def filter(
221231
self,
222232
filter_expr: Filter,
@@ -234,7 +244,7 @@ def filter(
234244
filter_expr=filter_expr,
235245
)
236246

237-
def create_ranked_cache(
247+
def rank(
238248
self,
239249
where: str | Path,
240250
rows_per_file: int | None = None,
@@ -338,29 +348,6 @@ def generate_heap_item(batches, batch_idx, row_idx):
338348
heap, generate_heap_item(batches, batch_idx, 0)
339349
)
340350

341-
@staticmethod
342-
def iterate_pairs(
343-
dataset: ds.Dataset,
344-
columns: list[str] | None = None,
345-
):
346-
for fragment in dataset.get_fragments():
347-
tbl = fragment.to_table(columns=columns)
348-
yield np.column_stack(
349-
[tbl.column(i).to_numpy() for i in range(tbl.num_columns)]
350-
)
351-
352-
@staticmethod
353-
def iterate_pairs_with_table(
354-
dataset: ds.Dataset,
355-
columns: list[str] | None = None,
356-
):
357-
for fragment in dataset.get_fragments():
358-
tbl = fragment.to_table()
359-
columns = columns if columns else tbl.columns
360-
yield tbl, np.column_stack(
361-
[tbl[col].to_numpy() for col in columns]
362-
)
363-
364351
def compute_precision_recall(
365352
self,
366353
iou_thresholds: list[float],

src/valor_lite/object_detection/loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def filter(
453453
name=loader._name,
454454
labels_override=evaluator._index_to_label,
455455
)
456-
evaluator.create_ranked_cache(where=loader._ranked_path)
456+
evaluator.rank(where=loader._ranked_path)
457457
return evaluator
458458

459459
def finalize(self):
@@ -473,5 +473,5 @@ def finalize(self):
473473
directory=self._directory,
474474
name=self._name,
475475
)
476-
evaluator.create_ranked_cache(where=self._ranked_path)
476+
evaluator.rank(where=self._ranked_path)
477477
return evaluator

src/valor_lite/object_detection/utilities.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from collections import defaultdict
22

33
import numpy as np
4+
import pyarrow as pa
45
from numpy.typing import NDArray
56

6-
from valor_lite.object_detection.computation import PairClassification
77
from valor_lite.object_detection.metric import Metric, MetricType
88

99

@@ -328,6 +328,21 @@ def unpack_confusion_matrix(
328328
return metrics
329329

330330

331+
def create_mapping(
332+
tbl: pa.Table,
333+
pairs: NDArray[np.float64],
334+
index: int,
335+
id_col: str,
336+
uid_col: str,
337+
) -> dict[int, str]:
338+
col = pairs[:, index].astype(np.int64)
339+
values, indices = np.unique(col, return_index=True)
340+
indices = indices[values >= 0]
341+
return {
342+
tbl[id_col][idx].as_py(): tbl[uid_col][idx].as_py() for idx in indices
343+
}
344+
345+
331346
def unpack_examples(
332347
detailed_pairs: NDArray[np.float64],
333348
mask_tp: NDArray[np.bool_],

0 commit comments

Comments
 (0)