Skip to content

Commit 0d0698a

Browse files
replace SimpleNamespace configs with typed dataclasses and enums. update pyproject.toml to pin pandas to 2.* (pandas 3.* introduced breaking changes). update uv.lock to use latest iddata commit
1 parent 6a26685 commit 0d0698a

11 files changed

Lines changed: 288 additions & 183 deletions

File tree

CHANGELOG.md

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [1.3.0]
11+
12+
### Added
13+
- Concrete configuration dataclasses: `ModelConfig`, `RunConfig` (abstract bases), `SARIXModelConfig`, `SARIXRunConfig`, `GBQRModelConfig`, `GBQRRunConfig`
14+
- `SARIXFourierModelConfig` dataclass with `fourier_K` and `fourier_pooling` fields
15+
- Wave feature fields on `GBQRModelConfig` (`use_directional_waves`, `wave_directions`, etc.), disabled by default
16+
- Enum types: `DataSource`, `Disease`, `PowerTransform`, `PoolingStrategy`
17+
- Docstrings for `ModelConfig` and `RunConfig` base classes
18+
- All config types exported from `idmodels.__init__`
19+
20+
### Changed
21+
- **Breaking**: `model_config.sources` now expects `list[DataSource]` instead of `list[str]`
22+
- **Breaking**: `model_config.power_transform` now expects `PowerTransform` instead of `str`
23+
- **Breaking**: `model_config.disease` now expects `Disease` instead of `str`
24+
- Source validation in `sarix.py` and `gbqr.py` uses `DataSource` enums and set operations instead of `np.isin` with string arrays
25+
- All tests use concrete config dataclasses instead of `SimpleNamespace`
26+
- Updated `directional_wave_features.md` examples to use config dataclasses
27+
28+
### Removed
29+
- `SimpleNamespace` usage throughout tests and documentation
30+
- `model_class` field from model configurations (implied by the dataclass type)
31+
- `num_bags` from `GBQRRunConfig` test helper (it is a `GBQRModelConfig` field)
32+
- `save_feat_importance` from `SARIXRunConfig` test helper (not a SARIX field)
33+
1034
## [1.1.0] - 2025-12-08
1135

1236
### Added
@@ -71,7 +95,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
7195
### Changed
7296
- Updated to latest iddata API
7397

74-
[Unreleased]: https://github.com/reichlab/idmodels/compare/v1.1.0...HEAD
98+
[Unreleased]: https://github.com/reichlab/idmodels/compare/v1.3.0...HEAD
99+
[1.3.0]: https://github.com/reichlab/idmodels/compare/v1.1.0...v1.3.0
75100
[1.1.0]: https://github.com/reichlab/idmodels/compare/v1.0.0...v1.1.0
76101
[1.0.0]: https://github.com/reichlab/idmodels/compare/v0.1.0...v1.0.0
77102
[0.1.0]: https://github.com/reichlab/idmodels/compare/v0.0.1...v0.1.0

docs/directional_wave_features.md

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,13 @@ For each location and time point, the following features are generated:
4343
Directional wave features are **disabled by default** for backwards compatibility. To enable them, add the following parameters to your `model_config`:
4444

