Skip to content

Commit a836578

Browse files
authored
Merge pull request #286 from rudrakatkar/add-tests-metrics
Added tests for Metrics Module
2 parents 73ef68a + 08857d5 commit a836578

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed

tests/test_metrics.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import numpy as np
2+
import pytest
3+
from detectionmetrics.utils.metrics import MetricsFactory
4+
5+
6+
@pytest.fixture
7+
def metrics_factory():
8+
"""Fixture to create a MetricsFactory instance for testing"""
9+
return MetricsFactory(n_classes=3)
10+
11+
12+
def test_update_confusion_matrix(metrics_factory):
13+
"""Test confusion matrix updates correctly"""
14+
pred = np.array([0, 1, 2, 2, 1])
15+
gt = np.array([0, 1, 1, 2, 2])
16+
17+
metrics_factory.update(pred, gt)
18+
confusion_matrix = metrics_factory.get_confusion_matrix()
19+
20+
expected = np.array([
21+
[1, 0, 0], # True class 0
22+
[0, 1, 1], # True class 1
23+
[0, 1, 1], # True class 2
24+
])
25+
assert np.array_equal(confusion_matrix, expected), "Confusion matrix mismatch"
26+
27+
28+
def test_get_tp_fp_fn_tn(metrics_factory):
29+
pred = np.array([0, 1, 1, 2, 2])
30+
gt = np.array([0, 1, 1, 2, 2])
31+
metrics_factory.update(pred, gt)
32+
33+
assert np.array_equal(metrics_factory.get_tp(), np.array([1, 2, 2]))
34+
assert np.array_equal(metrics_factory.get_fp(), np.array([0, 0, 0]))
35+
assert np.array_equal(metrics_factory.get_fn(), np.array([0, 0, 0]))
36+
assert np.array_equal(metrics_factory.get_tn(), np.array([4, 3, 3]))
37+
38+
def test_recall(metrics_factory):
39+
"""Test recall calculation"""
40+
pred = np.array([0, 1, 2, 2, 1])
41+
gt = np.array([0, 1, 1, 2, 2])
42+
43+
metrics_factory.update(pred, gt)
44+
45+
expected_recall = np.array([1.0, 0.5, 0.5])
46+
computed_recall = metrics_factory.get_recall()
47+
48+
assert np.allclose(computed_recall, expected_recall, equal_nan=True)
49+
50+
def test_accuracy(metrics_factory):
51+
"""Test global accuracy calculation (non per-class)"""
52+
pred = np.array([0, 1, 2, 2, 1])
53+
gt = np.array([0, 1, 1, 2, 2])
54+
55+
metrics_factory.update(pred, gt)
56+
57+
TP = metrics_factory.get_tp(per_class=False)
58+
FP = metrics_factory.get_fp(per_class=False)
59+
FN = metrics_factory.get_fn(per_class=False)
60+
TN = metrics_factory.get_tn(per_class=False)
61+
62+
total = TP + FP + FN + TN
63+
expected_accuracy = (TP + TN) / total if total > 0 else math.nan
64+
65+
computed_accuracy = metrics_factory.get_accuracy(per_class=False)
66+
assert np.isclose(computed_accuracy, expected_accuracy, equal_nan=True)
67+
68+
def test_f1_score(metrics_factory):
69+
"""Test F1-score calculation"""
70+
pred = np.array([0, 1, 2, 2, 1])
71+
gt = np.array([0, 1, 1, 2, 2])
72+
73+
metrics_factory.update(pred, gt)
74+
75+
precision = np.array([1.0, 0.5, 0.5])
76+
recall = np.array([1.0, 0.5, 0.5])
77+
expected_f1 = 2 * (precision * recall) / (precision + recall)
78+
79+
computed_f1 = metrics_factory.get_f1_score()
80+
81+
assert np.allclose(computed_f1, expected_f1, equal_nan=True)
82+
83+
84+
def test_edge_cases(metrics_factory):
85+
"""Test edge cases like empty arrays and division by zero"""
86+
pred = np.array([])
87+
gt = np.array([])
88+
89+
with pytest.raises(AssertionError):
90+
metrics_factory.update(pred, gt)
91+
92+
empty_metrics_factory = MetricsFactory(n_classes=3)
93+
94+
assert np.isnan(empty_metrics_factory.get_precision(per_class=False))
95+
assert np.isnan(empty_metrics_factory.get_recall(per_class=False))
96+
assert np.isnan(empty_metrics_factory.get_f1_score(per_class=False))
97+
assert np.isnan(empty_metrics_factory.get_iou(per_class=False))
98+
99+
100+
def test_macro_micro_weighted(metrics_factory):
101+
"""Test macro, micro, and weighted metric averaging"""
102+
pred = np.array([0, 1, 2, 2, 1])
103+
gt = np.array([0, 1, 1, 2, 2])
104+
105+
metrics_factory.update(pred, gt)
106+
107+
macro_f1 = metrics_factory.get_averaged_metric("f1_score", method="macro")
108+
micro_f1 = metrics_factory.get_averaged_metric("f1_score", method="micro")
109+
110+
weights = np.array([0.2, 0.5, 0.3])
111+
weighted_f1 = metrics_factory.get_averaged_metric("f1_score", method="weighted", weights=weights)
112+
113+
assert 0 <= macro_f1 <= 1
114+
assert 0 <= micro_f1 <= 1
115+
assert 0 <= weighted_f1 <= 1
116+

0 commit comments

Comments
 (0)