Skip to content

Commit 6dc1e0b

Browse files
authored
Merge pull request #13 from reichlab/ngr/update-sarix-fourier
update sarix
2 parents 35b7752 + db01be9 commit 6dc1e0b

8 files changed

Lines changed: 2367 additions & 25 deletions

File tree

CHANGELOG.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Changelog
2+
3+
All notable changes to this project will be documented in this file.
4+
5+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7+
8+
## [Unreleased]
9+
10+
## [0.1.0] - 2025-11-03
11+
12+
### Added
13+
- Support for Fourier pooling option in SARIX models
14+
- `SARIXFourierModel` as a proper subclass of `SARIXModel`
15+
- `uv.lock` for reproducible builds
16+
- Location filter option for models
17+
18+
### Changed
19+
- Pinned sarix dependency to commit 35eea237 for stability
20+
- Updated dependencies (iddata, sarix)
21+
- Refactored `SARIXFourierModel` implementation as subclass
22+
23+
### Fixed
24+
- SARIX `sigma_pooling='shared'` bug with multiple batches
25+
- Import sorting issues (ruff compliance)
26+
- Covariate ordering issue (#10)
27+
- Test determinism across different operating systems
28+
- Horizon handling for SARIX and GBQR models
29+
- Missing value handling in input NHSN data
30+
31+
## [0.0.1] - 2024
32+
33+
### Added
34+
- Initial package setup and structure
35+
- SARIX model support for infectious disease forecasting
36+
- GBQR (Gradient Boosted Quantile Regression) model
37+
- COVID-19 disease support
38+
- Integration tests for core models
39+
- Basic model implementations sourced from flusion-experiments
40+
- Pre-commit hooks with ruff linting
41+
- Python 3.11+ support
42+
43+
### Changed
44+
- Updated to latest iddata API
45+
46+
[Unreleased]: https://github.com/reichlab/idmodels/compare/v0.1.0...HEAD
47+
[0.1.0]: https://github.com/reichlab/idmodels/compare/v0.0.1...v0.1.0
48+
[0.0.1]: https://github.com/reichlab/idmodels/releases/tag/v0.0.1

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "idmodels"
33
description = "A Python module for modeling infectious disease."
44
license = {text = "MIT License"}
55
readme = "README.md"
6-
requires-python = ">=3.9"
6+
requires-python = ">=3.11"
77
classifiers = [
88
"Programming Language :: Python :: 3",
99
"License :: OSI Approved :: MIT License",
@@ -15,7 +15,7 @@ dependencies = [
1515
"lightgbm",
1616
"numpy",
1717
"pandas",
18-
"sarix @ git+https://github.com/elray1/sarix",
18+
"sarix @ git+https://github.com/reichlab/sarix@35eea2379a9790e0457b1aed41d13509e5d5056f",
1919
"scikit-learn",
2020
"tqdm",
2121
"timeseriesutils @ git+https://github.com/reichlab/timeseriesutils"

requirements/requirements-dev.txt

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ botocore==1.35.36
2020
# via aiobotocore
2121
cfgv==3.4.0
2222
# via pre-commit
23-
colorama==0.4.6
24-
# via
25-
# pytest
26-
# tqdm
2723
contourpy==1.3.0
2824
# via matplotlib
2925
coverage==7.6.4
@@ -42,19 +38,19 @@ frozenlist==1.5.0
4238
# aiosignal
4339
fsspec==2024.10.0
4440
# via s3fs
45-
iddata @ git+https://github.com/reichlab/iddata@3ad0ac0dc6d7f14488628a49bfb5228ca0643e1b
41+
iddata @ git+https://github.com/reichlab/iddata@c28849b2a02ab84e2f82876f16fee2ac60814877
4642
# via idmodels (pyproject.toml)
4743
identify==2.6.1
4844
# via pre-commit
4945
idna==3.10
5046
# via yarl
5147
iniconfig==2.0.0
5248
# via pytest
53-
jax==0.4.35
49+
jax==0.8.0
5450
# via
5551
# numpyro
5652
# sarix
57-
jaxlib==0.4.35
53+
jaxlib==0.8.0
5854
# via
5955
# jax
6056
# numpyro
@@ -100,7 +96,7 @@ numpy==2.1.3
10096
# scikit-learn
10197
# scipy
10298
# timeseriesutils
103-
numpyro==0.15.3
99+
numpyro==0.19.0
104100
# via sarix
105101
opt-einsum==3.4.0
106102
# via jax
@@ -112,7 +108,6 @@ pandas==2.2.3
112108
# via
113109
# idmodels (pyproject.toml)
114110
# iddata
115-
# sarix
116111
# timeseriesutils
117112
pillow==11.0.0
118113
# via matplotlib
@@ -147,7 +142,7 @@ ruff==0.7.2
147142
# via idmodels (pyproject.toml)
148143
s3fs==2024.10.0
149144
# via iddata
150-
sarix @ git+https://github.com/elray1/sarix@1c8995942d49afbb66637f0dc0662e1248606af4
145+
sarix @ git+https://github.com/reichlab/sarix@35eea2379a9790e0457b1aed41d13509e5d5056f
151146
# via idmodels (pyproject.toml)
152147
scikit-learn==1.5.2
153148
# via idmodels (pyproject.toml)

requirements/requirements.txt

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ attrs==24.2.0
1818
# pymmwr
1919
botocore==1.35.36
2020
# via aiobotocore
21-
colorama==0.4.6
22-
# via tqdm
2321
contourpy==1.3.0
2422
# via matplotlib
2523
cycler==0.12.1
@@ -32,15 +30,15 @@ frozenlist==1.5.0
3230
# aiosignal
3331
fsspec==2024.10.0
3432
# via s3fs
35-
iddata @ git+https://github.com/reichlab/iddata@3ad0ac0dc6d7f14488628a49bfb5228ca0643e1b
33+
iddata @ git+https://github.com/reichlab/iddata@c28849b2a02ab84e2f82876f16fee2ac60814877
3634
# via idmodels (pyproject.toml)
3735
idna==3.10
3836
# via yarl
39-
jax==0.4.35
37+
jax==0.8.0
4038
# via
4139
# numpyro
4240
# sarix
43-
jaxlib==0.4.35
41+
jaxlib==0.8.0
4442
# via
4543
# jax
4644
# numpyro
@@ -84,7 +82,7 @@ numpy==2.1.3
8482
# scikit-learn
8583
# scipy
8684
# timeseriesutils
87-
numpyro==0.15.3
85+
numpyro==0.19.0
8886
# via sarix
8987
opt-einsum==3.4.0
9088
# via jax
@@ -94,7 +92,6 @@ pandas==2.2.3
9492
# via
9593
# idmodels (pyproject.toml)
9694
# iddata
97-
# sarix
9895
# timeseriesutils
9996
pillow==11.0.0
10097
# via matplotlib
@@ -117,7 +114,7 @@ rich==13.9.4
117114
# via iddata
118115
s3fs==2024.10.0
119116
# via iddata
120-
sarix @ git+https://github.com/elray1/sarix@1c8995942d49afbb66637f0dc0662e1248606af4
117+
sarix @ git+https://github.com/reichlab/sarix@35eea2379a9790e0457b1aed41d13509e5d5056f
121118
# via idmodels (pyproject.toml)
122119
scikit-learn==1.5.2
123120
# via idmodels (pyproject.toml)

src/idmodels/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.1"
1+
__version__ = "0.1.0"

src/idmodels/sarix.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ class SARIXModel():
1212
def __init__(self, model_config):
1313
self.model_config = model_config
1414

15+
def _get_extra_sarix_params(self, df):
16+
"""Return extra parameters to pass to SARIX constructor. Returns empty dict by default."""
17+
return {}
18+
1519
def run(self, run_config):
1620
fdl = DiseaseDataLoader()
1721
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
@@ -30,11 +34,14 @@ def run(self, run_config):
3034
on="season") \
3135
.assign(delta_xmas = lambda x: x["season_week"] - x["xmas_week"])
3236
df["xmas_spike"] = np.maximum(3 - np.abs(df["delta_xmas"]), 0)
33-
37+
3438
xy_colnames = self.model_config.x + ["inc_trans_cs"]
3539
df = df.query("wk_end_date >= '2022-10-01'").interpolate()
3640
batched_xy = df[xy_colnames].values.reshape(len(df["location"].unique()), -1, len(xy_colnames))
37-
41+
42+
# Get any extra parameters for the SARIX constructor
43+
extra_params = self._get_extra_sarix_params(df)
44+
3845
sarix_fit_all_locs_theta_pooled = sarix.SARIX(
3946
xy = batched_xy,
4047
p = self.model_config.p,
@@ -48,7 +55,8 @@ def run(self, run_config):
4855
forecast_horizon = run_config.max_horizon,
4956
num_warmup = run_config.num_warmup,
5057
num_samples = run_config.num_samples,
51-
num_chains = run_config.num_chains
58+
num_chains = run_config.num_chains,
59+
**extra_params
5260
)
5361

5462
pred_qs = _np_percentile(sarix_fit_all_locs_theta_pooled.predictions[..., :, :, 0],
@@ -93,9 +101,34 @@ def run(self, run_config):
93101
run_config=run_config,
94102
model_config=self.model_config
95103
)
104+
# Ensure output_type_id is string to avoid pandas inferring it as float when reading
105+
preds_df["output_type_id"] = preds_df["output_type_id"].astype(str)
96106
preds_df.to_csv(save_path, index=False)
97107

98108

109+
class SARIXFourierModel(SARIXModel):
110+
"""
111+
SARIX model with Fourier seasonality terms.
112+
113+
Adds annual seasonal patterns using Fourier harmonics to the base SARIX model.
114+
115+
Required model_config parameters:
116+
- fourier_K: Number of Fourier harmonic pairs (int)
117+
- fourier_pooling: How to share Fourier coefficients across locations ('none' or 'shared')
118+
"""
119+
def _get_extra_sarix_params(self, df):
120+
"""Return Fourier-specific parameters for SARIX constructor."""
121+
# Extract day-of-year from dates for Fourier features
122+
# Take the first location's dates (same for all locations after reshaping)
123+
day_of_year = df.groupby("location")["wk_end_date"].apply(lambda x: x.dt.dayofyear.values).iloc[0]
124+
125+
return {
126+
"day_of_year": day_of_year,
127+
"fourier_K": self.model_config.fourier_K,
128+
"fourier_pooling": self.model_config.fourier_pooling
129+
}
130+
131+
99132
def _np_percentile(predictions, q_levels, axis):
100133
"""
101134
Simple helper function to ease patching from unit tests.

0 commit comments

Comments
 (0)