1- import numpy as np
2-
31from valor_lite .classification import Classification , DataLoader , Evaluator
4- from valor_lite .classification .computation import (
5- PairClassification ,
6- compute_confusion_matrix ,
7- )
8-
9-
10- def test_compute_confusion_matrix ():
11-
12- # groundtruth, prediction, score
13- data = np .array (
14- [
15- # datum 0
16- [0 , 0 , 0 , 1.0 , 1.0 ], # tp
17- [0 , 0 , 1 , 0.0 , 0.0 ], # tn
18- [0 , 0 , 2 , 0.0 , 0.0 ], # tn
19- [0 , 0 , 3 , 0.0 , 0.0 ], # tn
20- # datum 1
21- [1 , 0 , 0 , 0.0 , 0.0 ], # fn
22- [1 , 0 , 1 , 0.0 , 0.0 ], # tn
23- [1 , 0 , 2 , 1.0 , 1.0 ], # fp
24- [1 , 0 , 3 , 0.0 , 0.0 ], # tn
25- # datum 2
26- [2 , 3 , 0 , 0.0 , 0.0 ], # tn
27- [2 , 3 , 1 , 0.0 , 0.0 ], # tn
28- [2 , 3 , 2 , 0.0 , 0.0 ], # tn
29- [2 , 3 , 3 , 0.3 , 1.0 ], # fn for score threshold > 0.3
30- ],
31- dtype = np .float64 ,
32- )
33- score_thresholds = np .array ([0.25 , 0.75 ], dtype = np .float64 )
342
35- result = compute_confusion_matrix (
36- detailed_pairs = data ,
37- score_thresholds = score_thresholds ,
38- hardmax = True ,
39- )
40-
41- assert result .shape == (2 , 12 )
42- assert np .all (
43- result
44- == np .array (
45- [
46- [
47- PairClassification .TP ,
48- 0 ,
49- 0 ,
50- 0 ,
51- 0 ,
52- 0 ,
53- PairClassification .FP_FN_MISCLF ,
54- 0 ,
55- 0 ,
56- 0 ,
57- 0 ,
58- PairClassification .TP ,
59- ],
60- [
61- PairClassification .TP ,
62- 0 ,
63- 0 ,
64- 0 ,
65- 0 ,
66- 0 ,
67- PairClassification .FP_FN_MISCLF ,
68- 0 ,
69- PairClassification .FN_UNMATCHED ,
70- PairClassification .FN_UNMATCHED ,
71- PairClassification .FN_UNMATCHED ,
72- PairClassification .FN_UNMATCHED ,
73- ],
74- ],
75- dtype = np .uint8 ,
76- ),
77- )
3+ # def test_compute_confusion_matrix():
4+
5+ # # groundtruth, prediction, score
6+ # data = np.array(
7+ # [
8+ # # datum 0
9+ # [0, 0, 0, 1.0, 1.0], # tp
10+ # [0, 0, 1, 0.0, 0.0], # tn
11+ # [0, 0, 2, 0.0, 0.0], # tn
12+ # [0, 0, 3, 0.0, 0.0], # tn
13+ # # datum 1
14+ # [1, 0, 0, 0.0, 0.0], # fn
15+ # [1, 0, 1, 0.0, 0.0], # tn
16+ # [1, 0, 2, 1.0, 1.0], # fp
17+ # [1, 0, 3, 0.0, 0.0], # tn
18+ # # datum 2
19+ # [2, 3, 0, 0.0, 0.0], # tn
20+ # [2, 3, 1, 0.0, 0.0], # tn
21+ # [2, 3, 2, 0.0, 0.0], # tn
22+ # [2, 3, 3, 0.3, 1.0], # fn for score threshold > 0.3
23+ # ],
24+ # dtype=np.float64,
25+ # )
26+ # score_thresholds = np.array([0.25, 0.75], dtype=np.float64)
27+
28+ # result = compute_confusion_matrix(
29+ # detailed_pairs=data,
30+ # score_thresholds=score_thresholds,
31+ # hardmax=True,
32+ # )
33+
34+ # assert result.shape == (2, 12)
35+ # assert np.all(
36+ # result
37+ # == np.array(
38+ # [
39+ # [
40+ # PairClassification.TP,
41+ # 0,
42+ # 0,
43+ # 0,
44+ # 0,
45+ # 0,
46+ # PairClassification.FP_FN_MISCLF,
47+ # 0,
48+ # 0,
49+ # 0,
50+ # 0,
51+ # PairClassification.TP,
52+ # ],
53+ # [
54+ # PairClassification.TP,
55+ # 0,
56+ # 0,
57+ # 0,
58+ # 0,
59+ # 0,
60+ # PairClassification.FP_FN_MISCLF,
61+ # 0,
62+ # PairClassification.FN_UNMATCHED,
63+ # PairClassification.FN_UNMATCHED,
64+ # PairClassification.FN_UNMATCHED,
65+ # PairClassification.FN_UNMATCHED,
66+ # ],
67+ # ],
68+ # dtype=np.uint8,
69+ # ),
70+ # )
7871
7972
8073def test_compute_confusion_matrix_empty_pairs ():
@@ -101,14 +94,10 @@ def test_confusion_matrix_basic(basic_classifications: list[Classification]):
10194 loader .add_data (basic_classifications )
10295 evaluator = loader .finalize ()
10396
104- assert evaluator .ignored_prediction_labels == ["1" , "2" ]
105- assert evaluator .missing_prediction_labels == []
106- assert evaluator .metadata .to_dict () == {
107- "number_of_datums" : 3 ,
108- "number_of_ground_truths" : 3 ,
109- "number_of_predictions" : 12 ,
110- "number_of_labels" : 4 ,
111- }
97+ assert evaluator .metadata .number_of_datums == 3
98+ assert evaluator .metadata .number_of_ground_truths == 3
99+ assert evaluator .metadata .number_of_predictions == 12
100+ assert evaluator .metadata .number_of_labels == 4
112101
113102 actual_metrics = evaluator .compute_confusion_matrix (
114103 score_thresholds = [0.25 , 0.75 ],
@@ -418,14 +407,10 @@ def test_confusion_matrix_multiclass(
418407 loader .add_data (classifications_multiclass )
419408 evaluator = loader .finalize ()
420409
421- assert evaluator .ignored_prediction_labels == []
422- assert evaluator .missing_prediction_labels == []
423- assert evaluator .metadata .to_dict () == {
424- "number_of_datums" : 5 ,
425- "number_of_ground_truths" : 5 ,
426- "number_of_labels" : 3 ,
427- "number_of_predictions" : 15 ,
428- }
410+ assert evaluator .metadata .number_of_datums == 5
411+ assert evaluator .metadata .number_of_ground_truths == 5
412+ assert evaluator .metadata .number_of_labels == 3
413+ assert evaluator .metadata .number_of_predictions == 15
429414
430415 actual_metrics = evaluator .compute_confusion_matrix (
431416 score_thresholds = [0.05 , 0.5 , 0.85 ],
@@ -576,14 +561,10 @@ def test_confusion_matrix_without_hardmax_animal_example(
576561 loader .add_data (classifications_multiclass_true_negatives_check )
577562 evaluator = loader .finalize ()
578563
579- assert evaluator .ignored_prediction_labels == ["bee" , "cat" ]
580- assert evaluator .missing_prediction_labels == []
581- assert evaluator .metadata .to_dict () == {
582- "number_of_datums" : 1 ,
583- "number_of_ground_truths" : 1 ,
584- "number_of_predictions" : 3 ,
585- "number_of_labels" : 3 ,
586- }
564+ assert evaluator .metadata .number_of_datums == 1
565+ assert evaluator .metadata .number_of_ground_truths == 1
566+ assert evaluator .metadata .number_of_predictions == 3
567+ assert evaluator .metadata .number_of_labels == 3
587568
588569 actual_metrics = evaluator .compute_confusion_matrix (
589570 score_thresholds = [0.05 , 0.4 , 0.5 ],
0 commit comments