Skip to content

Commit f0160d4

Browse files
authored
Merge pull request #14 from reichlab/ls/sarix-gbqr-takes-nssp-data/9
SARIX and GBQR take NSSP Data
2 parents 6dc1e0b + 397827b commit f0160d4

14 files changed

Lines changed: 1314 additions & 197 deletions

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = [
1515
"lightgbm",
1616
"numpy",
1717
"pandas",
18-
"sarix @ git+https://github.com/reichlab/sarix@35eea2379a9790e0457b1aed41d13509e5d5056f",
18+
"sarix @ git+https://github.com/reichlab/sarix",
1919
"scikit-learn",
2020
"tqdm",
2121
"timeseriesutils @ git+https://github.com/reichlab/timeseriesutils"

requirements/requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ frozenlist==1.5.0
3838
# aiosignal
3939
fsspec==2024.10.0
4040
# via s3fs
41-
iddata @ git+https://github.com/reichlab/iddata@c28849b2a02ab84e2f82876f16fee2ac60814877
41+
iddata @ git+https://github.com/reichlab/iddata@5a7e74d7823d39b8a8ef6334c5191e440bc669d8
4242
# via idmodels (pyproject.toml)
4343
identify==2.6.1
4444
# via pre-commit

requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ frozenlist==1.5.0
3030
# aiosignal
3131
fsspec==2024.10.0
3232
# via s3fs
33-
iddata @ git+https://github.com/reichlab/iddata@c28849b2a02ab84e2f82876f16fee2ac60814877
33+
iddata @ git+https://github.com/reichlab/iddata@5a7e74d7823d39b8a8ef6334c5191e440bc669d8
3434
# via idmodels (pyproject.toml)
3535
idna==3.10
3636
# via yarl

src/idmodels/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.0"
1+
__version__ = "1.0.0"

src/idmodels/gbqr.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,38 @@ def run(self, run_config):
3131
ilinet_kwargs = {"scale_to_positive": False}
3232
flusurvnet_kwargs = {"burden_adj": False}
3333

34+
valid_sources = ["flusurvnet", "nhsn", "ilinet", "nssp"]
35+
if not np.isin(np.array(self.model_config.sources), valid_sources).all():
36+
raise ValueError("For GBQR, the only supported data sources are 'nhsn', 'flusurvnet', 'ilinet', or 'nssp'.")
37+
38+
# Check if both nhsn and nssp data are included as sources
39+
if all(src in self.model_config.sources for src in ["nhsn", "nssp"]):
40+
raise ValueError("Only one of 'nhsn' or 'nssp' may be selected as a data source.")
41+
3442
fdl = DiseaseDataLoader()
35-
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
36-
ilinet_kwargs=ilinet_kwargs,
37-
flusurvnet_kwargs=flusurvnet_kwargs,
38-
sources=self.model_config.sources,
39-
power_transform=self.model_config.power_transform)
40-
if run_config.locations is not None:
41-
df = df.loc[df["location"].isin(run_config.locations)]
43+
if "nhsn" in self.model_config.sources:
44+
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
45+
ilinet_kwargs=ilinet_kwargs,
46+
flusurvnet_kwargs=flusurvnet_kwargs,
47+
sources=self.model_config.sources,
48+
power_transform=self.model_config.power_transform)
49+
elif "nssp" in self.model_config.sources:
50+
df = fdl.load_data(nssp_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
51+
ilinet_kwargs=ilinet_kwargs,
52+
flusurvnet_kwargs=flusurvnet_kwargs,
53+
sources=self.model_config.sources,
54+
power_transform=self.model_config.power_transform)
55+
56+
if (run_config.states == []) & (run_config.hsas == []):
57+
raise ValueError("User must request a non-empty set of locations to forecast for.")
58+
59+
if (run_config.states != []) & (run_config.hsas != []):
60+
raise NotImplementedError("Functionality for simultaneously forecasting state- and hsa-level locations is not yet implemented.")
4261

