Skip to content

Commit 2db039f

Browse files
committed
Refactor sarix tests to create configs w helpers
1 parent efa0ff5 commit 2db039f

1 file changed

Lines changed: 48 additions & 73 deletions

File tree

tests/integration/test_sarix.py

Lines changed: 48 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -11,56 +11,48 @@
1111

1212

1313
def test_sarix_nhsn(tmp_path):
14-
model_config = SimpleNamespace(
15-
model_class = "sarix",
16-
model_name = "sarix_nhsn_p6_4rt_thetashared_sigmanone",
17-
18-
# data sources and adjustments for reporting issues
19-
sources = ["nhsn"],
20-
21-
# fit locations separately or jointly
22-
fit_locations_separately = False,
23-
24-
# SARI model parameters
25-
p = 6,
26-
P = 0,
27-
d = 0,
28-
D = 0,
29-
season_period = 1,
30-
31-
# power transform applied to surveillance signals
32-
power_transform = "4rt",
14+
date = datetime.date.fromisoformat("2024-01-06")
15+
fips_codes = ["US", "01", "02", "04", "05", "06", "08", "09", "10", "11",
16+
"12", "13", "15", "16", "17", "18", "19", "20", "21", "22",
17+
"23", "24", "25", "26", "27", "28", "29", "30", "31", "32",
18+
"33", "34", "35", "36", "37", "38", "39", "40", "41", "42",
19+
"44", "45", "46", "47", "48", "49", "50", "51", "53", "54",
20+
"55", "56", "72"]
21+
model_config = create_test_sarix_model_config(main_source=["nhsn"])
22+
run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path)
23+
24+
# patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs
25+
with patch("idmodels.sarix._np_percentile", return_value=_np_percentile_val()):
26+
model = SARIXModel(model_config)
27+
model.run(run_config)
3328

34-
# sharing of information about parameters
35-
theta_pooling="shared",
36-
sigma_pooling="none",
37-
38-
# covariates
39-
x = []
29+
actual_df = pd.read_csv(
30+
run_config.output_root / f"UMass-{model_config.model_name}" /
31+
f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv"
4032
)
41-
42-
run_config = SimpleNamespace(
43-
disease="flu",
44-
ref_date=datetime.date.fromisoformat("2024-01-06"),
45-
output_root=tmp_path / "model-output",
46-
artifact_store_root=tmp_path / "artifact-store",
47-
save_feat_importance=False,
48-
locations=["US", "01", "02", "04", "05", "06", "08", "09", "10", "11",
49-
"12", "13", "15", "16", "17", "18", "19", "20", "21", "22",
50-
"23", "24", "25", "26", "27", "28", "29", "30", "31", "32",
51-
"33", "34", "35", "36", "37", "38", "39", "40", "41", "42",
52-
"44", "45", "46", "47", "48", "49", "50", "51", "53", "54",
53-
"55", "56", "72"],
54-
max_horizon=3,
55-
q_levels = [0.025, 0.50, 0.975],
56-
q_labels = ["0.025", "0.5", "0.975"],
57-
num_warmup = 200,
58-
num_samples = 200,
59-
num_chains = 1
33+
expected_df = pd.read_csv(
34+
Path("tests") / "integration" / "data" /
35+
f"UMass-{model_config.model_name}" /
36+
f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv"
6037
)
38+
assert_frame_equal(actual_df, expected_df)
39+
6140

41+
def test_sarix_nssp(tmp_path):
42+
date = datetime.date.fromisoformat("2025-09-27")
43+
# Missouri (29) does not submit to NSSP
44+
fips_codes = ["US", "01", "02", "04", "05", "06", "08", "09", "10", "11",
45+
"12", "13", "15", "16", "17", "18", "19", "20", "21", "22",
46+
"23", "24", "25", "26", "27", "28", "30", "31", "32",
47+
"33", "34", "35", "36", "37", "38", "39", "40", "41", "42",
48+
"44", "45", "46", "47", "48", "49", "50", "51", "53", "54",
49+
"55", "56"]
50+
model_config = create_test_sarix_model_config(main_source=["nssp"])
51+
run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path)
52+
6253
# patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs
63-
with patch("idmodels.sarix._np_percentile", return_value=_np_percentile_val()):
54+
# nssp data only covers 51 locations
55+
with patch("idmodels.sarix._np_percentile", return_value=_np_percentile_val()[:, 0:51, :]):
6456
model = SARIXModel(model_config)
6557
model.run(run_config)
6658

@@ -75,14 +67,15 @@ def test_sarix_nhsn(tmp_path):
7567
)
7668
assert_frame_equal(actual_df, expected_df)
7769

70+
# hsas=["25", "150"]
7871

79-
def test_sarix_nssp(tmp_path):
72+
def create_test_sarix_model_config(main_source):
8073
model_config = SimpleNamespace(
8174
model_class = "sarix",
82-
model_name = "sarix_nssp_p6_4rt_thetashared_sigmanone",
75+
model_name = "sarix_" + main_source[0] + "_p6_4rt_thetashared_sigmanone",
8376

8477
# data sources and adjustments for reporting issues
85-
sources = ["nssp"],
78+
sources = main_source,
8679

8780
# fit locations separately or jointly
8881
fit_locations_separately = False,
@@ -104,44 +97,26 @@ def test_sarix_nssp(tmp_path):
10497
# covariates
10598
x = []
10699
)
100+
return model_config
107101

102+
def create_test_sarix_run_config(ref_date, states, hsas, tmp_path):
108103
run_config = SimpleNamespace(
109104
disease="flu",
110-
ref_date=datetime.date.fromisoformat("2025-09-27"),
105+
ref_date=ref_date,
111106
output_root=tmp_path / "model-output",
112107
artifact_store_root=tmp_path / "artifact-store",
113108
save_feat_importance=False,
114-
locations=["US", "01", "02", "04", "05", "06", "08", "09", "10", "11",
115-
"12", "13", "15", "16", "17", "18", "19", "20", "21", "22",
116-
"23", "24", "25", "26", "27", "28", "29", "30", "31", "32",
117-
"33", "34", "35", "36", "37", "38", "39", "40", "41", "42",
118-
"44", "45", "46", "47", "48", "49", "50", "51", "53", "54",
119-
"55", "56", "72"],
109+
states=states,
110+
hsas = hsas,
120111
max_horizon=3,
121112
q_levels = [0.025, 0.50, 0.975],
122113
q_labels = ["0.025", "0.5", "0.975"],
123114
num_warmup = 200,
124115
num_samples = 200,
125116
num_chains = 1
126117
)
127-
128-
# patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs
129-
# nssp data only covers 51 locations
130-
with patch("idmodels.sarix._np_percentile", return_value=_np_percentile_val()[:, 0:51, :]):
131-
model = SARIXModel(model_config)
132-
model.run(run_config)
133-
134-
actual_df = pd.read_csv(
135-
run_config.output_root / f"UMass-{model_config.model_name}" /
136-
f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv"
137-
)
138-
expected_df = pd.read_csv(
139-
Path("tests") / "integration" / "data" /
140-
f"UMass-{model_config.model_name}" /
141-
f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv"
142-
)
143-
assert_frame_equal(actual_df, expected_df)
144-
118+
return run_config
119+
145120
def _np_percentile_val():
146121
return numpy.array(
147122
[[[2.22541624e-01, 1.82324940e-01, 1.27709944e-01],

0 commit comments

Comments
 (0)