|
1 | | -import datetime as dt |
2 | | -from typing import Callable, Dict, List |
| 1 | +from typing import List |
3 | 2 |
|
4 | 3 | import polars as pl |
5 | 4 |
|
6 | | -from iup import CumulativeUptakeData, QuantileForecast |
7 | 5 |
|
8 | | - |
9 | | -###### evaluation metrics ##### |
10 | | -def check_date_match( |
11 | | - data: CumulativeUptakeData, pred: QuantileForecast, groups: List[str] | None |
12 | | -): |
13 | | - """ |
14 | | - Check the dates between data and pred. |
15 | | - Dates must be 1-on-1 equal and no duplicate. |
16 | | - ---------------------- |
17 | | -
|
18 | | - Parameters |
19 | | - data: |
20 | | - The observed data used for modeling. Should be CumulativeUptakeData |
21 | | - pred: |
22 | | - The forecast made by model. Can be QuantileForecast or PointForecast |
23 | | - groups: |
24 | | - A list of grouping factors |
25 | | -
|
26 | | - Return |
27 | | - Error if conditions fail to meet. |
28 | | -
|
29 | | - """ |
30 | | - # sort data and pred by date # |
31 | | - groups_and_time = ["time_end"] + groups if groups is not None else ["time_end"] |
32 | | - |
33 | | - data = CumulativeUptakeData(data.sort(groups_and_time)) |
34 | | - pred = QuantileForecast(pred.sort(groups_and_time)) |
35 | | - |
36 | | - if groups is not None: |
37 | | - # check if the forecast and the data have the same forecast dates for each level in each group # |
38 | | - for group in groups: |
39 | | - for group_value in data[group].unique().to_list(): |
40 | | - data_times = data.filter(pl.col(group) == group_value)[ |
41 | | - "time_end" |
42 | | - ].to_list() |
43 | | - pred_times = pred.filter(pl.col(group) == group_value)[ |
44 | | - "time_end" |
45 | | - ].to_list() |
46 | | - assert set(data_times) == set(pred_times), ( |
47 | | - "The forecast and the data should have the same forecast dates for each group. " |
48 | | - f"Instead, we have {data_times=} and {pred_times=}" |
49 | | - ) |
50 | | - else: |
51 | | - assert (data["time_end"] == pred["time_end"]).all(), ( |
52 | | - "The forecast and the data should have the same forecast dates. " |
53 | | - f"Instead, we have {data['time_end']=} and {pred['time_end']=}." |
54 | | - ) |
55 | | - |
56 | | - if groups is not None: |
57 | | - # check across all the combinations of groups # |
58 | | - check = data.with_columns(dup=pl.col("time_end").is_duplicated().over(groups)) |
59 | | - assert not (check["dup"].any()), "Duplicated dates are found in data" |
60 | | - else: |
61 | | - assert not (any(data["time_end"].is_duplicated())), ( |
62 | | - "Duplicated dates are found in data." |
63 | | - ) |
64 | | - |
65 | | - |
66 | | -def summarize_score( |
67 | | - data: CumulativeUptakeData, |
68 | | - pred: QuantileForecast, |
69 | | - groups: List[str] | None, |
70 | | - score_funs: Dict[str, Callable], |
| 6 | +def mspe( |
| 7 | + obs: pl.DataFrame, pred: pl.DataFrame, grouping_factors: List[str] |
71 | 8 | ) -> pl.DataFrame: |
72 | | - """ |
73 | | - Calculate score between observed data and forecast. |
74 | | - ---------------------- |
75 | | -
|
76 | | - Parameters |
77 | | - data: |
78 | | - The observed data used for modeling. Should be CumulativeUptakeData |
79 | | - pred: |
80 | | - The forecast made by model. Can be QuantileForecast or PointForecast |
81 | | - groups: |
82 | | - A list of grouping factors, specified in config file. |
83 | | - score_funs: |
84 | | - A dictionary of scoring functions. The key is the name of the score, and the value |
85 | | - is the scoring function. |
86 | | -
|
87 | | - Return |
88 | | - A pl.DataFrame of scores with information including score name and score values, grouped by quantile, forecast |
89 | | -
|
90 | | - """ |
91 | | - |
92 | | - check_date_match(data, pred, groups) |
93 | | - assert isinstance(data, CumulativeUptakeData) |
94 | | - assert isinstance(pred, QuantileForecast) |
95 | | - |
96 | | - assert len(pred["quantile"].unique()) == 1, ( |
97 | | - "The prediction should only have one quantile." |
98 | | - ) |
99 | | - |
100 | | - if groups is None: |
101 | | - columns_to_join = ["time_end"] |
102 | | - else: |
103 | | - columns_to_join = ["time_end"] + groups |
104 | | - |
105 | | - joined_df = ( |
106 | | - pl.DataFrame(data) |
107 | | - .join(pl.DataFrame(pred), on=columns_to_join, how="inner", validate="1:1") |
108 | | - .rename({"estimate": "data", "estimate_right": "pred"}) |
| 9 | + return ( |
| 10 | + pred.group_by(["model", "time_end", "forecast_start"] + grouping_factors) |
| 11 | + .agg(pred_median=pl.col("estimate").median()) |
| 12 | + .join(obs, on=["time_end"] + grouping_factors, how="right") |
| 13 | + .with_columns(score_value=(pl.col("estimate") - pl.col("pred_median")) ** 2) |
| 14 | + .group_by(["model", "forecast_start"] + grouping_factors) |
| 15 | + .agg(pl.col("score_value").mean()) |
| 16 | + .with_columns(score_fun=pl.lit("mspe")) |
109 | 17 | ) |
110 | 18 |
|
111 | | - all_scores = pl.DataFrame() |
112 | | - for score_name in score_funs: |
113 | | - score = joined_df.group_by(groups).agg( |
114 | | - score_name=pl.lit(score_name), |
115 | | - score_value=score_funs[score_name](pl.col("data"), pl.col("pred")), |
116 | | - ) |
117 | | - |
118 | | - if not score.is_empty(): |
119 | | - if isinstance(score["score_value"][0], pl.Series): |
120 | | - score = score.with_columns( |
121 | | - pl.col("score_value").list.drop_nulls().explode() |
122 | | - ) |
123 | | - else: |
124 | | - score = score.with_columns(score_value=None) |
125 | | - |
126 | | - score = score.with_columns( |
127 | | - quantile=joined_df["quantile"].first(), |
128 | | - forecast_start=joined_df["time_end"].min(), |
129 | | - forecast_end=joined_df["time_end"].max(), |
130 | | - ) |
131 | | - |
132 | | - all_scores = pl.concat([all_scores, score]) |
133 | | - |
134 | | - return all_scores |
135 | | - |
136 | | - |
137 | | -def mspe(x: pl.Expr, y: pl.Expr) -> pl.Expr: |
138 | | - """ |
139 | | - Calculate Mean Squared Prediction Error with polars column expression |
140 | | - --------------------- |
141 | | - Arguments: |
142 | | - x: either observed data or predictions |
143 | | - y: either observed data or predictions |
144 | | - Return: |
145 | | - Mean Squared Prediction Error as a polars column expression |
146 | | -
|
147 | | - """ |
148 | | - return ((x - y) ** 2).mean() |
149 | | - |
150 | 19 |
|
151 | | -def abs_diff( |
152 | | - selected_date: dt.date, date_col: pl.Expr |
153 | | -) -> Callable[[pl.Expr, pl.Expr], pl.Expr]: |
| 20 | +def eos_abs_diff( |
| 21 | + obs: pl.DataFrame, |
| 22 | + pred: pl.DataFrame, |
| 23 | + grouping_factors: List[str], |
| 24 | +) -> pl.DataFrame: |
154 | 25 | """ |
155 | 26 | Generate a function that calculates the absolute difference between |
156 | | - observed data and prediction on a certain date. |
157 | | - ---------------------- |
| 27 | + observed data and prediction for the last date in a season |
| 28 | +
|
158 | 29 | Arguments: |
159 | 30 | selected_date: a datetime date object to specify which date to do the calculation |
160 | 31 | date_col: a polars column expression used to select the date |
161 | 32 |
|
162 | 33 | Return: |
163 | 34 | A function that takes two polars column expressions to do the calculation. |
164 | 35 | """ |
| 36 | + assert "season" in grouping_factors |
165 | 37 |
|
166 | | - lit_date = pl.lit(selected_date) |
| 38 | + median_pred = pred.group_by( |
| 39 | + ["model", "time_end", "forecast_start"] + grouping_factors |
| 40 | + ).agg(pred_median=pl.col("estimate").median()) |
167 | 41 |
|
168 | | - def f(x: pl.Expr, y: pl.Expr) -> pl.Expr: |
169 | | - """ |
170 | | - Calculate the absolute difference between two polars column expressions |
171 | | - ----------------------- |
172 | | - Arguments: |
173 | | - x: either observed data or predictions |
174 | | - y: either observed data or predictions |
175 | | -
|
176 | | - Return: |
177 | | - A polars column expression that returns the absolute difference at the certain date, otherwise None |
178 | | - """ |
179 | | - return pl.when(date_col == lit_date).then((x - y).abs()).otherwise(None) |
180 | | - |
181 | | - return f |
| 42 | + return ( |
| 43 | + obs.filter( |
| 44 | + (pl.col("time_end") == pl.col("time_end").max()).over(grouping_factors) |
| 45 | + ) |
| 46 | + .join(median_pred, on=["time_end"] + grouping_factors, how="left") |
| 47 | + .with_columns( |
| 48 | + score_value=(pl.col("estimate") - pl.col("pred_median")).abs(), |
| 49 | + score_fun=pl.lit("eos_abs_diff"), |
| 50 | + ) |
| 51 | + ) |
0 commit comments