Skip to content

Commit 75d3c19

Browse files
authored
remove suppression of counts, pr curves (#852)
1 parent cc4073c commit 75d3c19

6 files changed

Lines changed: 43 additions & 8 deletions

File tree

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
* @czaloom @ekorman @jyono @rsbowman-striveworks
1+
* @Striveworks/models

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ site/*
2828
*.jpeg
2929
*.pt
3030
*.png
31-
*.jpg
31+
*.jpg
32+
*.parquet

src/valor_lite/object_detection/manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,6 @@ def compute_precision_recall(
410410
)
411411
return unpack_precision_recall_into_metric_lists(
412412
results=results,
413-
label_metadata=label_metadata,
414413
iou_thresholds=iou_thresholds,
415414
score_thresholds=score_thresholds,
416415
index_to_label=self.index_to_label,

src/valor_lite/object_detection/utilities.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def unpack_precision_recall_into_metric_lists(
2323
iou_thresholds: list[float],
2424
score_thresholds: list[float],
2525
index_to_label: list[str],
26-
label_metadata: NDArray[np.int32],
2726
):
2827
(
2928
(
@@ -125,13 +124,9 @@ def unpack_precision_recall_into_metric_lists(
125124
)
126125
for iou_idx, iou_threshold in enumerate(iou_thresholds)
127126
for label_idx, label in enumerate(index_to_label)
128-
if label_metadata[label_idx, 0] > 0
129127
]
130128

131129
for label_idx, label in enumerate(index_to_label):
132-
if label_metadata[label_idx, 0] == 0:
133-
continue
134-
135130
for score_idx, score_threshold in enumerate(score_thresholds):
136131
for iou_idx, iou_threshold in enumerate(iou_thresholds):
137132

tests/object_detection/test_counts.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,24 @@ def test_counts_ranked_pair_ordering(
592592
"label": "label3",
593593
},
594594
},
595+
{
596+
"type": "Counts",
597+
"value": {"tp": 0, "fp": 1, "fn": 0},
598+
"parameters": {
599+
"iou_threshold": 0.5,
600+
"score_threshold": 0.0,
601+
"label": "label4",
602+
},
603+
},
604+
{
605+
"type": "Counts",
606+
"value": {"tp": 0, "fp": 1, "fn": 0},
607+
"parameters": {
608+
"iou_threshold": 0.75,
609+
"score_threshold": 0.0,
610+
"label": "label4",
611+
},
612+
},
595613
]
596614
for m in actual_metrics:
597615
assert m in expected_metrics

tests/object_detection/test_pr_curve.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,28 @@ def test_pr_curve_using_torch_metrics_example(
151151
"label": "2",
152152
},
153153
},
154+
{
155+
"type": "PrecisionRecallCurve",
156+
"value": {
157+
"precisions": [0.0 for _ in range(101)],
158+
"scores": [0.318] + [0.0 for _ in range(100)],
159+
},
160+
"parameters": {
161+
"iou_threshold": 0.5,
162+
"label": "3",
163+
},
164+
},
165+
{
166+
"type": "PrecisionRecallCurve",
167+
"value": {
168+
"precisions": [0.0 for _ in range(101)],
169+
"scores": [0.318] + [0.0 for _ in range(100)],
170+
},
171+
"parameters": {
172+
"iou_threshold": 0.75,
173+
"label": "3",
174+
},
175+
},
154176
{
155177
"type": "PrecisionRecallCurve",
156178
"value": {

0 commit comments

Comments
 (0)