Skip to content

Commit d1719c4

Browse files
authored
Merge pull request #24 from ihmeuw-msca/bugfix/metrics-include_groups
Bugfix/metrics include groups
2 parents d75afb8 + 48fd8d0 commit d1719c4

File tree

4 files changed

+112
-2
lines changed

4 files changed

+112
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"scipy",
2323
"pydantic",
2424
"scikit-learn",
25+
"pandas",
2526
]
2627

2728
[project.optional-dependencies]

src/msca/metrics/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import StrEnum, auto
2+
23
import numpy as np
34
import pandas as pd
4-
55
from sklearn import metrics
66

77

@@ -246,7 +246,6 @@ def _eval_grouped(
246246
obs,
247247
pred,
248248
weights,
249-
include_groups=False,
250249
)
251250
.reset_index()
252251
)

tests/metrics/__init__.py

Whitespace-only changes.

tests/metrics/test_metrics.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
from sklearn import metrics
5+
6+
from msca.metrics import Metric # Replace with actual import path
7+
8+
9+
@pytest.fixture
10+
def sample_data():
11+
return pd.DataFrame(
12+
{
13+
"obs": [1.0, 2.0, 3.0, 4.0],
14+
"pred": [1.1, 1.9, 3.2, 3.8],
15+
"pred_alt": [1.2, 2.1, 3.1, 3.9],
16+
"pred_ref": [1.1, 2.0, 3.1, 4.0],
17+
"weights": [1.0, 1.0, 1.0, 1.0],
18+
"region": ["A", "A", "B", "B"],
19+
}
20+
)
21+
22+
23+
@pytest.mark.parametrize(
24+
"metric",
25+
[
26+
Metric.MEAN_ABSOLUTE_ERROR,
27+
Metric.MEAN_SQUARED_ERROR,
28+
Metric.MEAN_ABSOLUTE_PERCENTAGE_ERROR,
29+
Metric.MEDIAN_ABSOLUTE_ERROR,
30+
Metric.ROOT_MEAN_SQUARED_ERROR,
31+
],
32+
)
33+
def test_eval_single_metric(metric, sample_data):
34+
result = metric.eval(sample_data, "obs", "pred", "weights")
35+
assert isinstance(result, float)
36+
assert result >= 0
37+
38+
39+
@pytest.mark.parametrize(
40+
"metric_enum",
41+
[
42+
Metric.MEAN_ABSOLUTE_ERROR,
43+
Metric.MEAN_SQUARED_ERROR,
44+
],
45+
)
46+
def test_eval_grouped(metric_enum, sample_data):
47+
result_df = metric_enum.eval(
48+
sample_data, "obs", "pred", "weights", groupby=["region"]
49+
)
50+
assert isinstance(result_df, pd.DataFrame)
51+
assert "region" in result_df.columns
52+
metric_col = f"pred_{metric_enum.value}"
53+
assert metric_col in result_df.columns
54+
assert len(result_df) == sample_data["region"].nunique()
55+
56+
57+
def test_eval_skill_single(sample_data):
58+
metric = Metric.MEAN_ABSOLUTE_ERROR
59+
score = metric.eval_skill(
60+
sample_data, "obs", "pred_alt", "pred_ref", "weights"
61+
)
62+
assert isinstance(score, float)
63+
assert score <= 1 # skill score range
64+
65+
66+
def test_eval_skill_grouped(sample_data):
67+
metric = Metric.MEAN_ABSOLUTE_ERROR
68+
df = metric.eval_skill(
69+
sample_data,
70+
"obs",
71+
"pred_alt",
72+
"pred_ref",
73+
"weights",
74+
groupby=["region"],
75+
)
76+
assert isinstance(df, pd.DataFrame)
77+
assert "region" in df.columns
78+
skill_col = f"pred_alt_{metric.value}_skill"
79+
assert skill_col in df.columns
80+
81+
82+
def test_eval_skill_zero_division_grouped(sample_data):
83+
# Force reference metric to be zero
84+
sample_data["pred_ref"] = sample_data["obs"]
85+
metric = Metric.MEAN_ABSOLUTE_ERROR
86+
87+
# Make obs == pred_ref so MAE is zero
88+
with pytest.raises(ZeroDivisionError):
89+
metric.eval_skill(
90+
sample_data,
91+
"obs",
92+
"pred_alt",
93+
"pred_ref",
94+
"weights",
95+
groupby=["region"],
96+
)
97+
98+
99+
def test_eval_skill_zero_division_single(sample_data):
100+
# Force reference metric to be zero
101+
sample_data["pred_ref"] = sample_data["obs"]
102+
metric = Metric.MEAN_ABSOLUTE_ERROR
103+
with pytest.raises(ZeroDivisionError):
104+
metric.eval_skill(sample_data, "obs", "pred_alt", "pred_ref", "weights")
105+
106+
107+
def test_eval_single_unsupported_metric(sample_data):
108+
with pytest.raises(ValueError):
109+
fake = Metric("fake")
110+
fake._eval_single(sample_data, "obs", "pred", "weights")

0 commit comments

Comments
 (0)