|
| 1 | +import argparse |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import altair as alt |
| 5 | +import polars as pl |
| 6 | +import yaml |
| 7 | +from plot_data import ( |
| 8 | + MEDIAN_ENCODINGS, |
| 9 | + TICK_KWARGS, |
| 10 | + add_medians, |
| 11 | + gather_n, |
| 12 | + month_order, |
| 13 | +) |
| 14 | + |
| 15 | +LINE_OPACITY = 0.4 |
| 16 | + |
| 17 | +# scores across seasons & states |
| 18 | + |
| 19 | +if __name__ == "__main__": |
| 20 | + p = argparse.ArgumentParser() |
| 21 | + p.add_argument("--scores", required=True) |
| 22 | + p.add_argument("--output_dir", required=True) |
| 23 | + p.add_argument("--config", required=True) |
| 24 | + |
| 25 | + args = p.parse_args() |
| 26 | + |
| 27 | + with open(args.config) as f: |
| 28 | + config = yaml.safe_load(f) |
| 29 | + |
| 30 | + scores = pl.read_parquet(args.scores) |
| 31 | + |
| 32 | + out_dir = Path(args.output_dir) |
| 33 | + out_dir.mkdir(parents=True, exist_ok=True) |
| 34 | + |
| 35 | + fit_scores = scores.filter( |
| 36 | + pl.col("forecast_date") == pl.col("forecast_date").max(), |
| 37 | + pl.col("score_fun") == pl.lit("mspe"), |
| 38 | + ).with_columns(pl.col("score_value").log()) |
| 39 | + |
| 40 | + sort_month = month_order(config["season"]["start_month"]) |
| 41 | + enc_x_month = alt.X("month", title=None, sort=sort_month) |
| 42 | + |
| 43 | + enc_y_mspe = alt.Y( |
| 44 | + "score_value", title="Score (Log(MSPE))", scale=alt.Scale(zero=False) |
| 45 | + ) |
| 46 | + |
| 47 | + alt.Chart( |
| 48 | + add_medians(fit_scores, group_by="season", value_col="score_value") |
| 49 | + ).mark_point().encode( |
| 50 | + alt.X("season", title=None), |
| 51 | + enc_y_mspe, |
| 52 | + *MEDIAN_ENCODINGS, |
| 53 | + ).save(out_dir / "score_by_season.svg") |
| 54 | + |
| 55 | + alt.Chart( |
| 56 | + add_medians(fit_scores, group_by="geography", value_col="score_value") |
| 57 | + ).mark_point().encode( |
| 58 | + alt.X( |
| 59 | + "geography", |
| 60 | + title=None, |
| 61 | + sort=alt.EncodingSortField("score_value", "median", "descending"), |
| 62 | + ), |
| 63 | + enc_y_mspe, |
| 64 | + *MEDIAN_ENCODINGS, |
| 65 | + ).save(out_dir / "score_by_geo.svg") |
| 66 | + |
| 67 | + # scores increasing through the season? |
| 68 | + # sis = score in season |
| 69 | + sis_data = scores.filter( |
| 70 | + pl.col("score_fun") == pl.lit("eos_abs_diff"), |
| 71 | + pl.col("season") == pl.col("season").max(), |
| 72 | + ).with_columns(month=pl.col("forecast_date").dt.to_string("%b")) |
| 73 | + |
| 74 | + sis_line = ( |
| 75 | + alt.Chart(sis_data) |
| 76 | + .mark_line(color="black", opacity=LINE_OPACITY) |
| 77 | + .encode( |
| 78 | + enc_x_month, |
| 79 | + alt.Y("score_value", title="Score (abs. end-of-season diff.)"), |
| 80 | + alt.Detail("geography"), |
| 81 | + ) |
| 82 | + ) |
| 83 | + |
| 84 | + sis_tick_base = alt.Chart( |
| 85 | + sis_data.filter(pl.col("forecast_date") == pl.col("forecast_date").max()) |
| 86 | + .sort("score_value") |
| 87 | + .pipe(gather_n, 5) |
| 88 | + ).encode( |
| 89 | + enc_x_month, |
| 90 | + alt.Y("score_value"), |
| 91 | + alt.Text("geography"), |
| 92 | + ) |
| 93 | + |
| 94 | + sis_tick = sis_tick_base.mark_point(**TICK_KWARGS) |
| 95 | + sis_text = sis_tick_base.mark_text(align="left", dx=15) |
| 96 | + |
| 97 | + (sis_line + sis_tick + sis_text).save(out_dir / "scores_increasing.svg") |
| 98 | + |
| 99 | + ## summary of end-of-season abs diff ## |
| 100 | + alt.Chart(sis_data).mark_boxplot(extent="min-max").encode( |
| 101 | + enc_x_month, |
| 102 | + alt.Y("score_value", title="Score (abs. end-of-season diff.)"), |
| 103 | + ).save(out_dir / "eos_abs_diff_summary.svg") |
| 104 | + |
| 105 | + # end-of-season abs diff by state # |
| 106 | + state_sort = ( |
| 107 | + sis_data.filter(pl.col("month") == "Jul") |
| 108 | + .sort(pl.col("score_value")) |
| 109 | + .select("geography") |
| 110 | + .to_numpy() |
| 111 | + .ravel() |
| 112 | + .tolist() |
| 113 | + ) |
| 114 | + alt.Chart(sis_data).mark_line(color="black", opacity=LINE_OPACITY).encode( |
| 115 | + enc_x_month, |
| 116 | + alt.Y("score_value", title="Score (abs. end-of-season diff.)"), |
| 117 | + alt.Facet("geography", columns=9, sort=state_sort), |
| 118 | + ).save(out_dir / "eos_abs_diff_by_state.svg") |
| 119 | + |
| 120 | + # score vs. forecast |
| 121 | + avg_fit = ( |
| 122 | + fit_scores.group_by(["model", "geography"]) |
| 123 | + .agg(pl.col("score_value").median()) |
| 124 | + .rename({"score_value": "fit_score"}) |
| 125 | + ) |
| 126 | + fc_goodness = ( |
| 127 | + scores.filter( |
| 128 | + pl.col("score_fun") == pl.lit("eos_abs_diff"), |
| 129 | + pl.col("season") == pl.col("season").max(), |
| 130 | + pl.col("forecast_date") == pl.col("forecast_date").min(), |
| 131 | + ) |
| 132 | + .select(["geography", "model", "score_value"]) |
| 133 | + .rename({"score_value": "fc_score"}) |
| 134 | + ) |
| 135 | + |
| 136 | + alt.Chart( |
| 137 | + avg_fit.join(fc_goodness, on=["model", "geography"], how="inner") |
| 138 | + ).mark_point(color="black").encode( |
| 139 | + alt.X( |
| 140 | + "fit_score", |
| 141 | + title="Fit score (median MSPE over seasons)", |
| 142 | + scale=alt.Scale(zero=False), |
| 143 | + ), |
| 144 | + alt.Y("fc_score", title="Forecast score (abs. end-of-season diff.)"), |
| 145 | + ).save(out_dir / "forecast_fit_compare.svg") |
0 commit comments