1111
1212
1313def test_sarix_nhsn (tmp_path ):
14- model_config = SimpleNamespace (
15- model_class = "sarix" ,
16- model_name = "sarix_nhsn_p6_4rt_thetashared_sigmanone" ,
17-
18- # data sources and adjustments for reporting issues
19- sources = ["nhsn" ],
20-
21- # fit locations separately or jointly
22- fit_locations_separately = False ,
23-
24- # SARI model parameters
25- p = 6 ,
26- P = 0 ,
27- d = 0 ,
28- D = 0 ,
29- season_period = 1 ,
30-
31- # power transform applied to surveillance signals
32- power_transform = "4rt" ,
14+ date = datetime .date .fromisoformat ("2024-01-06" )
15+ fips_codes = ["US" , "01" , "02" , "04" , "05" , "06" , "08" , "09" , "10" , "11" ,
16+ "12" , "13" , "15" , "16" , "17" , "18" , "19" , "20" , "21" , "22" ,
17+ "23" , "24" , "25" , "26" , "27" , "28" , "29" , "30" , "31" , "32" ,
18+ "33" , "34" , "35" , "36" , "37" , "38" , "39" , "40" , "41" , "42" ,
19+ "44" , "45" , "46" , "47" , "48" , "49" , "50" , "51" , "53" , "54" ,
20+ "55" , "56" , "72" ]
21+ model_config = create_test_sarix_model_config (main_source = ["nhsn" ])
22+ run_config = create_test_sarix_run_config (ref_date = date , states = fips_codes , hsas = [], tmp_path = tmp_path )
23+
24+ # patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs
25+ with patch ("idmodels.sarix._np_percentile" , return_value = _np_percentile_val ()):
26+ model = SARIXModel (model_config )
27+ model .run (run_config )
3328
34- # sharing of information about parameters
35- theta_pooling = "shared" ,
36- sigma_pooling = "none" ,
37-
38- # covariates
39- x = []
29+ actual_df = pd .read_csv (
30+ run_config .output_root / f"UMass-{ model_config .model_name } " /
31+ f"{ str (run_config .ref_date )} -UMass-{ model_config .model_name } .csv"
4032 )
41-
42- run_config = SimpleNamespace (
43- disease = "flu" ,
44- ref_date = datetime .date .fromisoformat ("2024-01-06" ),
45- output_root = tmp_path / "model-output" ,
46- artifact_store_root = tmp_path / "artifact-store" ,
47- save_feat_importance = False ,
48- locations = ["US" , "01" , "02" , "04" , "05" , "06" , "08" , "09" , "10" , "11" ,
49- "12" , "13" , "15" , "16" , "17" , "18" , "19" , "20" , "21" , "22" ,
50- "23" , "24" , "25" , "26" , "27" , "28" , "29" , "30" , "31" , "32" ,
51- "33" , "34" , "35" , "36" , "37" , "38" , "39" , "40" , "41" , "42" ,
52- "44" , "45" , "46" , "47" , "48" , "49" , "50" , "51" , "53" , "54" ,
53- "55" , "56" , "72" ],
54- max_horizon = 3 ,
55- q_levels = [0.025 , 0.50 , 0.975 ],
56- q_labels = ["0.025" , "0.5" , "0.975" ],
57- num_warmup = 200 ,
58- num_samples = 200 ,
59- num_chains = 1
33+ expected_df = pd .read_csv (
34+ Path ("tests" ) / "integration" / "data" /
35+ f"UMass-{ model_config .model_name } " /
36+ f"{ str (run_config .ref_date )} -UMass-{ model_config .model_name } .csv"
6037 )
38+ assert_frame_equal (actual_df , expected_df )
39+
6140
41+ def test_sarix_nssp (tmp_path ):
42+ date = datetime .date .fromisoformat ("2025-09-27" )
43+ # Missouri (29) does not submit to NSSP
44+ fips_codes = ["US" , "01" , "02" , "04" , "05" , "06" , "08" , "09" , "10" , "11" ,
45+ "12" , "13" , "15" , "16" , "17" , "18" , "19" , "20" , "21" , "22" ,
46+ "23" , "24" , "25" , "26" , "27" , "28" , "30" , "31" , "32" ,
47+ "33" , "34" , "35" , "36" , "37" , "38" , "39" , "40" , "41" , "42" ,
48+ "44" , "45" , "46" , "47" , "48" , "49" , "50" , "51" , "53" , "54" ,
49+ "55" , "56" ]
50+ model_config = create_test_sarix_model_config (main_source = ["nssp" ])
51+ run_config = create_test_sarix_run_config (ref_date = date , states = fips_codes , hsas = [], tmp_path = tmp_path )
52+
6253 # patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs
63- with patch ("idmodels.sarix._np_percentile" , return_value = _np_percentile_val ()):
54+ # nssp data only covers 51 locations
55+ with patch ("idmodels.sarix._np_percentile" , return_value = _np_percentile_val ()[:, 0 :51 , :]):
6456 model = SARIXModel (model_config )
6557 model .run (run_config )
6658
@@ -75,14 +67,15 @@ def test_sarix_nhsn(tmp_path):
7567 )
7668 assert_frame_equal (actual_df , expected_df )
7769
70+ # hsas=["25", "150"]
7871
79- def test_sarix_nssp ( tmp_path ):
72+ def create_test_sarix_model_config ( main_source ):
8073 model_config = SimpleNamespace (
8174 model_class = "sarix" ,
82- model_name = "sarix_nssp_p6_4rt_thetashared_sigmanone " ,
75+ model_name = "sarix_" + main_source [ 0 ] + "_p6_4rt_thetashared_sigmanone " ,
8376
8477 # data sources and adjustments for reporting issues
85- sources = [ "nssp" ] ,
78+ sources = main_source ,
8679
8780 # fit locations separately or jointly
8881 fit_locations_separately = False ,
@@ -104,44 +97,26 @@ def test_sarix_nssp(tmp_path):
10497 # covariates
10598 x = []
10699 )
100+ return model_config
107101
102+ def create_test_sarix_run_config (ref_date , states , hsas , tmp_path ):
108103 run_config = SimpleNamespace (
109104 disease = "flu" ,
110- ref_date = datetime . date . fromisoformat ( "2025-09-27" ) ,
105+ ref_date = ref_date ,
111106 output_root = tmp_path / "model-output" ,
112107 artifact_store_root = tmp_path / "artifact-store" ,
113108 save_feat_importance = False ,
114- locations = ["US" , "01" , "02" , "04" , "05" , "06" , "08" , "09" , "10" , "11" ,
115- "12" , "13" , "15" , "16" , "17" , "18" , "19" , "20" , "21" , "22" ,
116- "23" , "24" , "25" , "26" , "27" , "28" , "29" , "30" , "31" , "32" ,
117- "33" , "34" , "35" , "36" , "37" , "38" , "39" , "40" , "41" , "42" ,
118- "44" , "45" , "46" , "47" , "48" , "49" , "50" , "51" , "53" , "54" ,
119- "55" , "56" , "72" ],
109+ states = states ,
110+ hsas = hsas ,
120111 max_horizon = 3 ,
121112 q_levels = [0.025 , 0.50 , 0.975 ],
122113 q_labels = ["0.025" , "0.5" , "0.975" ],
123114 num_warmup = 200 ,
124115 num_samples = 200 ,
125116 num_chains = 1
126117 )
127-
128- # patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs
129- # nssp data only covers 51 locations
130- with patch ("idmodels.sarix._np_percentile" , return_value = _np_percentile_val ()[:, 0 :51 , :]):
131- model = SARIXModel (model_config )
132- model .run (run_config )
133-
134- actual_df = pd .read_csv (
135- run_config .output_root / f"UMass-{ model_config .model_name } " /
136- f"{ str (run_config .ref_date )} -UMass-{ model_config .model_name } .csv"
137- )
138- expected_df = pd .read_csv (
139- Path ("tests" ) / "integration" / "data" /
140- f"UMass-{ model_config .model_name } " /
141- f"{ str (run_config .ref_date )} -UMass-{ model_config .model_name } .csv"
142- )
143- assert_frame_equal (actual_df , expected_df )
144-
118+ return run_config
119+
145120def _np_percentile_val ():
146121 return numpy .array (
147122 [[[2.22541624e-01 , 1.82324940e-01 , 1.27709944e-01 ],
0 commit comments