|
2 | 2 | import datetime as dt |
3 | 3 | import pickle |
4 | 4 | from pathlib import Path |
5 | | -from typing import Any, Dict, Tuple |
6 | 5 |
|
7 | 6 | import yaml |
8 | 7 |
|
|
12 | 11 |
|
13 | 12 |
|
14 | 13 | def diagnostic_plot( |
15 | | - models: Dict[Tuple[str, dt.date], iup.models.UptakeModel], |
16 | | - config: Dict[str, Any], |
17 | | - output_dir, |
| 14 | + plot_name: str, fit: iup.models.UptakeModel, output_path: str | Path |
18 | 15 | ): |
19 | | - """select the fitted model using model name and training end date |
20 | | - and generate selected diagnostic plots""" |
21 | | - |
22 | | - sel_model_dict = select_model_to_diagnose(models, config) |
23 | | - |
24 | | - diagnose_plot_names = config["diagnostics"]["plot"] |
25 | | - |
26 | | - for key, model in sel_model_dict.items(): |
27 | | - for plot_name in diagnose_plot_names: |
28 | | - plot_func = getattr(iup.diagnostics, plot_name) |
29 | | - axes = plot_func(model) |
30 | | - fig = axes.ravel()[0].figure |
31 | | - fig.savefig( |
32 | | - Path( |
33 | | - output_dir, |
34 | | - f"model={key[0]}_forecast_start={str(key[1])}_{plot_name}.png", |
35 | | - ) |
36 | | - ) |
| 16 | + plot_func = getattr(iup.diagnostics, plot_name) |
| 17 | + axes = plot_func(fit) |
| 18 | + fig = axes.ravel()[0].figure |
| 19 | + fig.savefig(output_path) |
37 | 20 |
|
38 | 21 |
|
39 | 22 | def diagnostic_table( |
40 | | - models: Dict[Tuple[str, dt.date], iup.models.UptakeModel], |
41 | | - config: Dict[str, Any], |
42 | | - output_dir, |
| 23 | + table_name: str, fit: iup.models.UptakeModel, output_path: str | Path |
43 | 24 | ): |
44 | | - """select the fitted model using model name and training end date |
45 | | - and generate selected diagnostics: summary/posterior as parquet""" |
46 | | - |
47 | | - sel_model_dict = select_model_to_diagnose(models, config) |
48 | | - |
49 | | - diagnose_table_names = config["diagnostics"]["table"] |
50 | | - |
51 | | - for key, model in sel_model_dict.items(): |
52 | | - for table_name in diagnose_table_names: |
53 | | - table_func = getattr(iup.diagnostics, table_name) |
54 | | - output = table_func(model) |
55 | | - |
56 | | - output.write_parquet( |
57 | | - Path( |
58 | | - output_dir, |
59 | | - f"model={key[0]}_forecast_start={str(key[1])}_{table_name}.parquet", |
60 | | - ) |
61 | | - ) |
62 | | - |
63 | | - |
64 | | -def select_model_to_diagnose( |
65 | | - models: Dict[Tuple[str, dt.date], iup.models.UptakeModel], config |
66 | | -) -> dict: |
67 | | - """Select the model to diagnose based on the model name and the training end date""" |
68 | | - |
69 | | - forecast_dates = config["diagnostics"]["forecast_date"] |
70 | | - |
71 | | - if forecast_dates is None: |
72 | | - sel_keys = [ |
73 | | - (model, date) |
74 | | - for model, date in models.keys() |
75 | | - if model in config["diagnostics"]["model"] |
76 | | - ] |
77 | | - else: |
78 | | - assert isinstance(forecast_dates, list) |
79 | | - assert all(isinstance(x, dt.date) for x in forecast_dates) |
80 | | - sel_keys = [ |
81 | | - (model, date) |
82 | | - for model, date in models.keys() |
83 | | - if model in config["diagnostics"]["model"] and date in forecast_dates |
84 | | - ] |
85 | | - |
86 | | - return {key: models[key] for key in sel_keys} |
| 25 | + table_func = getattr(iup.diagnostics, table_name) |
| 26 | + output = table_func(fit) |
| 27 | + output.write_csv(output_path) |
87 | 28 |
|
88 | 29 |
|
89 | 30 | if __name__ == "__main__": |
90 | 31 | p = argparse.ArgumentParser() |
91 | | - p.add_argument("--config", help="config file") |
92 | | - p.add_argument("--fits", help="fits pickle") |
93 | | - p.add_argument( |
94 | | - "--output", help="output status file; other files put in the same directory" |
95 | | - ) |
| 32 | + p.add_argument("--config", help="config file", required=True) |
| 33 | + p.add_argument("--fits_dir", help="directory with fit pickles", required=True) |
| 34 | + p.add_argument("--output_dir", required=True) |
96 | 35 | args = p.parse_args() |
97 | 36 |
|
| 37 | + Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
| 38 | + |
98 | 39 | with open(args.config, "r") as f: |
99 | 40 | config = yaml.safe_load(f) |
100 | 41 |
|
101 | | - with open(args.fits, "rb") as f: |
102 | | - models = pickle.load(f) |
| 42 | + for key in ["forecast_starts", "models", "tables", "plots"]: |
| 43 | + assert isinstance(config["diagnostics"][key], list), ( |
| 44 | + f"config['diagnostics']['{key}'] should be a list" |
| 45 | + ) |
103 | 46 |
|
104 | | - output_dir = Path(args.output).parent |
105 | | - output_dir.mkdir(parents=True, exist_ok=True) |
| 47 | + for forecast_start in config["diagnostics"]["forecast_starts"]: |
| 48 | + fc_date = dt.date.fromisoformat(forecast_start) |
106 | 49 |
|
107 | | - # write the other plots to the same folder |
108 | | - diagnostic_plot(models, config, output_dir) |
109 | | - diagnostic_table(models, config, output_dir) |
| 50 | + for model in config["diagnostics"]["models"]: |
| 51 | + with open(Path(args.fits_dir) / f"fit_{fc_date}.pkl", "rb") as f: |
| 52 | + fits = pickle.load(f) |
110 | 53 |
|
111 | | - # write the status file |
112 | | - with open(args.output, "w") as f: |
113 | | - f.write(dt.datetime.now().isoformat()) |
| 54 | + fit = fits[(model, fc_date)] |
| 55 | + |
| 56 | + for table in config["diagnostics"]["tables"]: |
| 57 | + diagnostic_table( |
| 58 | + table_name=table, |
| 59 | + fit=fit, |
| 60 | + output_path=Path( |
| 61 | + args.output_dir, |
| 62 | + f"model={model}_forecast_start={fc_date}_{table}.csv", |
| 63 | + ), |
| 64 | + ) |
| 65 | + |
| 66 | + for plot in config["diagnostics"]["plots"]: |
| 67 | + diagnostic_plot( |
| 68 | + plot_name=plot, |
| 69 | + fit=fit, |
| 70 | + output_path=Path( |
| 71 | + args.output_dir, |
| 72 | + f"model={model}_forecast_start={fc_date}_{plot}.png", |
| 73 | + ), |
| 74 | + ) |
0 commit comments