Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,46 +1,57 @@
RUN_ID = test
RUN_ID = full3

RAW_DATA = data/raw.parquet
CONFIG = scripts/config.yaml

OUTPUT_DIR = output/$(RUN_ID)
CONFIG_COPY = $(OUTPUT_DIR)/config.yaml
DATA = $(OUTPUT_DIR)/nis.parquet
FITS = $(OUTPUT_DIR)/fits.pkl
PREDS = $(OUTPUT_DIR)/predictions.parquet
DIAGNOSTICS = $(OUTPUT_DIR)/diagnostics/status.txt
DATA = $(OUTPUT_DIR)/data.parquet
PRED_DIR = $(OUTPUT_DIR)/pred
PREDS_FLAG = $(PRED_DIR)/.checkpoint
SCORES = $(OUTPUT_DIR)/scores.parquet

PLOT_DATA = $(OUTPUT_DIR)/plots/data_one_season_by_state.png
PLOT_PREDS = $(OUTPUT_DIR)/plots/forecast_example.png

FORECAST_STARTS = $(shell python scripts/get_forecast_starts.py --config=$(CONFIG))
PREDS = $(foreach date,$(FORECAST_STARTS),$(PRED_DIR)/forecast_start=$(date)/part-0.parquet)
FITS = $(foreach date,$(FORECAST_STARTS),$(OUTPUT_DIR)/fits/fit_$(date).pkl)

.PHONY: clean viz
# This variable because the pattern `forecast_date=2020-01-01` confuses make.
# It thinks `=%` is variable assignment, not pattern matching.
# So we need `forecast_date$(EQ)%`.
EQ = =

all: $(CONFIG_COPY) $(DATA) $(FITS) $(DIAGNOSTICS) $(PREDS) $(SCORES) $(PLOT_DATA) $(PLOT_PREDS)
.PHONY: clean viz dx

all: $(CONFIG_COPY) $(DATA) $(FITS) $(PREDS) $(SCORES) $(PLOT_DATA) $(PLOT_PREDS)

viz:
streamlit run scripts/viz.py -- \
--data=$(DATA) --preds=$(PREDS) --scores=$(SCORES) --config=$(CONFIG)

$(SCORES): scripts/eval.py $(PREDS) $(DATA)
python $< --preds=$(PREDS) --data=$(DATA) --config=$(CONFIG) --output=$@
--data=$(DATA) --preds=$(PRED_DIR) --scores=$(SCORES) --config=$(CONFIG)

$(PLOT_PREDS): scripts/plot_preds.py $(CONFIG) $(DATA) $(PREDS) $(SCORES)
python $< --config=$(CONFIG) --data=$(DATA) --preds=$(PREDS) --scores=$(SCORES) --output=$@
dx: scripts.diagnostics $(FITS) $(CONFIG)
python $< --fit_dir=$(OUTPUT_DIR)/fits --output_dir=$(OUTPUT_DIR)/diagnostics --config=$(CONFIG)

$(PREDS): scripts/predict.py $(DATA) $(FITS) $(CONFIG)
python $< --data=$(DATA) --fits=$(FITS) --config=$(CONFIG) --output=$@
$(SCORES): scripts/eval.py $(PREDS_FLAG) $(DATA) $(CONFIG)
python $< --preds=$(PRED_DIR) --data=$(DATA) --config=$(CONFIG) --output=$@

$(DIAGNOSTICS): scripts/diagnostics.py $(FITS) $(CONFIG)
python $< --fits=$(FITS) --config=$(CONFIG) --output=$@

$(FITS): scripts/fit.py $(DATA) $(CONFIG)
python $< --data=$(DATA) --config=$(CONFIG) --output=$@
$(PLOT_PREDS): scripts/plot_preds.py $(CONFIG) $(DATA) $(PREDS_FLAG) $(SCORES)
python $< --config=$(CONFIG) --data=$(DATA) --preds=$(PRED_DIR) --scores=$(SCORES) --output=$@

$(PLOT_DATA): scripts/plot_data.py $(DATA)
python $< --config=$(CONFIG) --data=$(DATA) --output=$@

$(PLOT_DATA): scripts/plot_data.py $(DATA) $(CONFIG)
python $< --config=$(CONFIG) --data=$(DATA) --output=$@

# output/run_id/pred/forecast_start=2021-01-01/part-0.parquet <== output/fits/fit_2021-01-01.pkl
$(PRED_DIR)/forecast_start$(EQ)%/part-0.parquet: scripts/predict.py $(OUTPUT_DIR)/fits/fit_%.pkl $(DATA) $(CONFIG)
python $< --data=$(DATA) --fits=$(OUTPUT_DIR)/fits/fit_$*.pkl --config=$(CONFIG) --output=$@

$(OUTPUT_DIR)/fits/fit_%.pkl: scripts/fit.py $(DATA) $(CONFIG)
python $< --data=$(DATA) --forecast_start=$* --config=$(CONFIG) --output=$@

