Skip to content

Commit bdd6444

Browse files
authored
Precision-Recall Curve Bugfix (#862)
1 parent 58bb07f commit bdd6444

8 files changed

Lines changed: 68897 additions & 9318 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ site/*
3030
*.png
3131
*.jpg
3232
*.parquet
33+
*.npy
3334

3435
.valor/*

src/valor_lite/object_detection/computation.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -261,16 +261,21 @@ def calculate_ranking_boundaries(
261261
NDArray[np.float64]
262262
A 1-D array containing the lower IOU boundary for classifying pairs as true-positive across chunks.
263263
"""
264-
# groundtruths defined as (datum_id, groundtruth_id, groundtruth_label_id)
265-
gts = ranked_pairs[:, (0, 1, 3)].astype(np.int64)
264+
ids = ranked_pairs[:, (0, 1, 2, 3, 4)].astype(np.int64)
265+
gts = ids[:, (0, 1, 3)]
266+
gt_labels = ids[:, 3]
267+
pd_labels = ids[:, 4]
266268
ious = ranked_pairs[:, 5]
267269

268-
iou_boundary = np.ones_like(ious) * 2 # impossible bound
270+
# set default boundary to 2.0 as it will be used to check lower boundary in range [0-1].
271+
iou_boundary = np.ones_like(ious) * 2
269272

273+
mask_matching_labels = gt_labels == pd_labels
270274
mask_valid_gts = gts[:, 1] >= 0
271275
unique_gts = np.unique(gts[mask_valid_gts], axis=0)
272276
for gt in unique_gts:
273277
mask_gt = (gts == gt).all(axis=1)
278+
mask_gt &= mask_matching_labels
274279
if mask_gt.sum() <= 1:
275280
iou_boundary[mask_gt] = 0.0
276281
continue
@@ -444,10 +449,6 @@ def compute_counts(
444449
minlength=n_labels,
445450
)
446451

447-
# create true-positive mask score threshold
448-
mask_tps = mask_tp_outer
449-
true_positives_mask = mask_tps & mask_iou_prev
450-
451452
# count running tp and total for AP
452453
for pd_label in unique_pd_labels:
453454
mask_pd_label = pd_labels == pd_label
@@ -463,7 +464,7 @@ def compute_counts(
463464
running_counts[iou_idx, pd_label, 0] += total_count
464465

465466
# running true-positive count
466-
mask_tp_for_counting = mask_pd_label & true_positives_mask
467+
mask_tp_for_counting = mask_pd_label & mask_tp_outer
467468
tp_count = mask_tp_for_counting.sum()
468469
running_tp_count[iou_idx, mask_tp_for_counting] = np.arange(
469470
running_counts[iou_idx, pd_label, 1] + 1,
@@ -488,17 +489,43 @@ def compute_counts(
488489
)
489490
recall_index = np.floor(recall * 100.0).astype(np.int32)
490491

491-
# bin precision-recall curve
492+
# sort precision in descending order
493+
precision_indices = np.argsort(-precision, axis=1)
494+
495+
# populate precision-recall curve
492496
for iou_idx in range(n_ious):
493-
pr_curve[iou_idx, pd_labels, recall_index[iou_idx], 0] = np.maximum(
494-
pr_curve[iou_idx, pd_labels, recall_index[iou_idx], 0],
495-
precision[iou_idx],
497+
labeled_recall = np.hstack(
498+
[
499+
pd_labels.reshape(-1, 1),
500+
recall_index[iou_idx, :].reshape(-1, 1),
501+
]
502+
)
503+
504+
# extract maximum score per (label, recall) bin
505+
# arrays are already ordered by descending score
506+
lr_pairs, recall_indices = np.unique(
507+
labeled_recall, return_index=True, axis=0
508+
)
509+
li = lr_pairs[:, 0]
510+
ri = lr_pairs[:, 1]
511+
pr_curve[iou_idx, li, ri, 1] = np.maximum(
512+
pr_curve[iou_idx, li, ri, 1],
513+
scores[recall_indices],
514+
)
515+
516+
# extract maximum precision per (label, recall) bin
517+
# reorder arrays into descending precision order
518+
indices = precision_indices[iou_idx]
519+
sorted_precision = precision[iou_idx, indices]
520+
sorted_labeled_recall = labeled_recall[indices]
521+
lr_pairs, recall_indices = np.unique(
522+
sorted_labeled_recall, return_index=True, axis=0
496523
)
497-
pr_curve[
498-
iou_idx, pd_labels[::-1], recall_index[iou_idx][::-1], 1
499-
] = np.maximum(
500-
pr_curve[iou_idx, pd_labels[::-1], recall_index[iou_idx][::-1], 1],
501-
scores[::-1],
524+
li = lr_pairs[:, 0]
525+
ri = lr_pairs[:, 1]
526+
pr_curve[iou_idx, li, ri, 0] = np.maximum(
527+
pr_curve[iou_idx, li, ri, 0],
528+
sorted_precision[recall_indices],
502529
)
503530

504531
return counts

src/valor_lite/object_detection/evaluator.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,7 @@ def persistent(
144144
metadata_fields=metadata_fields,
145145
)
146146

147-
def _rank(
148-
self,
149-
n_labels: int,
150-
batch_size: int = 1_000,
151-
):
147+
def _rank(self, batch_size: int = 1_000):
152148
"""Perform pair ranking over the detailed cache."""
153149

154150
detailed_reader = self._detailed_writer.to_reader()
@@ -203,10 +199,7 @@ def finalize(
203199
)
204200

205201
# populate ranked cache
206-
self._rank(
207-
n_labels=len(index_to_label),
208-
batch_size=batch_size,
209-
)
202+
self._rank(batch_size)
210203

211204
ranked_reader = self._ranked_writer.to_reader()
212205
return Evaluator(

tests/object_detection/conftest.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,7 +1345,7 @@ def fixture_path() -> Path:
13451345

13461346

13471347
@pytest.fixture
1348-
def coco_detections_v0_36_6(
1348+
def coco_detections_v0_37_3(
13491349
fixture_path: Path,
13501350
) -> list[Detection[BoundingBox]]:
13511351
path = fixture_path / "coco_input.json"
@@ -1362,8 +1362,13 @@ def coco_detections_v0_36_6(
13621362

13631363

13641364
@pytest.fixture
1365-
def coco_metrics_v0_36_6(fixture_path: Path) -> dict[str, list[dict]]:
1366-
# metrics are given back in dictionary format
1367-
path = fixture_path / "coco_output.json"
1368-
with open(path, "r") as f:
1365+
def coco_metrics_path_v0_37_3(fixture_path: Path) -> Path:
1366+
return fixture_path / "coco_output.json"
1367+
1368+
1369+
@pytest.fixture
1370+
def coco_metrics_v0_37_3(
1371+
coco_metrics_path_v0_37_3: Path,
1372+
) -> dict[str, list[dict]]:
1373+
with open(coco_metrics_path_v0_37_3, "r") as f:
13691374
return json.load(f)

0 commit comments

Comments
 (0)