Skip to content

Commit 4f7bf2f

Browse files
committed
all tests passing - need to add more tho
1 parent 0932a46 commit 4f7bf2f

3 files changed

Lines changed: 15 additions & 527 deletions

File tree

src/valor_lite/classification/utilities.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pyarrow as pa
55
from numpy.typing import NDArray
66

7-
from valor_lite.classification.computation import PairClassification
87
from valor_lite.classification.metric import Metric, MetricType
98

109

@@ -237,7 +236,6 @@ def _unpack_confusion_matrix_with_examples(
237236
{
238237
"datum_id": index_to_datum_id[unique_matches[idx, 0]],
239238
"score": float(scores[unique_match_indices[idx]]),
240-
"hardmax": float(winners[unique_match_indices[idx]]),
241239
}
242240
)
243241
if idx < n_unmatched_groundtruths:

tests/classification/test_confusion_matrix.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from valor_lite.classification import Classification, DataLoader, Evaluator
1+
from valor_lite.classification import Classification, DataLoader
22

33
# def test_compute_confusion_matrix():
44

@@ -70,12 +70,6 @@
7070
# )
7171

7272

73-
def test_compute_confusion_matrix_empty_pairs():
74-
evaluator = Evaluator()
75-
metrics = evaluator.compute_confusion_matrix()
76-
assert metrics == []
77-
78-
7973
def _filter_elements_with_zero_count(cm: dict, mp: dict):
8074
labels = list(mp.keys())
8175

@@ -135,6 +129,7 @@ def test_confusion_matrix_basic(basic_classifications: list[Classification]):
135129
},
136130
"parameters": {
137131
"score_threshold": 0.25,
132+
"hardmax": True,
138133
},
139134
},
140135
{
@@ -158,6 +153,7 @@ def test_confusion_matrix_basic(basic_classifications: list[Classification]):
158153
},
159154
"parameters": {
160155
"score_threshold": 0.75,
156+
"hardmax": True,
161157
},
162158
},
163159
]
@@ -229,6 +225,7 @@ def test_confusion_matrix_unit(
229225
},
230226
"parameters": {
231227
"score_threshold": 0.5,
228+
"hardmax": True,
232229
},
233230
},
234231
]
@@ -308,6 +305,7 @@ def test_confusion_matrix_with_animal_example(
308305
},
309306
"parameters": {
310307
"score_threshold": 0.5,
308+
"hardmax": True,
311309
},
312310
},
313311
]
@@ -387,6 +385,7 @@ def test_confusion_matrix_with_color_example(
387385
},
388386
"parameters": {
389387
"score_threshold": 0.5,
388+
"hardmax": True,
390389
},
391390
},
392391
]
@@ -437,7 +436,7 @@ def test_confusion_matrix_multiclass(
437436
"examples": [
438437
{
439438
"datum_id": "uid2",
440-
"score": 0.4076893257212283,
439+
"score": 0.4026564387187136,
441440
}
442441
],
443442
},
@@ -473,6 +472,7 @@ def test_confusion_matrix_multiclass(
473472
},
474473
"parameters": {
475474
"score_threshold": 0.05,
475+
"hardmax": True,
476476
},
477477
},
478478
{
@@ -515,6 +515,7 @@ def test_confusion_matrix_multiclass(
515515
},
516516
"parameters": {
517517
"score_threshold": 0.5,
518+
"hardmax": True,
518519
},
519520
},
520521
{
@@ -541,6 +542,7 @@ def test_confusion_matrix_multiclass(
541542
},
542543
"parameters": {
543544
"score_threshold": 0.85,
545+
"hardmax": True,
544546
},
545547
},
546548
]
@@ -611,6 +613,7 @@ def test_confusion_matrix_without_hardmax_animal_example(
611613
},
612614
"parameters": {
613615
"score_threshold": 0.05,
616+
"hardmax": False,
614617
},
615618
},
616619
{
@@ -633,6 +636,7 @@ def test_confusion_matrix_without_hardmax_animal_example(
633636
},
634637
"parameters": {
635638
"score_threshold": 0.4,
639+
"hardmax": False,
636640
},
637641
},
638642
{
@@ -652,6 +656,7 @@ def test_confusion_matrix_without_hardmax_animal_example(
652656
},
653657
"parameters": {
654658
"score_threshold": 0.5,
659+
"hardmax": False,
655660
},
656661
},
657662
]

0 commit comments

Comments
 (0)