@@ -17,12 +17,35 @@ def _get_extra_sarix_params(self, df):
1717 return {}
1818
1919 def run (self , run_config ):
20+ valid_sources = np .array (["nhsn" , "nssp" ])
21+ if not np .isin (np .array (self .model_config .sources ), valid_sources ).all ():
22+ raise ValueError ("For SARIX, the only supported data sources are 'nhsn' or 'nssp'." )
23+
24+ # Check if both nhsn and nssp data are included as sources
25+ if all (src in self .model_config .sources for src in ["nhsn" , "nssp" ]):
26+ raise ValueError ("Only one of 'nhsn' or 'nssp' may be selected as a data source." )
27+
2028 fdl = DiseaseDataLoader ()
21- df = fdl .load_data (nhsn_kwargs = {"as_of" : run_config .ref_date , "disease" : run_config .disease },
22- sources = self .model_config .sources ,
23- power_transform = self .model_config .power_transform )
24- if run_config .locations is not None :
25- df = df .loc [df ["location" ].isin (run_config .locations )]
29+ if "nhsn" in self .model_config .sources :
30+ df = fdl .load_data (nhsn_kwargs = {"as_of" : run_config .ref_date , "disease" : run_config .disease },
31+ sources = self .model_config .sources ,
32+ power_transform = self .model_config .power_transform )
33+ target_name = "wk inc " + run_config .disease + " hosp"
34+ elif "nssp" in self .model_config .sources :
35+ df = fdl .load_data (nssp_kwargs = {"as_of" : run_config .ref_date , "disease" : run_config .disease },
36+ sources = self .model_config .sources ,
37+ power_transform = self .model_config .power_transform )
38+ target_name = "wk inc " + run_config .disease + " prop ed visits"
39+
40+ if (run_config .states == []) & (run_config .hsas == []):
41+ raise ValueError ("User must request a non-empty set of locations to forecast for." )
42+
43+ if (run_config .states != []) & (run_config .hsas != []):
44+ raise NotImplementedError ("Functionality for simultaneously forecasting state- and hsa-level locations is not yet implemented." )
45+
46+ df_states = df .loc [(df ["location" ].isin (run_config .states )) & (df ["agg_level" ] != "hsa" )]
47+ df_hsas = df .loc [(df ["location" ].isin (run_config .hsas )) & (df ["agg_level" ] == "hsa" )]
48+ df = pd .concat ([df_states , df_hsas ], join = "inner" , axis = 0 )
2649
2750 # season week relative to christmas
2851 df = df .merge (
@@ -34,10 +57,12 @@ def run(self, run_config):
3457 on = "season" ) \
3558 .assign (delta_xmas = lambda x : x ["season_week" ] - x ["xmas_week" ])
3659 df ["xmas_spike" ] = np .maximum (3 - np .abs (df ["delta_xmas" ]), 0 )
37-
60+
61+ # missing values are interpolated when possible
3862 xy_colnames = self .model_config .x + ["inc_trans_cs" ]
3963 df = df .query ("wk_end_date >= '2022-10-01'" ).interpolate ()
40- batched_xy = df [xy_colnames ].values .reshape (len (df ["location" ].unique ()), - 1 , len (xy_colnames ))
64+ unique_locations = len (df_states ["location" ].unique ()) + len (df_hsas ["location" ].unique ())
65+ batched_xy = df [xy_colnames ].values .reshape (unique_locations , - 1 , len (xy_colnames ))
4166
4267 # Get any extra parameters for the SARIX constructor
4368 extra_params = self ._get_extra_sarix_params (df )
@@ -62,18 +87,18 @@ def run(self, run_config):
6287 pred_qs = _np_percentile (sarix_fit_all_locs_theta_pooled .predictions [..., :, :, 0 ],
6388 np .array (run_config .q_levels ) * 100 , axis = 0 )
6489
65- df_nhsn_last_obs = df .groupby (["location" ]).tail (1 )
90+ df_data_last_obs = df .groupby (["location" , "agg_level " ]).tail (1 )
6691
6792 preds_df = pd .concat ([
6893 pd .DataFrame (pred_qs [i , :, :]) \
69- .set_axis (df_nhsn_last_obs ["location" ], axis = "index" ) \
94+ .set_axis (df_data_last_obs ["location" ], axis = "index" ) \
7095 .set_axis (np .arange (1 , run_config .max_horizon + 1 ), axis = "columns" ) \
7196 .assign (output_type_id = q_label ) \
7297 for i , q_label in enumerate (run_config .q_labels )
7398 ]) \
7499 .reset_index () \
75100 .melt (["location" , "output_type_id" ], var_name = "horizon" ) \
76- .merge (df_nhsn_last_obs , on = "location" , how = "left" )
101+ .merge (df_data_last_obs , on = "location" , how = "left" )
77102
78103 # build data frame with predictions on the original scale
79104 preds_df ["value" ] = (preds_df ["value" ] + preds_df ["inc_trans_center_factor" ]) * preds_df ["inc_trans_scale_factor" ]
@@ -82,19 +107,27 @@ def run(self, run_config):
82107 else :
83108 preds_df ["value" ] = np .maximum (preds_df ["value" ], 0.0 ) ** 2
84109
85- preds_df ["value" ] = (preds_df ["value" ] - 0.01 - 0.75 ** 4 ) * preds_df [ "pop" ] / 100000
110+ preds_df ["value" ] = (preds_df ["value" ] - 0.01 - 0.75 ** 4 )
86111 preds_df ["value" ] = np .maximum (preds_df ["value" ], 0.0 )
87112
113+ if "nhsn" in preds_df ["source" ].unique ():
114+ # turn nhsn rates back into counts
115+ preds_df ["value" ] = preds_df ["value" ] * preds_df ["pop" ] / 100000
116+
117+ if target_name == "wk inc " + run_config .disease + " prop ed visits" :
118+ preds_df ["value" ] = preds_df ["value" ] / 100 # percentage to proportion
119+ preds_df ["value" ] = np .minimum (preds_df ["value" ], 1.0 )
120+
88121 # keep just required columns and rename to match hub format
89122 preds_df = preds_df [["location" , "wk_end_date" , "horizon" , "output_type_id" , "value" ]]
90123
91124 preds_df ["target_end_date" ] = preds_df ["wk_end_date" ] + pd .to_timedelta (7 * preds_df ["horizon" ], unit = "days" )
92125 preds_df ["reference_date" ] = run_config .ref_date
93126 preds_df ["horizon" ] = (pd .to_timedelta (preds_df ["target_end_date" ].dt .date - run_config .ref_date ).dt .days / 7 ).astype (int )
94127 preds_df ["output_type" ] = "quantile"
95- preds_df ["target" ] = "wk inc " + run_config . disease + " hosp"
128+ preds_df ["target" ] = target_name
96129 preds_df .drop (columns = "wk_end_date" , inplace = True )
97-
130+
98131 # save
99132 save_path = build_save_path (
100133 root = run_config .output_root ,
0 commit comments