-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy patheval.py
More file actions
51 lines (42 loc) · 1.66 KB
/
Copy patheval.py
File metadata and controls
51 lines (42 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import List
import polars as pl
def mspe(
obs: pl.DataFrame, pred: pl.DataFrame, grouping_factors: List[str]
) -> pl.DataFrame:
return (
pred.group_by(["model", "time_end", "forecast_start"] + grouping_factors)
.agg(pred_median=pl.col("estimate").median())
.join(obs, on=["time_end"] + grouping_factors, how="right")
.with_columns(score_value=(pl.col("estimate") - pl.col("pred_median")) ** 2)
.group_by(["model", "forecast_start"] + grouping_factors)
.agg(pl.col("score_value").mean())
.with_columns(score_fun=pl.lit("mspe"))
)
def eos_abs_diff(
obs: pl.DataFrame,
pred: pl.DataFrame,
grouping_factors: List[str],
) -> pl.DataFrame:
"""
Generate a function that calculates the absolute difference between
observed data and prediction for the last date in a season
Arguments:
selected_date: a datetime date object to specify which date to do the calculation
date_col: a polars column expression used to select the date
Return:
A function that takes two polars column expressions to do the calculation.
"""
assert "season" in grouping_factors
median_pred = pred.group_by(
["model", "time_end", "forecast_start"] + grouping_factors
).agg(pred_median=pl.col("estimate").median())
return (
obs.filter(
(pl.col("time_end") == pl.col("time_end").max()).over(grouping_factors)
)
.join(median_pred, on=["time_end"] + grouping_factors, how="left")
.with_columns(
score_value=(pl.col("estimate") - pl.col("pred_median")).abs(),
score_fun=pl.lit("eos_abs_diff"),
)
)