$(DATA): scripts/preprocess.py $(RAW_DATA) $(CONFIG)
python $< --config=$(CONFIG) --input=$(RAW_DATA) --output=$@

Expand Down
15 changes: 8 additions & 7 deletions iup/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import os

# silence Jax CPU warning
os.environ["JAX_PLATFORMS"] = "cpu"

import abc
from typing import List

Expand Down Expand Up @@ -271,12 +276,7 @@ def fit(
N_tot = data["N_tot"].to_numpy()

self.kernel = NUTS(self.model, init_strategy=init_to_sample)
self.mcmc = MCMC(
self.kernel,
num_warmup=mcmc["num_warmup"],
num_samples=mcmc["num_samples"],
num_chains=mcmc["num_chains"],
)
self.mcmc = MCMC(self.kernel, **mcmc)

self.mcmc.run(
self.fit_key,
Expand All @@ -300,7 +300,8 @@ def fit(
d_rate=params["d_rate"],
)

self.mcmc.print_summary()
if "progress_bar" in mcmc and mcmc["progress_bar"]:
self.mcmc.print_summary()

return self

Expand Down
9 changes: 5 additions & 4 deletions scripts/config_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ mcmc:
num_warmup: 1000
num_samples: 1000
num_chains: 4
progress_bar: false

forecasts:
start_date:
Expand All @@ -49,7 +50,7 @@ forecast_plots:
n_trajectories: 20

diagnostics:
forecast_date:
model: [LPLModel]
plot: [posterior_density_plot, parameter_trace_plot]
table: [print_posterior_dist, print_model_summary]
forecast_starts: [2021-07-01]
models: [LPLModel]
plots: [posterior_density_plot, parameter_trace_plot]
tables: [print_posterior_dist, print_model_summary]
127 changes: 44 additions & 83 deletions scripts/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import datetime as dt
import pickle
from pathlib import Path
from typing import Any, Dict, Tuple

import yaml

Expand All @@ -12,102 +11,64 @@


def diagnostic_plot(
models: Dict[Tuple[str, dt.date], iup.models.UptakeModel],
config: Dict[str, Any],
output_dir,
plot_name: str, fit: iup.models.UptakeModel, output_path: str | Path
):
"""select the fitted model using model name and training end date
and generate selected diagnostic plots"""

sel_model_dict = select_model_to_diagnose(models, config)

diagnose_plot_names = config["diagnostics"]["plot"]

for key, model in sel_model_dict.items():
for plot_name in diagnose_plot_names:
plot_func = getattr(iup.diagnostics, plot_name)
axes = plot_func(model)
fig = axes.ravel()[0].figure
fig.savefig(
Path(
output_dir,
f"model={key[0]}_forecast_start={str(key[1])}_{plot_name}.png",
)
)
plot_func = getattr(iup.diagnostics, plot_name)
axes = plot_func(fit)
fig = axes.ravel()[0].figure
fig.savefig(output_path)


def diagnostic_table(
models: Dict[Tuple[str, dt.date], iup.models.UptakeModel],
config: Dict[str, Any],
output_dir,
table_name: str, fit: iup.models.UptakeModel, output_path: str | Path
):
"""select the fitted model using model name and training end date
and generate selected diagnostics: summary/posterior as parquet"""

sel_model_dict = select_model_to_diagnose(models, config)

diagnose_table_names = config["diagnostics"]["table"]

for key, model in sel_model_dict.items():
for table_name in diagnose_table_names:
table_func = getattr(iup.diagnostics, table_name)
output = table_func(model)

output.write_parquet(
Path(
output_dir,
f"model={key[0]}_forecast_start={str(key[1])}_{table_name}.parquet",
)
)


def select_model_to_diagnose(
models: Dict[Tuple[str, dt.date], iup.models.UptakeModel], config
) -> dict:
"""Select the model to diagnose based on the model name and the training end date"""

forecast_dates = config["diagnostics"]["forecast_date"]

if forecast_dates is None:
sel_keys = [
(model, date)
for model, date in models.keys()
if model in config["diagnostics"]["model"]
]
else:
assert isinstance(forecast_dates, list)
assert all(isinstance(x, dt.date) for x in forecast_dates)
sel_keys = [
(model, date)
for model, date in models.keys()
if model in config["diagnostics"]["model"] and date in forecast_dates
]

return {key: models[key] for key in sel_keys}
table_func = getattr(iup.diagnostics, table_name)
output = table_func(fit)
output.write_csv(output_path)


if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--config", help="config file")
p.add_argument("--fits", help="fits pickle")
p.add_argument(
"--output", help="output status file; other files put in the same directory"
)
p.add_argument("--config", help="config file", required=True)
p.add_argument("--fits_dir", help="directory with fit pickles", required=True)
p.add_argument("--output_dir", required=True)
args = p.parse_args()

Path(args.output_dir).mkdir(parents=True, exist_ok=True)

with open(args.config, "r") as f:
config = yaml.safe_load(f)

with open(args.fits, "rb") as f:
models = pickle.load(f)
for key in ["forecast_starts", "models", "tables", "plots"]:
assert isinstance(config["diagnostics"][key], list), (
f"config['diagnostics']['{key}'] should be a list"
)

output_dir = Path(args.output).parent
output_dir.mkdir(parents=True, exist_ok=True)
for forecast_start in config["diagnostics"]["forecast_starts"]:
fc_date = dt.date.fromisoformat(forecast_start)

# write the other plots to the same folder
diagnostic_plot(models, config, output_dir)
diagnostic_table(models, config, output_dir)
for model in config["diagnostics"]["models"]:
with open(Path(args.fits_dir) / f"fit_{fc_date}.pkl", "rb") as f:
fits = pickle.load(f)

# write the status file
with open(args.output, "w") as f:
f.write(dt.datetime.now().isoformat())
fit = fits[(model, fc_date)]

for table in config["diagnostics"]["tables"]:
diagnostic_table(
table_name=table,
fit=fit,
output_path=Path(
args.output_dir,
f"model={model}_forecast_start={fc_date}_{table}.csv",
),
)

for plot in config["diagnostics"]["plots"]:
diagnostic_plot(
plot_name=plot,
fit=fit,
output_path=Path(
args.output_dir,
f"model={model}_forecast_start={fc_date}_{plot}.png",
),
)
4 changes: 4 additions & 0 deletions scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def eval_all_forecasts(
A pl.DataFrame with score name and score values, grouped by model, forecast start, quantile, and possibly other grouping factors
"""
forecast_starts = pred["forecast_start"].unique()

assert "score_funs" in config, (
f"`score_funs` not among config keys: {config.keys()}"
)
score_funs = [getattr(iup.eval, fun_name) for fun_name in config["score_funs"]]

assert config["groups"] is not None
Expand Down
55 changes: 29 additions & 26 deletions scripts/fit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import os

# silence Jax CPU warning
os.environ["JAX_PLATFORMS"] = "cpu"

import argparse
import datetime as dt
import pickle as pkl
from pathlib import Path
from typing import Any, Dict, List, Type

import numpyro
Expand All @@ -11,20 +17,16 @@
import iup.models


def fit_all_models(data, config) -> Dict[str, iup.models.UptakeModel]:
"""Run all forecasts
def fit_all_models(
data, forecast_start: dt.date, config
) -> Dict[str, iup.models.UptakeModel]:
"""
Run all forecasts

Returns:
pl.DataFrame: data frame of forecasts, organized by model and forecast date
"""

forecast_dates = pl.date_range(
config["forecasts"]["start_date"]["start"],
config["forecasts"]["start_date"]["end"],
config["forecasts"]["start_date"]["interval"],
eager=True,
)

all_models = {}

for config_model in config["models"]:
Expand All @@ -41,19 +43,18 @@ def fit_all_models(data, config) -> Dict[str, iup.models.UptakeModel]:
config["season"]["start_day"],
)

for forecast_date in forecast_dates:
fitted_model = fit_model(
data=augmented_data,
model_class=model_class,
seed=config_model["seed"],
params=config_model["params"],
mcmc=config["mcmc"],
grouping_factors=config["groups"],
forecast_start=forecast_date,
)
fitted_model = fit_model(
data=augmented_data,
model_class=model_class,
seed=config_model["seed"],
params=config_model["params"],
mcmc=config["mcmc"],
grouping_factors=config["groups"],
forecast_start=forecast_start,
)

label = (model_name, forecast_date)
all_models[label] = fitted_model
label = (model_name, forecast_start)
all_models[label] = fitted_model

return all_models

Expand Down Expand Up @@ -85,17 +86,19 @@ def fit_model(
p = argparse.ArgumentParser()
p.add_argument("--config", help="config file", required=True)
p.add_argument("--data", help="input data", required=True)
p.add_argument("--output", help="output directory", required=True)
p.add_argument("--forecast_start", required=True)
p.add_argument("--output", help="output pickle path", required=True)
args = p.parse_args()

with open(args.config, "r") as f:
config = yaml.safe_load(f)

numpyro.set_host_device_count(config["mcmc"]["num_chains"])

input_data = iup.CumulativeUptakeData(pl.read_parquet(args.data))
forecast_start = dt.date.fromisoformat(args.forecast_start)
data = iup.CumulativeUptakeData(pl.read_parquet(args.data))

all_models = fit_all_models(input_data, config)
numpyro.set_host_device_count(config["mcmc"]["num_chains"])
all_models = fit_all_models(data=data, forecast_start=forecast_start, config=config)

Path(args.output).parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "wb") as f:
pkl.dump(all_models, f)
Loading