Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,29 @@ OUTPUT_DIR = output/$(RUN_ID)
CONFIG_COPY = $(OUTPUT_DIR)/config.yaml
DATA = $(OUTPUT_DIR)/nis.parquet
FITS = $(OUTPUT_DIR)/fits.pkl
FORECASTS = $(OUTPUT_DIR)/forecasts.parquet
PREDS = $(OUTPUT_DIR)/predictions.parquet
DIAGNOSTICS = $(OUTPUT_DIR)/diagnostics/status.txt
SCORES = $(OUTPUT_DIR)/scores.parquet

DATA_PLOT = $(OUTPUT_DIR)/plots/data_national.png
PLOT_DATA = $(OUTPUT_DIR)/plots/data_one_season_by_state.png
PLOT_PREDS = $(OUTPUT_DIR)/plots/forecast_example.png


.PHONY: clean viz

all: $(SETTINGS) $(DATA) $(FITS) $(DIAGNOSTICS) $(FORECASTS) $(SCORES) $(DATA_PLOT)
all: $(CONFIG_COPY) $(DATA) $(FITS) $(DIAGNOSTICS) $(PREDS) $(SCORES) $(PLOT_DATA) $(PLOT_PREDS)

viz:
streamlit run scripts/viz.py -- \
--data=$(DATA) --forecasts=$(FORECASTS) --scores=$(SCORES) --config=$(CONFIG)
--data=$(DATA) --preds=$(PREDS) --scores=$(SCORES) --config=$(CONFIG)

$(SCORES): scripts/eval.py $(FORECASTS) $(DATA)
python $< --forecasts=$(FORECASTS) --data=$(DATA) --config=$(CONFIG) --output=$@
$(SCORES): scripts/eval.py $(PREDS) $(DATA)
python $< --preds=$(PREDS) --data=$(DATA) --config=$(CONFIG) --output=$@

$(FORECASTS): scripts/forecast.py $(DATA) $(FITS) $(CONFIG)
$(PLOT_PREDS): scripts/plot_preds.py $(CONFIG) $(DATA) $(PREDS) $(SCORES)
python $< --config=$(CONFIG) --data=$(DATA) --preds=$(PREDS) --scores=$(SCORES) --output=$@

$(PREDS): scripts/predict.py $(DATA) $(FITS) $(CONFIG)
python $< --data=$(DATA) --fits=$(FITS) --config=$(CONFIG) --output=$@

$(DIAGNOSTICS): scripts/diagnostics.py $(FITS) $(CONFIG)
Expand All @@ -34,15 +38,15 @@ $(DIAGNOSTICS): scripts/diagnostics.py $(FITS) $(CONFIG)
$(FITS): scripts/fit.py $(DATA) $(CONFIG)
python $< --data=$(DATA) --config=$(CONFIG) --output=$@

$(DATA_PLOT): scripts/describe_data.py $(DATA)
python $< --input=$(DATA) --output_dir=$(OUTPUT_DIR)/plots
$(PLOT_DATA): scripts/plot_data.py $(DATA)
python $< --config=$(CONFIG) --data=$(DATA) --output=$@

$(DATA): scripts/preprocess.py $(RAW_DATA) $(CONFIG)
python $< --config=$(CONFIG) --input=$(RAW_DATA) --output=$@

$(CONFIG_COPY): $(CONFIG)
mkdir -p $(OUTPUT_DIR)
cp $(CONFIG) $(CONFIG_COPY)
cp $(CONFIG) $@

clean:
rm -rf $(OUTPUT_DIR)
25 changes: 0 additions & 25 deletions iup/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime as dt
from typing import List

import polars as pl
Expand Down Expand Up @@ -46,30 +45,6 @@ def validate(self):
"""
self.assert_in_schema({"time_end": pl.Date, "estimate": pl.Float64})

@classmethod
def split_train_test(cls, uptake_data: "UptakeData", split_date: dt.date) -> tuple:
"""
Subset a training or test set from data.

Parameters
uptake_data: UptakeData
cumulative or incident uptake data
split_date: dt.date
date at which to split data

Returns
pl.DataFrames
training and test portions of the cumulative or uptake data

Details
Training data are before the start date; test data are on or after.
Infers what type of UptakeData to return from what type was given.
"""
train = uptake_data.sort("time_end").filter(pl.col("time_end") < split_date)
test = uptake_data.sort("time_end").filter(pl.col("time_end") >= split_date)

return type(uptake_data)(train), type(uptake_data)(test)


class IncidentUptakeData(UptakeData):
def validate(self):
Expand Down
1 change: 0 additions & 1 deletion iup/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def print_posterior_dist(model: iup.models.UptakeModel) -> pl.DataFrame:
def print_model_summary(model: iup.models.UptakeModel) -> pl.DataFrame:
idata = az.from_numpyro(model.mcmc)
summary_pd = az.summary(idata)
print(summary_pd.shape)
summary = pl.DataFrame(summary_pd)
summary = summary.with_columns(params=pl.Series(summary_pd.index)).select(
["params"] + [col for col in summary.columns if col != "params"]
Expand Down
194 changes: 32 additions & 162 deletions iup/eval.py
Original file line number Diff line number Diff line change
@@ -1,181 +1,51 @@
import datetime as dt
from typing import Callable, Dict, List
from typing import List

import polars as pl

from iup import CumulativeUptakeData, QuantileForecast


###### evaluation metrics #####
def check_date_match(
data: CumulativeUptakeData, pred: QuantileForecast, groups: List[str] | None
):
"""
Check the dates between data and pred.
Dates must be 1-on-1 equal and no duplicate.
----------------------

Parameters
data:
The observed data used for modeling. Should be CumulativeUptakeData
pred:
The forecast made by model. Can be QuantileForecast or PointForecast
groups:
A list of grouping factors

