-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy patheval.py
More file actions
110 lines (87 loc) · 3.82 KB
/
Copy patheval.py
File metadata and controls
110 lines (87 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import polars as pl
import yaml
from iup import CumulativeUptakeData, QuantileForecast, SampleForecast, eval
def eval_all_forecasts(
data: pl.DataFrame, pred: pl.DataFrame, config: dict
) -> pl.DataFrame:
"""
Calculates the evaluation metrics selected by config, by model and forecast start date.
-----------------------
Arguments:
data:
observed data with at least "time_end" and "estimate" columns
pred:
forecast data as sample distribution with at least "time_end", "sample_id", "model", "forecast_start" and "estimate",
config:
config file to specify the expected quantile from the sample distribution and evaluation metrics to calculate
Returns:
A pl.DataFrame with score name and score values, grouped by model, forecast start, quantile, and possibly other grouping factors
"""
model_names = pred["model"].unique()
forecast_starts = pred["forecast_start"].unique()
# group sample forecasts by date and grouping factors #
if config["groups"] is not None:
groups = config["groups"] + ["time_end"]
else:
groups = ["time_end"]
all_scores = pl.DataFrame()
for model in model_names:
for forecast_start in forecast_starts:
pred_by_start = SampleForecast(
CumulativeUptakeData(
pred.filter(
pl.col("model") == model,
pl.col("forecast_start") == forecast_start,
)
)
)
test = CumulativeUptakeData(
data.filter(
pl.col("time_end") >= forecast_start,
pl.col("time_end") <= config["forecasts"]["end_date"],
)
)
# 1. Convert sample forecast pred into quantiles
assert config["scores"]["quantiles"] is not None, (
"Quantiles of posterior prediction distribution must be specified in the config file."
)
for quantile in config["scores"]["quantiles"]:
summary_pred = QuantileForecast(
(
pred_by_start.group_by(groups)
.agg(pl.col("estimate").quantile(quantile))
.with_columns(quantile=quantile)
)
)
score_funcs = {}
if config["scores"]["difference_by_date"] is not None:
score_funcs = {
f"{eval.abs_diff.__name__}_{date}": eval.abs_diff(
date, pl.col("time_end")
)
for date in config["scores"]["difference_by_date"]
}
if config["scores"]["others"] is not None:
for score_fun_name in config["scores"]["others"]:
score_funcs[score_fun_name] = getattr(eval, score_fun_name)
scores = eval.summarize_score(
test, summary_pred, config["groups"], score_funcs
)
scores = scores.with_columns(
model=pl.lit(model),
)
all_scores = pl.concat([all_scores, scores])
return all_scores
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--config", help="config file", required=True)
p.add_argument("--forecasts", help="forecasts parquet", required=True)
p.add_argument("--data", help="observed data", required=True)
p.add_argument("--output", help="output scores parquet", required=True)
args = p.parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
pred = pl.read_parquet(args.forecasts)
data = pl.read_parquet(args.data)
eval_all_forecasts(data, pred, config).write_parquet(args.output)