Skip to content

Commit 706ec05

Browse files
nickreichclaude
andcommitted
Add support for t-distributed innovations in SARIXModel
- Extend _get_extra_sarix_params() to pass innovation_dist and innovation_df_prior_scale - Update SARIXFourierModel to properly merge base and Fourier parameters - Add test_sarix_tdist_innovations() integration test - Validates model runs successfully with t-distributed errors - Verifies output structure and prediction quality Depends on sarix library support for t-distributed innovations. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 6dc1e0b commit 706ec05

2 files changed

Lines changed: 89 additions & 3 deletions

File tree

src/idmodels/sarix.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@ def __init__(self, model_config):
1414

1515
def _get_extra_sarix_params(self, df):
1616
"""Return extra parameters to pass to SARIX constructor. Returns empty dict by default."""
17-
return {}
17+
extra_params = {}
18+
19+
# Add innovation distribution parameters if specified
20+
if hasattr(self.model_config, 'innovation_dist'):
21+
extra_params['innovation_dist'] = self.model_config.innovation_dist
22+
if hasattr(self.model_config, 'innovation_df_prior_scale'):
23+
extra_params['innovation_df_prior_scale'] = self.model_config.innovation_df_prior_scale
24+
25+
return extra_params
1826

1927
def run(self, run_config):
2028
fdl = DiseaseDataLoader()
@@ -118,15 +126,21 @@ class SARIXFourierModel(SARIXModel):
118126
"""
119127
def _get_extra_sarix_params(self, df):
120128
"""Return Fourier-specific parameters for SARIX constructor."""
129+
# Get base parameters (includes innovation_dist if specified)
130+
extra_params = super()._get_extra_sarix_params(df)
131+
121132
# Extract day-of-year from dates for Fourier features
122133
# Take the first location's dates (same for all locations after reshaping)
123134
day_of_year = df.groupby("location")["wk_end_date"].apply(lambda x: x.dt.dayofyear.values).iloc[0]
124135

125-
return {
136+
# Add Fourier-specific parameters
137+
extra_params.update({
126138
"day_of_year": day_of_year,
127139
"fourier_K": self.model_config.fourier_K,
128140
"fourier_pooling": self.model_config.fourier_pooling
129-
}
141+
})
142+
143+
return extra_params
130144

131145

132146
def _np_percentile(predictions, q_levels, axis):

tests/integration/test_sarix.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,78 @@ def test_sarix_fourier_missing_pooling_parameter():
333333
f"Error should mention fourier_pooling, got: {str(e)}"
334334

335335

336+
def test_sarix_tdist_innovations(tmp_path):
337+
"""Test SARIX model with t-distributed innovations."""
338+
model_config = SimpleNamespace(
339+
model_class="sarix",
340+
model_name="sarix_p6_4rt_thetashared_sigmanone_tdist",
341+
342+
# data sources and adjustments for reporting issues
343+
sources=["nhsn"],
344+
345+
# fit locations separately or jointly
346+
fit_locations_separately=False,
347+
348+
# SARI model parameters
349+
p=6,
350+
P=0,
351+
d=0,
352+
D=0,
353+
season_period=1,
354+
355+
# power transform applied to surveillance signals
356+
power_transform="4rt",
357+
358+
# sharing of information about parameters
359+
theta_pooling="shared",
360+
sigma_pooling="none",
361+
362+
# innovation distribution parameters
363+
innovation_dist="t",
364+
innovation_df_prior_scale=10.0,
365+
366+
# covariates
367+
x=[]
368+
)
369+
370+
run_config = SimpleNamespace(
371+
disease="flu",
372+
ref_date=datetime.date.fromisoformat("2024-01-06"),
373+
output_root=tmp_path / "model-output",
374+
artifact_store_root=tmp_path / "artifact-store",
375+
save_feat_importance=False,
376+
locations=["US", "01", "02", "04", "05"], # Reduced for faster testing
377+
max_horizon=2, # Reduced for faster testing
378+
q_levels=[0.025, 0.50, 0.975],
379+
q_labels=["0.025", "0.5", "0.975"],
380+
num_warmup=100, # Reduced for faster testing
381+
num_samples=100,
382+
num_chains=1
383+
)
384+
385+
model = SARIXModel(model_config)
386+
model.run(run_config)
387+
388+
# Verify output structure
389+
actual_df = pd.read_csv(
390+
run_config.output_root / "UMass-sarix_p6_4rt_thetashared_sigmanone_tdist" /
391+
"2024-01-06-UMass-sarix_p6_4rt_thetashared_sigmanone_tdist.csv"
392+
)
393+
394+
# Assertions
395+
assert len(actual_df) > 0, "Output dataframe should not be empty"
396+
assert set(actual_df["location"].unique()) == set(run_config.locations), \
397+
"Output should contain predictions for all input locations"
398+
assert all(actual_df["output_type"] == "quantile"), \
399+
"All outputs should be quantiles"
400+
assert set(actual_df["output_type_id"].astype(str).unique()) == set(run_config.q_labels), \
401+
"Output should contain all specified quantile levels"
402+
assert actual_df["value"].notna().all(), \
403+
"All predictions should be non-null"
404+
assert (actual_df["value"] >= 0).all(), \
405+
"All predictions should be non-negative"
406+
407+
336408
def _np_percentile_val():
337409
return numpy.array(
338410
[[[2.22541624e-01, 1.82324940e-01, 1.27709944e-01],

0 commit comments

Comments
 (0)