Skip to content

Commit 96157e3

Browse files
committed
add support for fourier pooling option
1 parent 8e515d4 commit 96157e3

2 files changed

Lines changed: 197 additions & 6 deletions

File tree

src/idmodels/sarix.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,14 @@ class SARIXFourierModel(SARIXModel):
116116
SARIX model with Fourier seasonality terms.
117117
118118
Adds annual seasonal patterns using Fourier harmonics to the base SARIX model.
119-
Requires fourier_K parameter in model_config to specify number of harmonic pairs.
119+
120+
Required model_config parameters:
121+
- fourier_K: Number of Fourier harmonic pairs (int)
122+
- fourier_pooling: How to share Fourier coefficients across locations ('none' or 'shared')
120123
"""
121124
def _get_sarix_module(self):
122-
"""Return the sarix_fourier module for Fourier-enhanced fitting."""
123-
from sarixfourier import sarix_fourier
124-
return sarix_fourier
125+
"""Return the sarix module (same module, but with Fourier parameters)."""
126+
return sarix
125127

126128
def _get_extra_sarix_params(self, df):
127129
"""Return Fourier-specific parameters for SARIX constructor."""
@@ -131,7 +133,8 @@ def _get_extra_sarix_params(self, df):
131133

132134
return {
133135
"day_of_year": day_of_year,
134-
"fourier_K": self.model_config.fourier_K
136+
"fourier_K": self.model_config.fourier_K,
137+
"fourier_pooling": self.model_config.fourier_pooling
135138
}
136139

137140

tests/integration/test_sarix.py

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88
from pandas.testing import assert_frame_equal
99

10-
from idmodels.sarix import SARIXModel
10+
from idmodels.sarix import SARIXModel, SARIXFourierModel
1111

1212

1313
def test_sarix(tmp_path):
@@ -145,6 +145,194 @@ def test_sarix_shared_sigma_pooling_multiple_batches(tmp_path):
145145
"All predictions should be non-negative"
146146

147147

148+
def test_sarix_fourier_none_pooling(tmp_path):
149+
"""Test SARIXFourierModel with fourier_pooling='none' (unpooled)."""
150+
model_config = SimpleNamespace(
151+
model_class="sarix_fourier",
152+
model_name="sarix_p2_fourier_K2_none",
153+
154+
# data sources
155+
sources=["nhsn"],
156+
157+
# fit locations separately or jointly
158+
fit_locations_separately=False,
159+
160+
# SARIX model parameters
161+
p=2,
162+
P=0,
163+
d=0,
164+
D=0,
165+
season_period=1,
166+
167+
# power transform
168+
power_transform="4rt",
169+
170+
# parameter pooling
171+
theta_pooling="shared",
172+
sigma_pooling="shared",
173+
174+
# Fourier parameters
175+
fourier_K=2,
176+
fourier_pooling="none", # Unpooled Fourier coefficients
177+
178+
# covariates
179+
x=[]
180+
)
181+
182+
# Use subset of locations for faster testing
183+
run_config = SimpleNamespace(
184+
disease="flu",
185+
ref_date=datetime.date.fromisoformat("2024-01-06"),
186+
output_root=tmp_path / "model-output",
187+
artifact_store_root=tmp_path / "artifact-store",
188+
save_feat_importance=False,
189+
locations=["US", "01", "02", "04", "05"],
190+
max_horizon=2, # Reduced for faster testing
191+
q_levels=[0.025, 0.50, 0.975],
192+
q_labels=["0.025", "0.5", "0.975"],
193+
num_warmup=50, # Reduced for faster testing
194+
num_samples=50,
195+
num_chains=1
196+
)
197+
198+
model = SARIXFourierModel(model_config)
199+
model.run(run_config)
200+
201+
# Verify output structure
202+
actual_df = pd.read_csv(
203+
run_config.output_root / "UMass-sarix_p2_fourier_K2_none" /
204+
"2024-01-06-UMass-sarix_p2_fourier_K2_none.csv"
205+
)
206+
207+
# Assertions
208+
assert len(actual_df) > 0, "Output dataframe should not be empty"
209+
assert set(actual_df["location"].unique()) == set(run_config.locations), \
210+
"Output should contain predictions for all input locations"
211+
assert all(actual_df["output_type"] == "quantile"), \
212+
"All outputs should be quantiles"
213+
assert set(actual_df["output_type_id"].astype(str).unique()) == set(run_config.q_labels), \
214+
"Output should contain all specified quantile levels"
215+
assert actual_df["value"].notna().all(), \
216+
"All predictions should be non-null"
217+
assert (actual_df["value"] >= 0).all(), \
218+
"All predictions should be non-negative"
219+
220+
221+
def test_sarix_fourier_shared_pooling(tmp_path):
222+
"""Test SARIXFourierModel with fourier_pooling='shared' (pooled across locations)."""
223+
model_config = SimpleNamespace(
224+
model_class="sarix_fourier",
225+
model_name="sarix_p2_fourier_K2_shared",
226+
227+
# data sources
228+
sources=["nhsn"],
229+
230+
# fit locations separately or jointly
231+
fit_locations_separately=False,
232+
233+
# SARIX model parameters
234+
p=2,
235+
P=0,
236+
d=0,
237+
D=0,
238+
season_period=1,
239+
240+
# power transform
241+
power_transform="4rt",
242+
243+
# parameter pooling
244+
theta_pooling="shared",
245+
sigma_pooling="shared",
246+
247+
# Fourier parameters
248+
fourier_K=2,
249+
fourier_pooling="shared", # Shared Fourier coefficients
250+
251+
# covariates
252+
x=[]
253+
)
254+
255+
# Use subset of locations for faster testing
256+
run_config = SimpleNamespace(
257+
disease="flu",
258+
ref_date=datetime.date.fromisoformat("2024-01-06"),
259+
output_root=tmp_path / "model-output",
260+
artifact_store_root=tmp_path / "artifact-store",
261+
save_feat_importance=False,
262+
locations=["US", "01", "02", "04", "05"],
263+
max_horizon=2,
264+
q_levels=[0.025, 0.50, 0.975],
265+
q_labels=["0.025", "0.5", "0.975"],
266+
num_warmup=50,
267+
num_samples=50,
268+
num_chains=1
269+
)
270+
271+
model = SARIXFourierModel(model_config)
272+
model.run(run_config)
273+
274+
# Verify output structure
275+
actual_df = pd.read_csv(
276+
run_config.output_root / "UMass-sarix_p2_fourier_K2_shared" /
277+
"2024-01-06-UMass-sarix_p2_fourier_K2_shared.csv"
278+
)
279+
280+
# Assertions
281+
assert len(actual_df) > 0, "Output dataframe should not be empty"
282+
assert set(actual_df["location"].unique()) == set(run_config.locations), \
283+
"Output should contain predictions for all input locations"
284+
assert all(actual_df["output_type"] == "quantile"), \
285+
"All outputs should be quantiles"
286+
assert set(actual_df["output_type_id"].astype(str).unique()) == set(run_config.q_labels), \
287+
"Output should contain all specified quantile levels"
288+
assert actual_df["value"].notna().all(), \
289+
"All predictions should be non-null"
290+
assert (actual_df["value"] >= 0).all(), \
291+
"All predictions should be non-negative"
292+
293+
294+
def test_sarix_fourier_missing_pooling_parameter():
295+
"""Test that SARIXFourierModel raises error when fourier_pooling is missing."""
296+
model_config = SimpleNamespace(
297+
model_class="sarix_fourier",
298+
model_name="sarix_p2_fourier_K2_nopooling",
299+
sources=["nhsn"],
300+
fit_locations_separately=False,
301+
p=2, P=0, d=0, D=0, season_period=1,
302+
power_transform="4rt",
303+
theta_pooling="shared",
304+
sigma_pooling="shared",
305+
fourier_K=2,
306+
# fourier_pooling is MISSING - should cause error
307+
x=[]
308+
)
309+
310+
run_config = SimpleNamespace(
311+
disease="flu",
312+
ref_date=datetime.date.fromisoformat("2024-01-06"),
313+
output_root=Path("/tmp") / "model-output",
314+
artifact_store_root=Path("/tmp") / "artifact-store",
315+
save_feat_importance=False,
316+
locations=["US"],
317+
max_horizon=1,
318+
q_levels=[0.5],
319+
q_labels=["0.5"],
320+
num_warmup=10,
321+
num_samples=10,
322+
num_chains=1
323+
)
324+
325+
model = SARIXFourierModel(model_config)
326+
327+
# Should raise AttributeError when trying to access missing fourier_pooling
328+
try:
329+
model.run(run_config)
330+
assert False, "Should have raised AttributeError for missing fourier_pooling"
331+
except AttributeError as e:
332+
assert "fourier_pooling" in str(e), \
333+
f"Error should mention fourier_pooling, got: {str(e)}"
334+
335+
148336
def _np_percentile_val():
149337
return numpy.array(
150338
[[[2.22541624e-01, 1.82324940e-01, 1.27709944e-01],

0 commit comments

Comments
 (0)