Skip to content

Commit 337554e

Browse files
committed
single script for score plotting
1 parent 9d06618 commit 337554e

1 file changed

Lines changed: 145 additions & 0 deletions

File tree

scripts/plot_scores.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import argparse
2+
from pathlib import Path
3+
4+
import altair as alt
5+
import polars as pl
6+
import yaml
7+
from plot_data import (
8+
MEDIAN_ENCODINGS,
9+
TICK_KWARGS,
10+
add_medians,
11+
gather_n,
12+
month_order,
13+
)
14+
15+
LINE_OPACITY = 0.4
16+
17+
# scores across seasons & states
18+
19+
if __name__ == "__main__":
20+
p = argparse.ArgumentParser()
21+
p.add_argument("--scores", required=True)
22+
p.add_argument("--output_dir", required=True)
23+
p.add_argument("--config", required=True)
24+
25+
args = p.parse_args()
26+
27+
with open(args.config) as f:
28+
config = yaml.safe_load(f)
29+
30+
scores = pl.read_parquet(args.scores)
31+
32+
out_dir = Path(args.output_dir)
33+
out_dir.mkdir(parents=True, exist_ok=True)
34+
35+
fit_scores = scores.filter(
36+
pl.col("forecast_date") == pl.col("forecast_date").max(),
37+
pl.col("score_fun") == pl.lit("mspe"),
38+
).with_columns(pl.col("score_value").log())
39+
40+
sort_month = month_order(config["season"]["start_month"])
41+
enc_x_month = alt.X("month", title=None, sort=sort_month)
42+
43+
enc_y_mspe = alt.Y(
44+
"score_value", title="Score (Log(MSPE))", scale=alt.Scale(zero=False)
45+
)
46+
47+
alt.Chart(
48+
add_medians(fit_scores, group_by="season", value_col="score_value")
49+
).mark_point().encode(
50+
alt.X("season", title=None),
51+
enc_y_mspe,
52+
*MEDIAN_ENCODINGS,
53+
).save(out_dir / "score_by_season.svg")
54+
55+
alt.Chart(
56+
add_medians(fit_scores, group_by="geography", value_col="score_value")
57+
).mark_point().encode(
58+
alt.X(
59+
"geography",
60+
title=None,
61+
sort=alt.EncodingSortField("score_value", "median", "descending"),
62+
),
63+
enc_y_mspe,
64+
*MEDIAN_ENCODINGS,
65+
).save(out_dir / "score_by_geo.svg")
66+
67+
# scores increasing through the season?
68+
# sis = score in season
69+
sis_data = scores.filter(
70+
pl.col("score_fun") == pl.lit("eos_abs_diff"),
71+
pl.col("season") == pl.col("season").max(),
72+
).with_columns(month=pl.col("forecast_date").dt.to_string("%b"))
73+
74+
sis_line = (
75+
alt.Chart(sis_data)
76+
.mark_line(color="black", opacity=LINE_OPACITY)
77+
.encode(
78+
enc_x_month,
79+
alt.Y("score_value", title="Score (abs. end-of-season diff.)"),
80+
alt.Detail("geography"),
81+
)
82+
)
83+
84+
sis_tick_base = alt.Chart(
85+
sis_data.filter(pl.col("forecast_date") == pl.col("forecast_date").max())
86+
.sort("score_value")
87+
.pipe(gather_n, 5)
88+
).encode(
89+
enc_x_month,
90+
alt.Y("score_value"),
91+
alt.Text("geography"),
92+
)
93+
94+
sis_tick = sis_tick_base.mark_point(**TICK_KWARGS)
95+
sis_text = sis_tick_base.mark_text(align="left", dx=15)
96+
97+
(sis_line + sis_tick + sis_text).save(out_dir / "scores_increasing.svg")
98+
99+
## summary of end-of-season abs diff ##
100+
alt.Chart(sis_data).mark_boxplot(extent="min-max").encode(
101+
enc_x_month,
102+
alt.Y("score_value", title="Score (abs. end-of-season diff.)"),
103+
).save(out_dir / "eos_abs_diff_summary.svg")
104+
105+
# end-of-season abs diff by state #
106+
state_sort = (
107+
sis_data.filter(pl.col("month") == "Jul")
108+
.sort(pl.col("score_value"))
109+
.select("geography")
110+
.to_numpy()
111+
.ravel()
112+
.tolist()
113+
)
114+
alt.Chart(sis_data).mark_line(color="black", opacity=LINE_OPACITY).encode(
115+
enc_x_month,
116+
alt.Y("score_value", title="Score (abs. end-of-season diff.)"),
117+
alt.Facet("geography", columns=9, sort=state_sort),
118+
).save(out_dir / "eos_abs_diff_by_state.svg")
119+
120+
# score vs. forecast
121+
avg_fit = (
122+
fit_scores.group_by(["model", "geography"])
123+
.agg(pl.col("score_value").median())
124+
.rename({"score_value": "fit_score"})
125+
)
126+
fc_goodness = (
127+
scores.filter(
128+
pl.col("score_fun") == pl.lit("eos_abs_diff"),
129+
pl.col("season") == pl.col("season").max(),
130+
pl.col("forecast_date") == pl.col("forecast_date").min(),
131+
)
132+
.select(["geography", "model", "score_value"])
133+
.rename({"score_value": "fc_score"})
134+
)
135+
136+
alt.Chart(
137+
avg_fit.join(fc_goodness, on=["model", "geography"], how="inner")
138+
).mark_point(color="black").encode(
139+
alt.X(
140+
"fit_score",
141+
title="Fit score (median MSPE over seasons)",
142+
scale=alt.Scale(zero=False),
143+
),
144+
alt.Y("fc_score", title="Forecast score (abs. end-of-season diff.)"),
145+
).save(out_dir / "forecast_fit_compare.svg")

0 commit comments

Comments
 (0)