Return
Error if conditions fail to meet.

"""
# sort data and pred by date #
groups_and_time = ["time_end"] + groups if groups is not None else ["time_end"]

data = CumulativeUptakeData(data.sort(groups_and_time))
pred = QuantileForecast(pred.sort(groups_and_time))

if groups is not None:
# check if the forecast and the data have the same forecast dates for each level in each group #
for group in groups:
for group_value in data[group].unique().to_list():
data_times = data.filter(pl.col(group) == group_value)[
"time_end"
].to_list()
pred_times = pred.filter(pl.col(group) == group_value)[
"time_end"
].to_list()
assert set(data_times) == set(pred_times), (
"The forecast and the data should have the same forecast dates for each group. "
f"Instead, we have {data_times=} and {pred_times=}"
)
else:
assert (data["time_end"] == pred["time_end"]).all(), (
"The forecast and the data should have the same forecast dates. "
f"Instead, we have {data['time_end']=} and {pred['time_end']=}."
)

if groups is not None:
# check across all the combinations of groups #
check = data.with_columns(dup=pl.col("time_end").is_duplicated().over(groups))
assert not (check["dup"].any()), "Duplicated dates are found in data"
else:
assert not (any(data["time_end"].is_duplicated())), (
"Duplicated dates are found in data."
)


def summarize_score(
data: CumulativeUptakeData,
pred: QuantileForecast,
groups: List[str] | None,
score_funs: Dict[str, Callable],
def mspe(
obs: pl.DataFrame, pred: pl.DataFrame, grouping_factors: List[str]
) -> pl.DataFrame:
"""
Calculate score between observed data and forecast.
----------------------

Parameters
data:
The observed data used for modeling. Should be CumulativeUptakeData
pred:
The forecast made by model. Can be QuantileForecast or PointForecast
groups:
A list of grouping factors, specified in config file.
score_funs:
A dictionary of scoring functions. The key is the name of the score, and the value
is the scoring function.

Return
A pl.DataFrame of scores with information including score name and score values, grouped by quantile, forecast

"""

check_date_match(data, pred, groups)
assert isinstance(data, CumulativeUptakeData)
assert isinstance(pred, QuantileForecast)

assert len(pred["quantile"].unique()) == 1, (
"The prediction should only have one quantile."
)

if groups is None:
columns_to_join = ["time_end"]
else:
columns_to_join = ["time_end"] + groups

joined_df = (
pl.DataFrame(data)
.join(pl.DataFrame(pred), on=columns_to_join, how="inner", validate="1:1")
.rename({"estimate": "data", "estimate_right": "pred"})
return (
pred.group_by(["model", "time_end", "forecast_start"] + grouping_factors)
.agg(pred_median=pl.col("estimate").median())
.join(obs, on=["time_end"] + grouping_factors, how="right")
.with_columns(score_value=(pl.col("estimate") - pl.col("pred_median")) ** 2)
.group_by(["model", "forecast_start"] + grouping_factors)
.agg(pl.col("score_value").mean())
.with_columns(score_fun=pl.lit("mspe"))
)

all_scores = pl.DataFrame()
for score_name in score_funs:
score = joined_df.group_by(groups).agg(
score_name=pl.lit(score_name),
score_value=score_funs[score_name](pl.col("data"), pl.col("pred")),
)

if not score.is_empty():
if isinstance(score["score_value"][0], pl.Series):
score = score.with_columns(
pl.col("score_value").list.drop_nulls().explode()
)
else:
score = score.with_columns(score_value=None)

score = score.with_columns(
quantile=joined_df["quantile"].first(),
forecast_start=joined_df["time_end"].min(),
forecast_end=joined_df["time_end"].max(),
)

all_scores = pl.concat([all_scores, score])

return all_scores


def mspe(x: pl.Expr, y: pl.Expr) -> pl.Expr:
"""
Calculate Mean Squared Prediction Error with polars column expression
---------------------
Arguments:
x: either observed data or predictions
y: either observed data or predictions
Return:
Mean Squared Prediction Error as a polars column expression

"""
return ((x - y) ** 2).mean()


def abs_diff(
selected_date: dt.date, date_col: pl.Expr
) -> Callable[[pl.Expr, pl.Expr], pl.Expr]:
def eos_abs_diff(
obs: pl.DataFrame,
pred: pl.DataFrame,
grouping_factors: List[str],
) -> pl.DataFrame:
"""
Generate a function that calculates the absolute difference between
observed data and prediction on a certain date.
----------------------
observed data and prediction for the last date in a season

Arguments:
selected_date: a datetime date object to specify which date to do the calculation
date_col: a polars column expression used to select the date

Return:
A function that takes two polars column expressions to do the calculation.
"""
assert "season" in grouping_factors

lit_date = pl.lit(selected_date)
median_pred = pred.group_by(
["model", "time_end", "forecast_start"] + grouping_factors
).agg(pred_median=pl.col("estimate").median())

def f(x: pl.Expr, y: pl.Expr) -> pl.Expr:
"""
Calculate the absolute difference between two polars column expressions
-----------------------
Arguments:
x: either observed data or predictions
y: either observed data or predictions

Return:
A polars column expression that returns the absolute difference at the certain date, otherwise None
"""
return pl.when(date_col == lit_date).then((x - y).abs()).otherwise(None)

return f
return (
obs.filter(
(pl.col("time_end") == pl.col("time_end").max()).over(grouping_factors)
)
.join(median_pred, on=["time_end"] + grouping_factors, how="left")
.with_columns(
score_value=(pl.col("estimate") - pl.col("pred_median")).abs(),
score_fun=pl.lit("eos_abs_diff"),
)
)
Loading