62+
df_states = df.loc[(df["location"].isin(run_config.states)) & (df["agg_level"] != "hsa")]
63+
df_hsas = df.loc[(df["location"].isin(run_config.hsas)) & (df["agg_level"] == "hsa")]
64+
df = pd.concat([df_states, df_hsas], join = "inner", axis = 0)
65+
4366
# augment data with features and target values
4467
if run_config.disease == "flu":
4568
init_feats = ["inc_trans_cs", "season_week", "log_pop"]
@@ -133,7 +156,7 @@ def _train_gbq_and_predict(self, run_config,
133156
"inc_trans_cs", "horizon",
134157
"inc_trans_center_factor", "inc_trans_scale_factor"]
135158
preds_df = df_test_w_preds[cols_to_keep + run_config.q_labels]
136-
preds_df = preds_df.loc[(preds_df["source"] == "nhsn")]
159+
preds_df = preds_df.loc[preds_df["source"].isin(["nhsn", "nssp"])]
137160
preds_df = pd.melt(preds_df,
138161
id_vars=cols_to_keep,
139162
var_name="quantile",
@@ -149,11 +172,20 @@ def _train_gbq_and_predict(self, run_config,
149172
else:
150173
raise ValueError('unsupported power_transform: must be "4rt" or None')
151174

152-
preds_df["value"] = (np.maximum(preds_df["inc_trans_target_hat"], 0.0) ** inv_power - 0.01 - 0.75**4) * preds_df["pop"] / 100000
153-
preds_df["value"] = np.maximum(preds_df["value"], 0.0)
175+
preds_df["value"] = (np.maximum(preds_df["inc_trans_target_hat"], 0.0) ** inv_power - 0.01 - 0.75**4)
154176

155177
# get predictions into the format needed for FluSight hub submission
156-
preds_df = self._format_as_flusight_output(preds_df, run_config.ref_date, run_config.disease)
178+
if "nhsn" in preds_df["source"].unique():
179+
# turn nhsn rates back into counts
180+
preds_df["value"] = preds_df["value"] * preds_df["pop"] / 100000
181+
target_name = "wk inc " + run_config.disease + " hosp"
182+
elif "nssp" in preds_df["source"].unique():
183+
preds_df["value"] = preds_df["value"] / 100 # percentage to proportion
184+
preds_df["value"] = np.minimum(preds_df["value"], 1.0)
185+
target_name = "wk inc " + run_config.disease + " prop ed visits"
186+
187+
preds_df["value"] = np.maximum(preds_df["value"], 0.0)
188+
preds_df = self._format_as_flusight_output(preds_df, run_config.ref_date, target_name)
157189

158190
# sort quantiles to avoid quantile crossing
159191
preds_df = self._quantile_noncrossing(
@@ -248,15 +280,15 @@ def _get_test_quantile_predictions(self, run_config,
248280
return test_pred_qs_df
249281

250282

251-
def _format_as_flusight_output(self, preds_df, ref_date, disease):
283+
def _format_as_flusight_output(self, preds_df, ref_date, target_name):
252284
# keep just required columns and rename to match hub format
253285
preds_df = preds_df[["location", "wk_end_date", "horizon", "quantile", "value"]] \
254286
.rename(columns={"quantile": "output_type_id"})
255287

256288
preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta(7*preds_df["horizon"], unit="days")
257289
preds_df["reference_date"] = ref_date
258290
preds_df["horizon"] = (pd.to_timedelta(preds_df["target_end_date"].dt.date - ref_date).dt.days / 7).astype(int)
259-
preds_df["target"] = "wk inc " + disease + " hosp"
291+
preds_df["target"] = target_name
260292

261293
preds_df["output_type"] = "quantile"
262294
preds_df.drop(columns="wk_end_date", inplace=True)

src/idmodels/sarix.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,35 @@ def _get_extra_sarix_params(self, df):
1717
return {}
1818

1919
def run(self, run_config):
20+
valid_sources = np.array(["nhsn", "nssp"])
21+
if not np.isin(np.array(self.model_config.sources), valid_sources).all():
22+
raise ValueError("For SARIX, the only supported data sources are 'nhsn' or 'nssp'.")
23+
24+
# Check if both nhsn and nssp data are included as sources
25+
if all(src in self.model_config.sources for src in ["nhsn", "nssp"]):
26+
raise ValueError("Only one of 'nhsn' or 'nssp' may be selected as a data source.")
27+
2028
fdl = DiseaseDataLoader()
21-
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
22-
sources=self.model_config.sources,
23-
power_transform=self.model_config.power_transform)
24-
if run_config.locations is not None:
25-
df = df.loc[df["location"].isin(run_config.locations)]
29+
if "nhsn" in self.model_config.sources:
30+
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
31+
sources=self.model_config.sources,
32+
power_transform=self.model_config.power_transform)
33+
target_name = "wk inc " + run_config.disease + " hosp"
34+
elif "nssp" in self.model_config.sources:
35+
df = fdl.load_data(nssp_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
36+
sources=self.model_config.sources,
37+
power_transform=self.model_config.power_transform)
38+
target_name = "wk inc " + run_config.disease + " prop ed visits"
39+
40+
if (run_config.states == []) & (run_config.hsas == []):
41+
raise ValueError("User must request a non-empty set of locations to forecast for.")
42+
43+
if (run_config.states != []) & (run_config.hsas != []):
44+
raise NotImplementedError("Functionality for simultaneously forecasting state- and hsa-level locations is not yet implemented.")
45+
46+
df_states = df.loc[(df["location"].isin(run_config.states)) & (df["agg_level"] != "hsa")]
47+
df_hsas = df.loc[(df["location"].isin(run_config.hsas)) & (df["agg_level"] == "hsa")]
48+
df = pd.concat([df_states, df_hsas], join = "inner", axis = 0)
2649

2750
# season week relative to christmas
2851
df = df.merge(
@@ -34,10 +57,12 @@ def run(self, run_config):
3457
on="season") \
3558
.assign(delta_xmas = lambda x: x["season_week"] - x["xmas_week"])
3659
df["xmas_spike"] = np.maximum(3 - np.abs(df["delta_xmas"]), 0)
37-
60+
61+
# missing values are interpolated when possible
3862
xy_colnames = self.model_config.x + ["inc_trans_cs"]
3963
df = df.query("wk_end_date >= '2022-10-01'").interpolate()
40-
batched_xy = df[xy_colnames].values.reshape(len(df["location"].unique()), -1, len(xy_colnames))
64+
unique_locations = len(df_states["location"].unique()) + len(df_hsas["location"].unique())
65+
batched_xy = df[xy_colnames].values.reshape(unique_locations, -1, len(xy_colnames))
4166

4267
# Get any extra parameters for the SARIX constructor
4368
extra_params = self._get_extra_sarix_params(df)
@@ -62,18 +87,18 @@ def run(self, run_config):
6287
pred_qs = _np_percentile(sarix_fit_all_locs_theta_pooled.predictions[..., :, :, 0],
6388
np.array(run_config.q_levels) * 100, axis=0)
6489

65-
df_nhsn_last_obs = df.groupby(["location"]).tail(1)
90+
df_data_last_obs = df.groupby(["location", "agg_level"]).tail(1)
6691

6792
preds_df = pd.concat([
6893
pd.DataFrame(pred_qs[i, :, :]) \
69-
.set_axis(df_nhsn_last_obs["location"], axis="index") \
94+
.set_axis(df_data_last_obs["location"], axis="index") \
7095
.set_axis(np.arange(1, run_config.max_horizon+1), axis="columns") \
7196
.assign(output_type_id = q_label) \
7297
for i, q_label in enumerate(run_config.q_labels)
7398
]) \
7499
.reset_index() \
75100
.melt(["location", "output_type_id"], var_name="horizon") \
76-
.merge(df_nhsn_last_obs, on="location", how="left")
101+
.merge(df_data_last_obs, on="location", how="left")
77102

78103
# build data frame with predictions on the original scale
79104
preds_df["value"] = (preds_df["value"] + preds_df["inc_trans_center_factor"]) * preds_df["inc_trans_scale_factor"]
@@ -82,19 +107,27 @@ def run(self, run_config):
82107
else:
83108
preds_df["value"] = np.maximum(preds_df["value"], 0.0) ** 2
84109

85-
preds_df["value"] = (preds_df["value"] - 0.01 - 0.75**4) * preds_df["pop"] / 100000
110+
preds_df["value"] = (preds_df["value"] - 0.01 - 0.75**4)
86111
preds_df["value"] = np.maximum(preds_df["value"], 0.0)
87112

113+
if "nhsn" in preds_df["source"].unique():
114+
# turn nhsn rates back into counts
115+
preds_df["value"] = preds_df["value"] * preds_df["pop"] / 100000
116+
117+
if target_name == "wk inc " + run_config.disease + " prop ed visits":
118+
preds_df["value"] = preds_df["value"] / 100 # percentage to proportion
119+
preds_df["value"] = np.minimum(preds_df["value"], 1.0)
120+
88121
# keep just required columns and rename to match hub format
89122
preds_df = preds_df[["location", "wk_end_date", "horizon", "output_type_id", "value"]]
90123

91124
preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta(7*preds_df["horizon"], unit="days")
92125
preds_df["reference_date"] = run_config.ref_date
93126
preds_df["horizon"] = (pd.to_timedelta(preds_df["target_end_date"].dt.date - run_config.ref_date).dt.days / 7).astype(int)
94127
preds_df["output_type"] = "quantile"
95-
preds_df["target"] = "wk inc " + run_config.disease + " hosp"
128+
preds_df["target"] = target_name
96129
preds_df.drop(columns="wk_end_date", inplace=True)
97-
130+
98131
# save
99132
save_path = build_save_path(
100133
root=run_config.output_root,

tests/integration/data/UMass-gbqr_no_reporting_adj/2024-01-06-UMass-gbqr_no_reporting_adj.csv renamed to tests/integration/data/UMass-gbqr_nhsn_no_reporting_adj/2024-01-06-UMass-gbqr_nhsn_no_reporting_adj.csv

File renamed without changes.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
location,reference_date,horizon,target_end_date,target,output_type,output_type_id,value
2+
1,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.025,0.0
3+
25,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.025,0.0
4+
99,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.025,0.0
5+
1,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.025,0.0
6+
25,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.025,0.0
7+
99,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.025,0.0
8+
1,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.025,0.0
9+
25,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.025,0.0
10+
99,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.025,0.0
11+
1,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.5,0.0
12+
25,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.5,0.0
13+
99,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.5,0.0
14+
1,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.5,0.0
15+
25,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.5,0.0
16+
99,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.5,0.0
17+
1,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.5,0.0
18+
25,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.5,0.0
19+
99,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.5,0.0
20+
1,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.975,0.0
21+
25,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.975,0.0
22+
99,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.975,0.0
23+
1,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.975,0.0
24+
25,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.975,0.0
25+
99,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.975,0.0
26+
1,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.975,0.0
27+
25,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.975,0.0
28+
99,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.975,0.0

0 commit comments

Comments
 (0)