Skip to content

Commit 62ecc8d

Browse files
authored
Refactor forecasts and draft plots (#230)
- When predicting, do not distinguish between "forecast" and "postcheck." Instead, make predictions for every date & group you have observations. This makes the "scaffold" concept a lot simpler in terms of implementation. - When doing analyses, we distinguish between prospective "forecast" and retrospective "fit" scores.
1 parent 9c4b8ee commit 62ecc8d

19 files changed

Lines changed: 389 additions & 699 deletions

Makefile

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,29 @@ OUTPUT_DIR = output/$(RUN_ID)
77
CONFIG_COPY = $(OUTPUT_DIR)/config.yaml
88
DATA = $(OUTPUT_DIR)/nis.parquet
99
FITS = $(OUTPUT_DIR)/fits.pkl
10-
FORECASTS = $(OUTPUT_DIR)/forecasts.parquet
10+
PREDS = $(OUTPUT_DIR)/predictions.parquet
1111
DIAGNOSTICS = $(OUTPUT_DIR)/diagnostics/status.txt
1212
SCORES = $(OUTPUT_DIR)/scores.parquet
1313

14-
DATA_PLOT = $(OUTPUT_DIR)/plots/data_national.png
14+
PLOT_DATA = $(OUTPUT_DIR)/plots/data_one_season_by_state.png
15+
PLOT_PREDS = $(OUTPUT_DIR)/plots/forecast_example.png
1516

1617

1718
.PHONY: clean viz
1819

19-
all: $(SETTINGS) $(DATA) $(FITS) $(DIAGNOSTICS) $(FORECASTS) $(SCORES) $(DATA_PLOT)
20+
all: $(CONFIG_COPY) $(DATA) $(FITS) $(DIAGNOSTICS) $(PREDS) $(SCORES) $(PLOT_DATA) $(PLOT_PREDS)
2021

2122
viz:
2223
streamlit run scripts/viz.py -- \
23-
--data=$(DATA) --forecasts=$(FORECASTS) --scores=$(SCORES) --config=$(CONFIG)
24+
--data=$(DATA) --preds=$(PREDS) --scores=$(SCORES) --config=$(CONFIG)
2425

25-
$(SCORES): scripts/eval.py $(FORECASTS) $(DATA)
26-
python $< --forecasts=$(FORECASTS) --data=$(DATA) --config=$(CONFIG) --output=$@
26+
$(SCORES): scripts/eval.py $(PREDS) $(DATA)
27+
python $< --preds=$(PREDS) --data=$(DATA) --config=$(CONFIG) --output=$@
2728

28-
$(FORECASTS): scripts/forecast.py $(DATA) $(FITS) $(CONFIG)
29+
$(PLOT_PREDS): scripts/plot_preds.py $(CONFIG) $(DATA) $(PREDS) $(SCORES)
30+
python $< --config=$(CONFIG) --data=$(DATA) --preds=$(PREDS) --scores=$(SCORES) --output=$@
31+
32+
$(PREDS): scripts/predict.py $(DATA) $(FITS) $(CONFIG)
2933
python $< --data=$(DATA) --fits=$(FITS) --config=$(CONFIG) --output=$@
3034

3135
$(DIAGNOSTICS): scripts/diagnostics.py $(FITS) $(CONFIG)
@@ -34,15 +38,15 @@ $(DIAGNOSTICS): scripts/diagnostics.py $(FITS) $(CONFIG)
3438
$(FITS): scripts/fit.py $(DATA) $(CONFIG)
3539
python $< --data=$(DATA) --config=$(CONFIG) --output=$@
3640

37-
$(DATA_PLOT): scripts/describe_data.py $(DATA)
38-
python $< --input=$(DATA) --output_dir=$(OUTPUT_DIR)/plots
41+
$(PLOT_DATA): scripts/plot_data.py $(DATA)
42+
python $< --config=$(CONFIG) --data=$(DATA) --output=$@
3943

4044
$(DATA): scripts/preprocess.py $(RAW_DATA) $(CONFIG)
4145
python $< --config=$(CONFIG) --input=$(RAW_DATA) --output=$@
4246

4347
$(CONFIG_COPY): $(CONFIG)
4448
mkdir -p $(OUTPUT_DIR)
45-
cp $(CONFIG) $(CONFIG_COPY)
49+
cp $(CONFIG) $@
4650

4751
clean:
4852
rm -rf $(OUTPUT_DIR)

iup/__init__.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import datetime as dt
21
from typing import List
32

43
import polars as pl
@@ -46,30 +45,6 @@ def validate(self):
4645
"""
4746
self.assert_in_schema({"time_end": pl.Date, "estimate": pl.Float64})
4847

49-
@classmethod
50-
def split_train_test(cls, uptake_data: "UptakeData", split_date: dt.date) -> tuple:
51-
"""
52-
Subset a training or test set from data.
53-
54-
Parameters
55-
uptake_data: UptakeData
56-
cumulative or incident uptake data
57-
split_date: dt.date
58-
date at which to split data
59-
60-
Returns
61-
pl.DataFrames
62-
training and test portions of the cumulative or uptake data
63-
64-
Details
65-
Training data are before the start date; test data are on or after.
66-
Infers what type of UptakeData to return from what type was given.
67-
"""
68-
train = uptake_data.sort("time_end").filter(pl.col("time_end") < split_date)
69-
test = uptake_data.sort("time_end").filter(pl.col("time_end") >= split_date)
70-
71-
return type(uptake_data)(train), type(uptake_data)(test)
72-
7348

7449
class IncidentUptakeData(UptakeData):
7550
def validate(self):

iup/diagnostics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def print_posterior_dist(model: iup.models.UptakeModel) -> pl.DataFrame:
6565
def print_model_summary(model: iup.models.UptakeModel) -> pl.DataFrame:
6666
idata = az.from_numpyro(model.mcmc)
6767
summary_pd = az.summary(idata)
68-
print(summary_pd.shape)
6968
summary = pl.DataFrame(summary_pd)
7069
summary = summary.with_columns(params=pl.Series(summary_pd.index)).select(
7170
["params"] + [col for col in summary.columns if col != "params"]

iup/eval.py

Lines changed: 32 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -1,181 +1,51 @@
1-
import datetime as dt
2-
from typing import Callable, Dict, List
1+
from typing import List
32

43
import polars as pl
54

6-
from iup import CumulativeUptakeData, QuantileForecast
75

8-
9-
###### evaluation metrics #####
10-
def check_date_match(
11-
data: CumulativeUptakeData, pred: QuantileForecast, groups: List[str] | None
12-
):
13-
"""
14-
Check the dates between data and pred.
15-
Dates must be 1-on-1 equal and no duplicate.
16-
----------------------
17-
18-
Parameters
19-
data:
20-
The observed data used for modeling. Should be CumulativeUptakeData
21-
pred:
22-
The forecast made by model. Can be QuantileForecast or PointForecast
23-
groups:
24-
A list of grouping factors
25-
26-
Return
27-
Error if conditions fail to meet.
28-
29-
"""
30-
# sort data and pred by date #
31-
groups_and_time = ["time_end"] + groups if groups is not None else ["time_end"]
32-
33-
data = CumulativeUptakeData(data.sort(groups_and_time))
34-
pred = QuantileForecast(pred.sort(groups_and_time))
35-
36-
if groups is not None:
37-
# check if the forecast and the data have the same forecast dates for each level in each group #
38-
for group in groups:
39-
for group_value in data[group].unique().to_list():
40-
data_times = data.filter(pl.col(group) == group_value)[
41-
"time_end"
42-
].to_list()
43-
pred_times = pred.filter(pl.col(group) == group_value)[
44-
"time_end"
45-
].to_list()
46-
assert set(data_times) == set(pred_times), (
47-
"The forecast and the data should have the same forecast dates for each group. "
48-
f"Instead, we have {data_times=} and {pred_times=}"
49-
)
50-
else:
51-
assert (data["time_end"] == pred["time_end"]).all(), (
52-
"The forecast and the data should have the same forecast dates. "
53-
f"Instead, we have {data['time_end']=} and {pred['time_end']=}."
54-
)
55-
56-
if groups is not None:
57-
# check across all the combinations of groups #
58-
check = data.with_columns(dup=pl.col("time_end").is_duplicated().over(groups))
59-
assert not (check["dup"].any()), "Duplicated dates are found in data"
60-
else:
61-
assert not (any(data["time_end"].is_duplicated())), (
62-
"Duplicated dates are found in data."
63-
)
64-
65-
66-
def summarize_score(
67-
data: CumulativeUptakeData,
68-
pred: QuantileForecast,
69-
groups: List[str] | None,
70-
score_funs: Dict[str, Callable],
6+
def mspe(
7+
obs: pl.DataFrame, pred: pl.DataFrame, grouping_factors: List[str]
718
) -> pl.DataFrame:
72-
"""
73-
Calculate score between observed data and forecast.
74-
----------------------
75-
76-
Parameters
77-
data:
78-
The observed data used for modeling. Should be CumulativeUptakeData
79-
pred:
80-
The forecast made by model. Can be QuantileForecast or PointForecast
81-
groups:
82-
A list of grouping factors, specified in config file.
83-
score_funs:
84-
A dictionary of scoring functions. The key is the name of the score, and the value
85-
is the scoring function.
86-
87-
Return
88-
A pl.DataFrame of scores with information including score name and score values, grouped by quantile, forecast
89-
90-
"""
91-
92-
check_date_match(data, pred, groups)
93-
assert isinstance(data, CumulativeUptakeData)
94-
assert isinstance(pred, QuantileForecast)
95-
96-
assert len(pred["quantile"].unique()) == 1, (
97-
"The prediction should only have one quantile."
98-
)
99-
100-
if groups is None:
101-
columns_to_join = ["time_end"]
102-
else:
103-
columns_to_join = ["time_end"] + groups
104-
105-
joined_df = (
106-
pl.DataFrame(data)
107-
.join(pl.DataFrame(pred), on=columns_to_join, how="inner", validate="1:1")
108-
.rename({"estimate": "data", "estimate_right": "pred"})
9+
return (
10+
pred.group_by(["model", "time_end", "forecast_start"] + grouping_factors)
11+
.agg(pred_median=pl.col("estimate").median())
12+
.join(obs, on=["time_end"] + grouping_factors, how="right")
13+
.with_columns(score_value=(pl.col("estimate") - pl.col("pred_median")) ** 2)
14+
.group_by(["model", "forecast_start"] + grouping_factors)
15+
.agg(pl.col("score_value").mean())
16+
.with_columns(score_fun=pl.lit("mspe"))
10917
)
11018

111-
all_scores = pl.DataFrame()
112-
for score_name in score_funs:
113-
score = joined_df.group_by(groups).agg(
114-
score_name=pl.lit(score_name),
115-
score_value=score_funs[score_name](pl.col("data"), pl.col("pred")),
116-
)
117-
118-
if not score.is_empty():
119-
if isinstance(score["score_value"][0], pl.Series):
120-
score = score.with_columns(
121-
pl.col("score_value").list.drop_nulls().explode()
122-
)
123-
else:
124-
score = score.with_columns(score_value=None)
125-
126-
score = score.with_columns(
127-
quantile=joined_df["quantile"].first(),
128-
forecast_start=joined_df["time_end"].min(),
129-
forecast_end=joined_df["time_end"].max(),
130-
)
131-
132-
all_scores = pl.concat([all_scores, score])
133-
134-
return all_scores
135-
136-
137-
def mspe(x: pl.Expr, y: pl.Expr) -> pl.Expr:
138-
"""
139-
Calculate Mean Squared Prediction Error with polars column expression
140-
---------------------
141-
Arguments:
142-
x: either observed data or predictions
143-
y: either observed data or predictions
144-
Return:
145-
Mean Squared Prediction Error as a polars column expression
146-
147-
"""
148-
return ((x - y) ** 2).mean()
149-
15019

151-
def abs_diff(
152-
selected_date: dt.date, date_col: pl.Expr
153-
) -> Callable[[pl.Expr, pl.Expr], pl.Expr]:
20+
def eos_abs_diff(
21+
obs: pl.DataFrame,
22+
pred: pl.DataFrame,
23+
grouping_factors: List[str],
24+
) -> pl.DataFrame:
15425
"""
15526
Generate a function that calculates the absolute difference between
156-
observed data and prediction on a certain date.
157-
----------------------
27+
observed data and prediction for the last date in a season
28+
15829
Arguments:
15930
selected_date: a datetime date object to specify which date to do the calculation
16031
date_col: a polars column expression used to select the date
16132
16233
Return:
16334
A function that takes two polars column expressions to do the calculation.
16435
"""
36+
assert "season" in grouping_factors
16537

166-
lit_date = pl.lit(selected_date)
38+
median_pred = pred.group_by(
39+
["model", "time_end", "forecast_start"] + grouping_factors
40+
).agg(pred_median=pl.col("estimate").median())
16741

168-
def f(x: pl.Expr, y: pl.Expr) -> pl.Expr:
169-
"""
170-
Calculate the absolute difference between two polars column expressions
171-
-----------------------
172-
Arguments:
173-
x: either observed data or predictions
174-
y: either observed data or predictions
175-
176-
Return:
177-
A polars column expression that returns the absolute difference at the certain date, otherwise None
178-
"""
179-
return pl.when(date_col == lit_date).then((x - y).abs()).otherwise(None)
180-
181-
return f
42+
return (
43+
obs.filter(
44+
(pl.col("time_end") == pl.col("time_end").max()).over(grouping_factors)
45+
)
46+
.join(median_pred, on=["time_end"] + grouping_factors, how="left")
47+
.with_columns(
48+
score_value=(pl.col("estimate") - pl.col("pred_median")).abs(),
49+
score_fun=pl.lit("eos_abs_diff"),
50+
)
51+
)

0 commit comments

Comments
 (0)