|
| 1 | +import numpy as np |
| 2 | +import pandas as pd |
| 3 | +import pytest |
| 4 | +from sklearn import metrics |
| 5 | + |
| 6 | +from msca.metrics import Metric # Replace with actual import path |
| 7 | + |
| 8 | + |
| 9 | +@pytest.fixture |
| 10 | +def sample_data(): |
| 11 | + return pd.DataFrame( |
| 12 | + { |
| 13 | + "obs": [1.0, 2.0, 3.0, 4.0], |
| 14 | + "pred": [1.1, 1.9, 3.2, 3.8], |
| 15 | + "pred_alt": [1.2, 2.1, 3.1, 3.9], |
| 16 | + "pred_ref": [1.1, 2.0, 3.1, 4.0], |
| 17 | + "weights": [1.0, 1.0, 1.0, 1.0], |
| 18 | + "region": ["A", "A", "B", "B"], |
| 19 | + } |
| 20 | + ) |
| 21 | + |
| 22 | + |
| 23 | +@pytest.mark.parametrize( |
| 24 | + "metric", |
| 25 | + [ |
| 26 | + Metric.MEAN_ABSOLUTE_ERROR, |
| 27 | + Metric.MEAN_SQUARED_ERROR, |
| 28 | + Metric.MEAN_ABSOLUTE_PERCENTAGE_ERROR, |
| 29 | + Metric.MEDIAN_ABSOLUTE_ERROR, |
| 30 | + Metric.ROOT_MEAN_SQUARED_ERROR, |
| 31 | + ], |
| 32 | +) |
| 33 | +def test_eval_single_metric(metric, sample_data): |
| 34 | + result = metric.eval(sample_data, "obs", "pred", "weights") |
| 35 | + assert isinstance(result, float) |
| 36 | + assert result >= 0 |
| 37 | + |
| 38 | + |
| 39 | +@pytest.mark.parametrize( |
| 40 | + "metric_enum", |
| 41 | + [ |
| 42 | + Metric.MEAN_ABSOLUTE_ERROR, |
| 43 | + Metric.MEAN_SQUARED_ERROR, |
| 44 | + ], |
| 45 | +) |
| 46 | +def test_eval_grouped(metric_enum, sample_data): |
| 47 | + result_df = metric_enum.eval( |
| 48 | + sample_data, "obs", "pred", "weights", groupby=["region"] |
| 49 | + ) |
| 50 | + assert isinstance(result_df, pd.DataFrame) |
| 51 | + assert "region" in result_df.columns |
| 52 | + metric_col = f"pred_{metric_enum.value}" |
| 53 | + assert metric_col in result_df.columns |
| 54 | + assert len(result_df) == sample_data["region"].nunique() |
| 55 | + |
| 56 | + |
| 57 | +def test_eval_skill_single(sample_data): |
| 58 | + metric = Metric.MEAN_ABSOLUTE_ERROR |
| 59 | + score = metric.eval_skill( |
| 60 | + sample_data, "obs", "pred_alt", "pred_ref", "weights" |
| 61 | + ) |
| 62 | + assert isinstance(score, float) |
| 63 | + assert score <= 1 # skill score range |
| 64 | + |
| 65 | + |
| 66 | +def test_eval_skill_grouped(sample_data): |
| 67 | + metric = Metric.MEAN_ABSOLUTE_ERROR |
| 68 | + df = metric.eval_skill( |
| 69 | + sample_data, |
| 70 | + "obs", |
| 71 | + "pred_alt", |
| 72 | + "pred_ref", |
| 73 | + "weights", |
| 74 | + groupby=["region"], |
| 75 | + ) |
| 76 | + assert isinstance(df, pd.DataFrame) |
| 77 | + assert "region" in df.columns |
| 78 | + skill_col = f"pred_alt_{metric.value}_skill" |
| 79 | + assert skill_col in df.columns |
| 80 | + |
| 81 | + |
| 82 | +def test_eval_skill_zero_division_grouped(sample_data): |
| 83 | + # Force reference metric to be zero |
| 84 | + sample_data["pred_ref"] = sample_data["obs"] |
| 85 | + metric = Metric.MEAN_ABSOLUTE_ERROR |
| 86 | + |
| 87 | + # Make obs == pred_ref so MAE is zero |
| 88 | + with pytest.raises(ZeroDivisionError): |
| 89 | + metric.eval_skill( |
| 90 | + sample_data, |
| 91 | + "obs", |
| 92 | + "pred_alt", |
| 93 | + "pred_ref", |
| 94 | + "weights", |
| 95 | + groupby=["region"], |
| 96 | + ) |
| 97 | + |
| 98 | + |
| 99 | +def test_eval_skill_zero_division_single(sample_data): |
| 100 | + # Force reference metric to be zero |
| 101 | + sample_data["pred_ref"] = sample_data["obs"] |
| 102 | + metric = Metric.MEAN_ABSOLUTE_ERROR |
| 103 | + with pytest.raises(ZeroDivisionError): |
| 104 | + metric.eval_skill(sample_data, "obs", "pred_alt", "pred_ref", "weights") |
| 105 | + |
| 106 | + |
| 107 | +def test_eval_single_unsupported_metric(sample_data): |
| 108 | + with pytest.raises(ValueError): |
| 109 | + fake = Metric("fake") |
| 110 | + fake._eval_single(sample_data, "obs", "pred", "weights") |
0 commit comments