4545
```python
46-
from types import SimpleNamespace
46+
from idmodels.config import DataSource, GBQRModelConfig, PowerTransform
4747

48-
model_config = SimpleNamespace(
49-
# ... existing parameters ...
48+
model_config = GBQRModelConfig(
49+
model_name = "gbqr_with_waves",
50+
sources = [DataSource.NHSN],
51+
fit_locations_separately = False,
52+
power_transform = PowerTransform.FOURTH_ROOT,
5053

5154
# Directional wave features (disabled by default)
5255
use_directional_waves = True, # Set to True to enable
@@ -111,8 +114,8 @@ model_config = SimpleNamespace(
111114

112115
### Minimal Configuration (4 cardinal directions)
113116
```python
114-
model_config = SimpleNamespace(
115-
# ... other params ...
117+
model_config = GBQRModelConfig(
118+
# ... required base params ...
116119
use_directional_waves = True,
117120
wave_directions = ['N', 'S', 'E', 'W']
118121
)
@@ -121,8 +124,8 @@ Generates: 4 base + 4 aggregate + (4+1)×2 lags = **14 features**
121124

122125
### Standard Configuration (8 directions)
123126
```python
124-
model_config = SimpleNamespace(
125-
# ... other params ...
127+
model_config = GBQRModelConfig(
128+
# ... required base params ...
126129
use_directional_waves = True,
127130
wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'],
128131
wave_temporal_lags = [1, 2]
@@ -132,8 +135,8 @@ Generates: 8 base + 1 aggregate + (8+1)×2 lags = **27 features**
132135

133136
### Maximum Information (all options)
134137
```python
135-
model_config = SimpleNamespace(
136-
# ... other params ...
138+
model_config = GBQRModelConfig(
139+
# ... required base params ...
137140
use_directional_waves = True,
138141
wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'],
139142
wave_temporal_lags = [1, 2],
@@ -147,8 +150,8 @@ Generates: 8 base + 1 aggregate + (8+1)×2 lags + (8+1) velocity = **36 features
147150
### Hypothesis-Driven (specific directions)
148151
```python
149152
# If you suspect disease spreads along NE-SW axis
150-
model_config = SimpleNamespace(
151-
# ... other params ...
153+
model_config = GBQRModelConfig(
154+
# ... required base params ...
152155
use_directional_waves = True,
153156
wave_directions = ['NE', 'SW'],
154157
wave_temporal_lags = [1, 2, 3], # Longer lags for slower spread
@@ -240,22 +243,21 @@ The implementation includes validation that warns about:
240243
## Example: Complete GBQR Configuration
241244

242245
```python
243-
from types import SimpleNamespace
246+
import datetime
247+
from pathlib import Path
248+
from idmodels.config import DataSource, Disease, GBQRModelConfig, GBQRRunConfig, PowerTransform
244249
from idmodels.gbqr import GBQRModel
245250

246251
# Model configuration with directional wave features
247-
model_config = SimpleNamespace(
248-
model_class = "gbqr",
252+
model_config = GBQRModelConfig(
249253
model_name = "gbqr_with_waves",
250-
251-
# Standard GBQR parameters
254+
sources = [DataSource.NHSN],
255+
fit_locations_separately = False,
256+
power_transform = PowerTransform.FOURTH_ROOT,
252257
incl_level_feats = True,
253258
num_bags = 10,
254259
bag_frac_samples = 0.7,
255260
reporting_adj = False,
256-
sources = ["nhsn"],
257-
fit_locations_separately = False,
258-
power_transform = "4rt",
259261

260262
# Directional wave features
261263
use_directional_waves = True,
@@ -267,16 +269,17 @@ model_config = SimpleNamespace(
267269
)
268270

269271
# Run configuration
270-
run_config = SimpleNamespace(
271-
disease = "flu",
272+
run_config = GBQRRunConfig(
273+
disease = Disease.FLU,
272274
ref_date = datetime.date(2024, 1, 6),
273-
output_root = "output/",
274-
artifact_store_root = "artifacts/",
275-
save_feat_importance = True,
276-
locations = None, # All locations
275+
output_root = Path("output/"),
276+
artifact_store_root = Path("artifacts/"),
277277
max_horizon = 4,
278+
states = ["US", "01", "06", "13", "36", "48"],
279+
hsas = [],
278280
q_levels = [0.025, 0.10, 0.25, 0.50, 0.75, 0.90, 0.975],
279-
q_labels = ["0.025", "0.1", "0.25", "0.5", "0.75", "0.9", "0.975"]
281+
q_labels = ["0.025", "0.1", "0.25", "0.5", "0.75", "0.9", "0.975"],
282+
save_feat_importance = True
280283
)
281284

282285
# Run model

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ dependencies = [
1414
"iddata @ git+https://github.com/reichlab/iddata",
1515
"lightgbm",
1616
"numpy",
17-
"pandas",
18-
"sarix @ git+https://github.com/reichlab/sarix",
17+
"pandas~=2.0", # pandas 3.0 breaks compatibility; remove cap once validated
18+
"sarix @ git+https://github.com/reichlab/sarix?rev=35eea2379a9790e0457b1aed41d13509e5d5056f",
1919
"scikit-learn",
2020
"tqdm",
2121
"timeseriesutils @ git+https://github.com/reichlab/timeseriesutils"

src/idmodels/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,11 @@
1-
__version__ = "1.2.0"
1+
from idmodels.config import (DataSource, Disease, GBQRModelConfig, GBQRRunConfig, PoolingStrategy, PowerTransform,
2+
SARIXFourierModelConfig, SARIXModelConfig, SARIXRunConfig)
3+
from idmodels.gbqr import GBQRModel
4+
from idmodels.sarix import SARIXFourierModel, SARIXModel
25

6+
7+
__all__ = ['DataSource', 'Disease', 'GBQRModel', 'GBQRModelConfig', 'GBQRRunConfig', 'PoolingStrategy',
8+
'PowerTransform', 'SARIXFourierModel', 'SARIXFourierModelConfig', 'SARIXModel', 'SARIXModelConfig',
9+
'SARIXRunConfig']
10+
11+
__version__ = '1.3.0'

src/idmodels/config.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import datetime
2+
from abc import ABC
3+
from dataclasses import dataclass, field
4+
from enum import Enum
5+
from pathlib import Path
6+
7+
8+
class DataSource(str, Enum):
9+
NHSN = "nhsn"
10+
NSSP = "nssp"
11+
FLUSURVNET = "flusurvnet"
12+
ILINET = "ilinet"
13+
14+
15+
class Disease(str, Enum):
16+
FLU = "flu"
17+
COVID = "covid"
18+
19+
20+
class PowerTransform(str, Enum):
21+
FOURTH_ROOT = "4rt"
22+
23+
24+
class PoolingStrategy(str, Enum):
25+
NONE = "none"
26+
SHARED = "shared"
27+
28+
29+
@dataclass
30+
class ModelConfig(ABC):
31+
"""
32+
Abstract base for model configuration.
33+
34+
Holds settings that describe *what* model to run and how it processes data (sources, transforms, pooling).
35+
Not instantiated directly - use :class:`SARIXModelConfig` or :class:`GBQRModelConfig`.
36+
"""
37+
38+
model_name: str
39+
sources: list[DataSource]
40+
fit_locations_separately: bool
41+
power_transform: PowerTransform
42+
43+
def __post_init__(self):
44+
if type(self) is ModelConfig:
45+
raise TypeError("ModelConfig is abstract - use SARIXModelConfig or GBQRModelConfig")
46+
47+
48+
@dataclass
49+
class RunConfig(ABC):
50+
"""
51+
Abstract base for run configuration.
52+
53+
Holds settings that describe a single execution: which disease, which locations, output paths, quantile levels, etc.
54+
Not instantiated directly - use :class:`SARIXRunConfig` or :class:`GBQRRunConfig`.
55+
"""
56+
57+
disease: Disease
58+
ref_date: datetime.date
59+
output_root: Path
60+
artifact_store_root: Path | None
61+
max_horizon: int
62+
states: list[str]
63+
hsas: list[str]
64+
q_levels: list[float]
65+
q_labels: list[str]
66+
67+
def __post_init__(self):
68+
if type(self) is RunConfig:
69+
raise TypeError("RunConfig is abstract - use SARIXRunConfig or GBQRRunConfig")
70+
71+
72+
@dataclass
73+
class SARIXModelConfig(ModelConfig):
74+
p: int = 0
75+
P: int = 0
76+
d: int = 0
77+
D: int = 0
78+
season_period: int = 1
79+
theta_pooling: PoolingStrategy = PoolingStrategy.NONE
80+
sigma_pooling: PoolingStrategy = PoolingStrategy.NONE
81+
x: list = field(default_factory=list)
82+
83+
84+
@dataclass
85+
class SARIXFourierModelConfig(SARIXModelConfig):
86+
fourier_K: int = 1
87+
fourier_pooling: PoolingStrategy = PoolingStrategy.NONE
88+
89+
90+
@dataclass
91+
class GBQRModelConfig(ModelConfig):
92+
incl_level_feats: bool = True
93+
num_bags: int = 100
94+
bag_frac_samples: float = 0.7
95+
reporting_adj: bool = False
96+
# directional wave features (disabled by default)
97+
use_directional_waves: bool = False
98+
wave_directions: list[str] = field(default_factory=lambda: ["N", "NE", "E", "SE", "S", "SW", "W", "NW"])
99+
wave_temporal_lags: list[int] = field(default_factory=lambda: [1, 2])
100+
wave_max_distance_km: float = 1000.0
101+
wave_include_velocity: bool = False
102+
wave_include_aggregate: bool = True
103+
104+
105+
@dataclass
106+
class SARIXRunConfig(RunConfig):
107+
num_warmup: int = 2000
108+
num_samples: int = 2000
109+
num_chains: int = 1
110+
111+
112+
@dataclass
113+
class GBQRRunConfig(RunConfig):
114+
save_feat_importance: bool = False

src/idmodels/gbqr.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from iddata.loader import DiseaseDataLoader
77
from tqdm.autonotebook import tqdm
88

9+
from idmodels.config import DataSource, Disease, PowerTransform
910
from idmodels.preprocess import create_directional_wave_features, create_features_and_targets
1011
from idmodels.utils import build_save_path
1112

@@ -31,22 +32,22 @@ def run(self, run_config):
3132
ilinet_kwargs = {"scale_to_positive": False}
3233
flusurvnet_kwargs = {"burden_adj": False}
3334

34-
valid_sources = ["flusurvnet", "nhsn", "ilinet", "nssp"]
35-
if not np.isin(np.array(self.model_config.sources), valid_sources).all():
35+
valid_sources = {DataSource.FLUSURVNET, DataSource.NHSN, DataSource.ILINET, DataSource.NSSP}
36+
if not set(self.model_config.sources) <= valid_sources:
3637
raise ValueError("For GBQR, the only supported data sources are 'nhsn', 'flusurvnet', 'ilinet', or 'nssp'.")
37-
38+
3839
# Check if both nhsn and nssp data are included as sources
39-
if all(src in self.model_config.sources for src in ["nhsn", "nssp"]):
40+
if (DataSource.NHSN in self.model_config.sources) and (DataSource.NSSP in self.model_config.sources):
4041
raise ValueError("Only one of 'nhsn' or 'nssp' may be selected as a data source.")
41-
42+
4243
fdl = DiseaseDataLoader()
43-
if "nhsn" in self.model_config.sources:
44+
if DataSource.NHSN in self.model_config.sources:
4445
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
4546
ilinet_kwargs=ilinet_kwargs,
4647
flusurvnet_kwargs=flusurvnet_kwargs,
4748
sources=self.model_config.sources,
4849
power_transform=self.model_config.power_transform)
49-
elif "nssp" in self.model_config.sources:
50+
elif DataSource.NSSP in self.model_config.sources:
5051
df = fdl.load_data(nssp_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
5152
ilinet_kwargs=ilinet_kwargs,
5253
flusurvnet_kwargs=flusurvnet_kwargs,
@@ -62,20 +63,20 @@ def run(self, run_config):
6263
df["unique_id"] = df["agg_level"] + df["location"]
6364

6465
# augment data with features and target values
65-
if run_config.disease == "flu":
66+
if run_config.disease == Disease.FLU:
6667
init_feats = ["inc_trans_cs", "season_week", "log_pop"]
67-
elif run_config.disease == "covid":
68+
elif run_config.disease == Disease.COVID:
6869
init_feats = ["inc_trans_cs", "log_pop"]
6970

7071
# Create directional wave features if enabled
71-
if hasattr(self.model_config, "use_directional_waves") and self.model_config.use_directional_waves:
72+
if self.model_config.use_directional_waves:
7273
wave_config = {
7374
"enabled": True,
74-
"directions": getattr(self.model_config, "wave_directions", ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]),
75-
"temporal_lags": getattr(self.model_config, "wave_temporal_lags", [1, 2]),
76-
"max_distance_km": getattr(self.model_config, "wave_max_distance_km", 1000),
77-
"include_velocity": getattr(self.model_config, "wave_include_velocity", False),
78-
"include_aggregate": getattr(self.model_config, "wave_include_aggregate", True)
75+
"directions": self.model_config.wave_directions,
76+
"temporal_lags": self.model_config.wave_temporal_lags,
77+
"max_distance_km": self.model_config.wave_max_distance_km,
78+
"include_velocity": self.model_config.wave_include_velocity,
79+
"include_aggregate": self.model_config.wave_include_aggregate,
7980
}
8081
df, wave_feat_names = create_directional_wave_features(df, wave_config)
8182
init_feats = init_feats + wave_feat_names
@@ -87,7 +88,7 @@ def run(self, run_config):
8788
curr_feat_names=init_feats)
8889

8990
# keep only rows that are in-season
90-
if run_config.disease == "flu":
91+
if run_config.disease == Disease.FLU:
9192
df = df.query("season_week >= 5 and season_week <= 45")
9293

9394
# "test set" df used to generate look-ahead predictions
@@ -176,12 +177,10 @@ def _train_gbq_and_predict(self, run_config,
176177
# build data frame with predictions on the original scale
177178
preds_df["inc_trans_cs_target_hat"] = preds_df["inc_trans_cs"] + preds_df["delta_hat"]
178179
preds_df["inc_trans_target_hat"] = (preds_df["inc_trans_cs_target_hat"] + preds_df["inc_trans_center_factor"]) * (preds_df["inc_trans_scale_factor"] + 0.01)
179-
if self.model_config.power_transform == "4rt":
180+
if self.model_config.power_transform == PowerTransform.FOURTH_ROOT:
180181
inv_power = 4
181-
elif self.model_config.power_transform is None:
182-
inv_power = 1
183182
else:
184-
raise ValueError('unsupported power_transform: must be "4rt" or None')
183+
raise ValueError(f"unsupported power_transform: {self.model_config.power_transform!r}")
185184

186185
preds_df["value"] = (np.maximum(preds_df["inc_trans_target_hat"], 0.0) ** inv_power - 0.01 - 0.75**4)
187186

0 commit comments

Comments
 (0)