@@ -43,10 +43,13 @@ For each location and time point, the following features are generated:
4343Directional wave features are ** disabled by default** for backwards compatibility. To enable them, add the following parameters to your ` model_config ` :
4444
4545``` python
46- from types import SimpleNamespace
46+ from idmodels.config import DataSource, GBQRModelConfig, PowerTransform
4747
48- model_config = SimpleNamespace(
49- # ... existing parameters ...
48+ model_config = GBQRModelConfig(
49+ model_name = " gbqr_with_waves" ,
50+ sources = [DataSource.NHSN ],
51+ fit_locations_separately = False ,
52+ power_transform = PowerTransform.FOURTH_ROOT ,
5053
5154 # Directional wave features (disabled by default)
5255 use_directional_waves = True , # Set to True to enable
@@ -111,8 +114,8 @@ model_config = SimpleNamespace(
111114
112115### Minimal Configuration (4 cardinal directions)
113116``` python
114- model_config = SimpleNamespace (
115- # ... other params ...
117+ model_config = GBQRModelConfig (
118+ # ... required base params ...
116119 use_directional_waves = True ,
117120 wave_directions = [' N' , ' S' , ' E' , ' W' ]
118121)
@@ -121,8 +124,8 @@ Generates: 4 base + 4 aggregate + (4+1)×2 lags = **14 features**
121124
122125### Standard Configuration (8 directions)
123126``` python
124- model_config = SimpleNamespace (
125- # ... other params ...
127+ model_config = GBQRModelConfig (
128+ # ... required base params ...
126129 use_directional_waves = True ,
127130 wave_directions = [' N' , ' NE' , ' E' , ' SE' , ' S' , ' SW' , ' W' , ' NW' ],
128131 wave_temporal_lags = [1 , 2 ]
@@ -132,8 +135,8 @@ Generates: 8 base + 1 aggregate + (8+1)×2 lags = **27 features**
132135
133136### Maximum Information (all options)
134137``` python
135- model_config = SimpleNamespace (
136- # ... other params ...
138+ model_config = GBQRModelConfig (
139+ # ... required base params ...
137140 use_directional_waves = True ,
138141 wave_directions = [' N' , ' NE' , ' E' , ' SE' , ' S' , ' SW' , ' W' , ' NW' ],
139142 wave_temporal_lags = [1 , 2 ],
@@ -147,8 +150,8 @@ Generates: 8 base + 1 aggregate + (8+1)×2 lags + (8+1) velocity = **36 features
147150### Hypothesis-Driven (specific directions)
148151``` python
149152# If you suspect disease spreads along NE-SW axis
150- model_config = SimpleNamespace (
151- # ... other params ...
153+ model_config = GBQRModelConfig (
154+ # ... required base params ...
152155 use_directional_waves = True ,
153156 wave_directions = [' NE' , ' SW' ],
154157 wave_temporal_lags = [1 , 2 , 3 ], # Longer lags for slower spread
@@ -240,22 +243,21 @@ The implementation includes validation that warns about:
240243## Example: Complete GBQR Configuration
241244
242245``` python
243- from types import SimpleNamespace
246+ import datetime
247+ from pathlib import Path
248+ from idmodels.config import DataSource, Disease, GBQRModelConfig, GBQRRunConfig, PowerTransform
244249from idmodels.gbqr import GBQRModel
245250
246251# Model configuration with directional wave features
247- model_config = SimpleNamespace(
248- model_class = " gbqr" ,
252+ model_config = GBQRModelConfig(
249253 model_name = " gbqr_with_waves" ,
250-
251- # Standard GBQR parameters
254+ sources = [DataSource.NHSN ],
255+ fit_locations_separately = False ,
256+ power_transform = PowerTransform.FOURTH_ROOT ,
252257 incl_level_feats = True ,
253258 num_bags = 10 ,
254259 bag_frac_samples = 0.7 ,
255260 reporting_adj = False ,
256- sources = [" nhsn" ],
257- fit_locations_separately = False ,
258- power_transform = " 4rt" ,
259261
260262 # Directional wave features
261263 use_directional_waves = True ,
@@ -267,16 +269,17 @@ model_config = SimpleNamespace(
267269)
268270
269271# Run configuration
270- run_config = SimpleNamespace (
271- disease = " flu " ,
272+ run_config = GBQRRunConfig (
273+ disease = Disease. FLU ,
272274 ref_date = datetime.date(2024 , 1 , 6 ),
273- output_root = " output/" ,
274- artifact_store_root = " artifacts/" ,
275- save_feat_importance = True ,
276- locations = None , # All locations
275+ output_root = Path(" output/" ),
276+ artifact_store_root = Path(" artifacts/" ),
277277 max_horizon = 4 ,
278+ states = [" US" , " 01" , " 06" , " 13" , " 36" , " 48" ],
279+ hsas = [],
278280 q_levels = [0.025 , 0.10 , 0.25 , 0.50 , 0.75 , 0.90 , 0.975 ],
279- q_labels = [" 0.025" , " 0.1" , " 0.25" , " 0.5" , " 0.75" , " 0.9" , " 0.975" ]
281+ q_labels = [" 0.025" , " 0.1" , " 0.25" , " 0.5" , " 0.75" , " 0.9" , " 0.975" ],
282+ save_feat_importance = True
280283)
281284
282285# Run model
0 commit comments