Skip to content

Commit b1e387d

Browse files
authored
Parallelize fits and predictions (#231)
- Parallelize the fits and predictions over forecast dates - Make a `output/fits/fit_DATE.pkl` for each forecast date (including all models) - Make a `output/pred/forecast_start=DATE/part-0.parquet` - Create a completion flag so that the parquets can be accessed Hive-style - Silence the jax warnings about GPUs-- we're using CPUs - Make diagnostics optional, since they are only workable with a small number of geographies
1 parent 62ecc8d commit b1e387d

9 files changed

Lines changed: 146 additions & 142 deletions

File tree

Makefile

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,57 @@
1-
RUN_ID = test
1+
RUN_ID = full3
22

33
RAW_DATA = data/raw.parquet
44
CONFIG = scripts/config.yaml
55

66
OUTPUT_DIR = output/$(RUN_ID)
77
CONFIG_COPY = $(OUTPUT_DIR)/config.yaml
8-
DATA = $(OUTPUT_DIR)/nis.parquet
9-
FITS = $(OUTPUT_DIR)/fits.pkl
10-
PREDS = $(OUTPUT_DIR)/predictions.parquet
11-
DIAGNOSTICS = $(OUTPUT_DIR)/diagnostics/status.txt
8+
DATA = $(OUTPUT_DIR)/data.parquet
9+
PRED_DIR = $(OUTPUT_DIR)/pred
10+
PREDS_FLAG = $(PRED_DIR)/.checkpoint
1211
SCORES = $(OUTPUT_DIR)/scores.parquet
1312

1413
PLOT_DATA = $(OUTPUT_DIR)/plots/data_one_season_by_state.png
1514
PLOT_PREDS = $(OUTPUT_DIR)/plots/forecast_example.png
1615

16+
FORECAST_STARTS = $(shell python scripts/get_forecast_starts.py --config=$(CONFIG))
17+
PREDS = $(foreach date,$(FORECAST_STARTS),$(PRED_DIR)/forecast_start=$(date)/part-0.parquet)
18+
FITS = $(foreach date,$(FORECAST_STARTS),$(OUTPUT_DIR)/fits/fit_$(date).pkl)
1719

18-
.PHONY: clean viz
20+
# This variable because the pattern `forecast_date=2020-01-01` confuses make.
21+
# It thinks `=%` is variable assignment, not pattern matching.
22+
# So we need `forecast_date$(EQ)%`.
23+
EQ = =
1924

20-
all: $(CONFIG_COPY) $(DATA) $(FITS) $(DIAGNOSTICS) $(PREDS) $(SCORES) $(PLOT_DATA) $(PLOT_PREDS)
25+
.PHONY: clean viz dx
26+
27+
all: $(CONFIG_COPY) $(DATA) $(FITS) $(PREDS) $(SCORES) $(PLOT_DATA) $(PLOT_PREDS)
2128

2229
viz:
2330
streamlit run scripts/viz.py -- \
24-
--data=$(DATA) --preds=$(PREDS) --scores=$(SCORES) --config=$(CONFIG)
25-
26-
$(SCORES): scripts/eval.py $(PREDS) $(DATA)
27-
python $< --preds=$(PREDS) --data=$(DATA) --config=$(CONFIG) --output=$@
31+
--data=$(DATA) --preds=$(PRED_DIR) --scores=$(SCORES) --config=$(CONFIG)
2832

29-
$(PLOT_PREDS): scripts/plot_preds.py $(CONFIG) $(DATA) $(PREDS) $(SCORES)
30-
python $< --config=$(CONFIG) --data=$(DATA) --preds=$(PREDS) --scores=$(SCORES) --output=$@
33+
dx: scripts.diagnostics $(FITS) $(CONFIG)
34+
python $< --fit_dir=$(OUTPUT_DIR)/fits --output_dir=$(OUTPUT_DIR)/diagnostics --config=$(CONFIG)
3135

32-
$(PREDS): scripts/predict.py $(DATA) $(FITS) $(CONFIG)
33-
python $< --data=$(DATA) --fits=$(FITS) --config=$(CONFIG) --output=$@
36+
$(SCORES): scripts/eval.py $(PREDS_FLAG) $(DATA) $(CONFIG)
37+
python $< --preds=$(PRED_DIR) --data=$(DATA) --config=$(CONFIG) --output=$@
3438

35-
$(DIAGNOSTICS): scripts/diagnostics.py $(FITS) $(CONFIG)
36-
python $< --fits=$(FITS) --config=$(CONFIG) --output=$@
37-
38-
$(FITS): scripts/fit.py $(DATA) $(CONFIG)
39-
python $< --data=$(DATA) --config=$(CONFIG) --output=$@
39+
$(PLOT_PREDS): scripts/plot_preds.py $(CONFIG) $(DATA) $(PREDS_FLAG) $(SCORES)
40+
python $< --config=$(CONFIG) --data=$(DATA) --preds=$(PRED_DIR) --scores=$(SCORES) --output=$@
4041

4142
$(PLOT_DATA): scripts/plot_data.py $(DATA)
4243
python $< --config=$(CONFIG) --data=$(DATA) --output=$@
4344

45+
$(PLOT_DATA): scripts/plot_data.py $(DATA) $(CONFIG)
46+
python $< --config=$(CONFIG) --data=$(DATA) --output=$@
47+
48+
# output/run_id/pred/forecast_start=2021-01-01/part-0.parquet <== output/fits/fit_2021-01-01.pkl
49+
$(PRED_DIR)/forecast_start$(EQ)%/part-0.parquet: scripts/predict.py $(OUTPUT_DIR)/fits/fit_%.pkl $(DATA) $(CONFIG)
50+
python $< --data=$(DATA) --fits=$(OUTPUT_DIR)/fits/fit_$*.pkl --config=$(CONFIG) --output=$@
51+
52+
$(OUTPUT_DIR)/fits/fit_%.pkl: scripts/fit.py $(DATA) $(CONFIG)
53+
python $< --data=$(DATA) --forecast_start=$* --config=$(CONFIG) --output=$@
54+
4455
$(DATA): scripts/preprocess.py $(RAW_DATA) $(CONFIG)
4556
python $< --config=$(CONFIG) --input=$(RAW_DATA) --output=$@
4657

iup/models.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import os
2+
3+
# silence Jax CPU warning
4+
os.environ["JAX_PLATFORMS"] = "cpu"
5+
16
import abc
27
from typing import List
38

@@ -271,12 +276,7 @@ def fit(
271276
N_tot = data["N_tot"].to_numpy()
272277

273278
self.kernel = NUTS(self.model, init_strategy=init_to_sample)
274-
self.mcmc = MCMC(
275-
self.kernel,
276-
num_warmup=mcmc["num_warmup"],
277-
num_samples=mcmc["num_samples"],
278-
num_chains=mcmc["num_chains"],
279-
)
279+
self.mcmc = MCMC(self.kernel, **mcmc)
280280

281281
self.mcmc.run(
282282
self.fit_key,
@@ -300,7 +300,8 @@ def fit(
300300
d_rate=params["d_rate"],
301301
)
302302

303-
self.mcmc.print_summary()
303+
if "progress_bar" in mcmc and mcmc["progress_bar"]:
304+
self.mcmc.print_summary()
304305

305306
return self
306307

scripts/config_template.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ mcmc:
3131
num_warmup: 1000
3232
num_samples: 1000
3333
num_chains: 4
34+
progress_bar: false
3435

3536
forecasts:
3637
start_date:
@@ -49,7 +50,7 @@ forecast_plots:
4950
n_trajectories: 20
5051

5152
diagnostics:
52-
forecast_date:
53-
model: [LPLModel]
54-
plot: [posterior_density_plot, parameter_trace_plot]
55-
table: [print_posterior_dist, print_model_summary]
53+
forecast_starts: [2021-07-01]
54+
models: [LPLModel]
55+
plots: [posterior_density_plot, parameter_trace_plot]
56+
tables: [print_posterior_dist, print_model_summary]

scripts/diagnostics.py

Lines changed: 44 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import datetime as dt
33
import pickle
44
from pathlib import Path
5-
from typing import Any, Dict, Tuple
65

76
import yaml
87

@@ -12,102 +11,64 @@
1211

1312

1413
def diagnostic_plot(
15-
models: Dict[Tuple[str, dt.date], iup.models.UptakeModel],
16-
config: Dict[str, Any],
17-
output_dir,
14+
plot_name: str, fit: iup.models.UptakeModel, output_path: str | Path
1815
):
19-
"""select the fitted model using model name and training end date
20-
and generate selected diagnostic plots"""
21-
22-
sel_model_dict = select_model_to_diagnose(models, config)
23-
24-
diagnose_plot_names = config["diagnostics"]["plot"]
25-
26-
for key, model in sel_model_dict.items():
27-
for plot_name in diagnose_plot_names:
28-
plot_func = getattr(iup.diagnostics, plot_name)
29-
axes = plot_func(model)
30-
fig = axes.ravel()[0].figure
31-
fig.savefig(
32-
Path(
33-
output_dir,
34-
f"model={key[0]}_forecast_start={str(key[1])}_{plot_name}.png",
35-
)
36-
)
16+
plot_func = getattr(iup.diagnostics, plot_name)
17+
axes = plot_func(fit)
18+
fig = axes.ravel()[0].figure
19+
fig.savefig(output_path)
3720

3821

3922
def diagnostic_table(
40-
models: Dict[Tuple[str, dt.date], iup.models.UptakeModel],
41-
config: Dict[str, Any],
42-
output_dir,
23+
table_name: str, fit: iup.models.UptakeModel, output_path: str | Path
4324
):
44-
"""select the fitted model using model name and training end date
45-
and generate selected diagnostics: summary/posterior as parquet"""
46-
47-
sel_model_dict = select_model_to_diagnose(models, config)
48-
49-
diagnose_table_names = config["diagnostics"]["table"]
50-
51-
for key, model in sel_model_dict.items():
52-
for table_name in diagnose_table_names:
53-
table_func = getattr(iup.diagnostics, table_name)
54-
output = table_func(model)
55-
56-
output.write_parquet(
57-
Path(
58-
output_dir,
59-
f"model={key[0]}_forecast_start={str(key[1])}_{table_name}.parquet",
60-
)
61-
)
62-
63-
64-
def select_model_to_diagnose(
65-
models: Dict[Tuple[str, dt.date], iup.models.UptakeModel], config
66-
) -> dict:
67-
"""Select the model to diagnose based on the model name and the training end date"""
68-
69-
forecast_dates = config["diagnostics"]["forecast_date"]
70-
71-
if forecast_dates is None:
72-
sel_keys = [
73-
(model, date)
74-
for model, date in models.keys()
75-
if model in config["diagnostics"]["model"]
76-
]
77-
else:
78-
assert isinstance(forecast_dates, list)
79-
assert all(isinstance(x, dt.date) for x in forecast_dates)
80-
sel_keys = [
81-
(model, date)
82-
for model, date in models.keys()
83-
if model in config["diagnostics"]["model"] and date in forecast_dates
84-
]
85-
86-
return {key: models[key] for key in sel_keys}
25+
table_func = getattr(iup.diagnostics, table_name)
26+
output = table_func(fit)
27+
output.write_csv(output_path)
8728

8829

8930
if __name__ == "__main__":
9031
p = argparse.ArgumentParser()
91-
p.add_argument("--config", help="config file")
92-
p.add_argument("--fits", help="fits pickle")
93-
p.add_argument(
94-
"--output", help="output status file; other files put in the same directory"
95-
)
32+
p.add_argument("--config", help="config file", required=True)
33+
p.add_argument("--fits_dir", help="directory with fit pickles", required=True)
34+
p.add_argument("--output_dir", required=True)
9635
args = p.parse_args()
9736

37+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
38+
9839
with open(args.config, "r") as f:
9940
config = yaml.safe_load(f)
10041

101-
with open(args.fits, "rb") as f:
102-
models = pickle.load(f)
42+
for key in ["forecast_starts", "models", "tables", "plots"]:
43+
assert isinstance(config["diagnostics"][key], list), (
44+
f"config['diagnostics']['{key}'] should be a list"
45+
)
10346

104-
output_dir = Path(args.output).parent
105-
output_dir.mkdir(parents=True, exist_ok=True)
47+
for forecast_start in config["diagnostics"]["forecast_starts"]:
48+
fc_date = dt.date.fromisoformat(forecast_start)
10649

107-
# write the other plots to the same folder
108-
diagnostic_plot(models, config, output_dir)
109-
diagnostic_table(models, config, output_dir)
50+
for model in config["diagnostics"]["models"]:
51+
with open(Path(args.fits_dir) / f"fit_{fc_date}.pkl", "rb") as f:
52+
fits = pickle.load(f)
11053

111-
# write the status file
112-
with open(args.output, "w") as f:
113-
f.write(dt.datetime.now().isoformat())
54+
fit = fits[(model, fc_date)]
55+
56+
for table in config["diagnostics"]["tables"]:
57+
diagnostic_table(
58+
table_name=table,
59+
fit=fit,
60+
output_path=Path(
61+
args.output_dir,
62+
f"model={model}_forecast_start={fc_date}_{table}.csv",
63+
),
64+
)
65+
66+
for plot in config["diagnostics"]["plots"]:
67+
diagnostic_plot(
68+
plot_name=plot,
69+
fit=fit,
70+
output_path=Path(
71+
args.output_dir,
72+
f"model={model}_forecast_start={fc_date}_{plot}.png",
73+
),
74+
)

scripts/eval.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def eval_all_forecasts(
2424
A pl.DataFrame with score name and score values, grouped by model, forecast start, quantile, and possibly other grouping factors
2525
"""
2626
forecast_starts = pred["forecast_start"].unique()
27+
28+
assert "score_funs" in config, (
29+
f"`score_funs` not among config keys: {config.keys()}"
30+
)
2731
score_funs = [getattr(iup.eval, fun_name) for fun_name in config["score_funs"]]
2832

2933
assert config["groups"] is not None

scripts/fit.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
import os
2+
3+
# silence Jax CPU warning
4+
os.environ["JAX_PLATFORMS"] = "cpu"
5+
16
import argparse
27
import datetime as dt
38
import pickle as pkl
9+
from pathlib import Path
410
from typing import Any, Dict, List, Type
511

612
import numpyro
@@ -11,20 +17,16 @@
1117
import iup.models
1218

1319

14-
def fit_all_models(data, config) -> Dict[str, iup.models.UptakeModel]:
15-
"""Run all forecasts
20+
def fit_all_models(
21+
data, forecast_start: dt.date, config
22+
) -> Dict[str, iup.models.UptakeModel]:
23+
"""
24+
Run all forecasts
1625
1726
Returns:
1827
pl.DataFrame: data frame of forecasts, organized by model and forecast date
1928
"""
2029

21-
forecast_dates = pl.date_range(
22-
config["forecasts"]["start_date"]["start"],
23-
config["forecasts"]["start_date"]["end"],
24-
config["forecasts"]["start_date"]["interval"],
25-
eager=True,
26-
)
27-
2830
all_models = {}
2931

3032
for config_model in config["models"]:
@@ -41,19 +43,18 @@ def fit_all_models(data, config) -> Dict[str, iup.models.UptakeModel]:
4143
config["season"]["start_day"],
4244
)
4345

44-
for forecast_date in forecast_dates:
45-
fitted_model = fit_model(
46-
data=augmented_data,
47-
model_class=model_class,
48-
seed=config_model["seed"],
49-
params=config_model["params"],
50-
mcmc=config["mcmc"],
51-
grouping_factors=config["groups"],
52-
forecast_start=forecast_date,
53-
)
46+
fitted_model = fit_model(
47+
data=augmented_data,
48+
model_class=model_class,
49+
seed=config_model["seed"],
50+
params=config_model["params"],
51+
mcmc=config["mcmc"],
52+
grouping_factors=config["groups"],
53+
forecast_start=forecast_start,
54+
)
5455

55-
label = (model_name, forecast_date)
56-
all_models[label] = fitted_model
56+
label = (model_name, forecast_start)
57+
all_models[label] = fitted_model
5758

5859
return all_models
5960

@@ -85,17 +86,19 @@ def fit_model(
8586
p = argparse.ArgumentParser()
8687
p.add_argument("--config", help="config file", required=True)
8788
p.add_argument("--data", help="input data", required=True)
88-
p.add_argument("--output", help="output directory", required=True)
89+
p.add_argument("--forecast_start", required=True)
90+
p.add_argument("--output", help="output pickle path", required=True)
8991
args = p.parse_args()
9092

9193
with open(args.config, "r") as f:
9294
config = yaml.safe_load(f)
9395

94-
numpyro.set_host_device_count(config["mcmc"]["num_chains"])
95-
96-
input_data = iup.CumulativeUptakeData(pl.read_parquet(args.data))
96+
forecast_start = dt.date.fromisoformat(args.forecast_start)
97+
data = iup.CumulativeUptakeData(pl.read_parquet(args.data))
9798

98-
all_models = fit_all_models(input_data, config)
99+
numpyro.set_host_device_count(config["mcmc"]["num_chains"])
100+
all_models = fit_all_models(data=data, forecast_start=forecast_start, config=config)
99101

102+
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
100103
with open(args.output, "wb") as f:
101104
pkl.dump(all_models, f)

0 commit comments

Comments
 (0)