|
| 1 | +# Directional Wave Features for GBQR Model |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +Directional wave features capture spatial-temporal patterns in disease spread by computing distance-weighted averages of neighboring locations' incidence in specified directions (N, NE, E, SE, S, SW, W, NW). |
| 6 | + |
| 7 | +These features allow the GBQR model to learn how disease "waves" propagate geographically over time, improving forecast accuracy when spatial spread patterns are important. |
| 8 | + |
| 9 | +## Motivation |
| 10 | + |
| 11 | +Traditional forecasting models treat each location independently or use simple spatial averaging. Directional wave features enable the model to: |
| 12 | + |
| 13 | +1. **Capture directional spread patterns**: Disease may spread preferentially in certain directions (e.g., following travel corridors, climate patterns) |
| 14 | +2. **Learn wave propagation speed**: By including temporal lags, the model can learn how long it takes for a wave to travel between locations |
| 15 | +3. **Distinguish between spreading and receding waves**: Velocity features capture acceleration/deceleration of spread |
| 16 | + |
| 17 | +## Feature Types |
| 18 | + |
| 19 | +For each location and time point, the following features are generated: |
| 20 | + |
| 21 | +### 1. Base Directional Features |
| 22 | +- `inc_trans_cs_wave_N`: Distance-weighted average of northern neighbors' incidence |
| 23 | +- `inc_trans_cs_wave_NE`: Distance-weighted average of northeastern neighbors' incidence |
| 24 | +- `inc_trans_cs_wave_E`: Distance-weighted average of eastern neighbors' incidence |
| 25 | +- ... (one for each specified direction) |
| 26 | + |
| 27 | +### 2. Aggregate Feature |
| 28 | +- `inc_trans_cs_wave_avg`: Overall distance-weighted average of all neighbors (regardless of direction) |
| 29 | + |
| 30 | +### 3. Temporal Lag Features |
| 31 | +- `inc_trans_cs_wave_N_lag1`: Northern neighbors' incidence from 1 week ago |
| 32 | +- `inc_trans_cs_wave_N_lag2`: Northern neighbors' incidence from 2 weeks ago |
| 33 | +- ... (for each direction and lag) |
| 34 | + |
| 35 | +**Important**: `lag1` refers to time t-1, `lag2` refers to time t-2, etc. |
| 36 | + |
| 37 | +### 4. Velocity Features (optional) |
| 38 | +- `inc_trans_cs_wave_N_velocity`: Rate of change = current - lag1 |
| 39 | +- ... (one for each direction) |
| 40 | + |
| 41 | +## Configuration |
| 42 | + |
| 43 | +Directional wave features are **disabled by default** for backwards compatibility. To enable them, add the following parameters to your `model_config`: |
| 44 | + |
| 45 | +```python |
| 46 | +from types import SimpleNamespace |
| 47 | + |
| 48 | +model_config = SimpleNamespace( |
| 49 | + # ... existing parameters ... |
| 50 | + |
| 51 | + # Directional wave features (disabled by default) |
| 52 | + use_directional_waves = True, # Set to True to enable |
| 53 | + |
| 54 | + # Which directions to compute (subset of: N, NE, E, SE, S, SW, W, NW) |
| 55 | + wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], # Default: all 8 |
| 56 | + |
| 57 | + # Temporal lags to include (lag1 = t-1, lag2 = t-2) |
| 58 | + wave_temporal_lags = [1, 2], # Default: [1, 2] |
| 59 | + |
| 60 | + # Maximum distance (km) to consider as neighbor |
| 61 | + wave_max_distance_km = 1000, # Default: 1000 |
| 62 | + |
| 63 | + # Include velocity (rate of change) features |
| 64 | + wave_include_velocity = False, # Default: False |
| 65 | + |
| 66 | + # Include aggregate weighted average feature |
| 67 | + wave_include_aggregate = True # Default: True |
| 68 | +) |
| 69 | +``` |
| 70 | + |
| 71 | +### Configuration Parameters Explained |
| 72 | + |
| 73 | +#### `use_directional_waves` (bool, default: False) |
| 74 | +- Master switch to enable/disable directional wave features |
| 75 | +- Must be set to `True` to generate wave features |
| 76 | + |
| 77 | +#### `wave_directions` (list of str, default: ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']) |
| 78 | +- Which directions to compute features for |
| 79 | +- Valid directions: N, NE, E, SE, S, SW, W, NW |
| 80 | +- Each direction has a 45° cone (±22.5° around center) |
| 81 | +- Examples: |
| 82 | + - `['N', 'S', 'E', 'W']` - Just cardinal directions (4 features) |
| 83 | + - `['NE', 'SW']` - Just diagonal directions (2 features) |
| 84 | + |
| 85 | +#### `wave_temporal_lags` (list of int, default: [1, 2]) |
| 86 | +- Which temporal lags to include |
| 87 | +- `lag1` means t-1 (last week), `lag2` means t-2 (two weeks ago) |
| 88 | +- Example: `[1, 2, 3]` includes 1, 2, and 3 week lags |
| 89 | + |
| 90 | +#### `wave_max_distance_km` (float, default: 1000) |
| 91 | +- Maximum distance (kilometers) to consider a location as a neighbor |
| 92 | +- Only locations within this distance are included in directional averages |
| 93 | +- Larger values include more distant neighbors (slower computation) |
| 94 | +- Typical values: |
| 95 | + - 500-1000 km for state-level analysis (immediate neighbors) |
| 96 | + - 2000-3000 km for regional patterns |
| 97 | + - 5000+ km for continent-wide patterns |
| 98 | + |
| 99 | +#### `wave_include_velocity` (bool, default: False) |
| 100 | +- Whether to include velocity features (rate of change) |
| 101 | +- Velocity = current - lag1 |
| 102 | +- Captures acceleration/deceleration of wave spread |
| 103 | +- Increases feature count by ~50% (one velocity per direction) |
| 104 | + |
| 105 | +#### `wave_include_aggregate` (bool, default: True) |
| 106 | +- Whether to include overall weighted average (all neighbors, any direction) |
| 107 | +- Provides general spatial context independent of direction |
| 108 | +- Recommended to keep enabled |
| 109 | + |
| 110 | +## Example Configurations |
| 111 | + |
| 112 | +### Minimal Configuration (4 cardinal directions) |
| 113 | +```python |
| 114 | +model_config = SimpleNamespace( |
| 115 | + # ... other params ... |
| 116 | + use_directional_waves = True, |
| 117 | + wave_directions = ['N', 'S', 'E', 'W'] |
| 118 | +) |
| 119 | +``` |
| 120 | +Generates: 4 base + 4 aggregate + (4+1)×2 lags = **14 features** |
| 121 | + |
| 122 | +### Standard Configuration (8 directions) |
| 123 | +```python |
| 124 | +model_config = SimpleNamespace( |
| 125 | + # ... other params ... |
| 126 | + use_directional_waves = True, |
| 127 | + wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], |
| 128 | + wave_temporal_lags = [1, 2] |
| 129 | +) |
| 130 | +``` |
| 131 | +Generates: 8 base + 1 aggregate + (8+1)×2 lags = **27 features** |
| 132 | + |
| 133 | +### Maximum Information (all options) |
| 134 | +```python |
| 135 | +model_config = SimpleNamespace( |
| 136 | + # ... other params ... |
| 137 | + use_directional_waves = True, |
| 138 | + wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], |
| 139 | + wave_temporal_lags = [1, 2], |
| 140 | + wave_max_distance_km = 2000, |
| 141 | + wave_include_velocity = True, |
| 142 | + wave_include_aggregate = True |
| 143 | +) |
| 144 | +``` |
| 145 | +Generates: 8 base + 1 aggregate + (8+1)×2 lags + (8+1) velocity = **36 features** |
| 146 | + |
| 147 | +### Hypothesis-Driven (specific directions) |
| 148 | +```python |
| 149 | +# If you suspect disease spreads along NE-SW axis |
| 150 | +model_config = SimpleNamespace( |
| 151 | + # ... other params ... |
| 152 | + use_directional_waves = True, |
| 153 | + wave_directions = ['NE', 'SW'], |
| 154 | + wave_temporal_lags = [1, 2, 3], # Longer lags for slower spread |
| 155 | + wave_max_distance_km = 1500 |
| 156 | +) |
| 157 | +``` |
| 158 | + |
| 159 | +## Technical Details |
| 160 | + |
| 161 | +### Directional Cone Definition |
| 162 | +Each direction captures neighbors within a 45° cone: |
| 163 | +- **N (North)**: 0° ±22.5° (337.5° to 22.5°) |
| 164 | +- **NE (Northeast)**: 45° ±22.5° (22.5° to 67.5°) |
| 165 | +- **E (East)**: 90° ±22.5° (67.5° to 112.5°) |
| 166 | +- ... and so on |
| 167 | + |
| 168 | +### Distance Weighting |
| 169 | +Inverse distance weighting is used: |
| 170 | +``` |
| 171 | +weight = 1 / distance |
| 172 | +weighted_average = Σ(weight × neighbor_value) / Σ(weight) |
| 173 | +``` |
| 174 | + |
| 175 | +Closer neighbors have more influence on the feature value. |
| 176 | + |
| 177 | +### Lag Semantics |
| 178 | +- **Base feature** (no lag suffix): Uses current time t |
| 179 | +- **lag1**: Uses time t-1 (one week ago) |
| 180 | +- **lag2**: Uses time t-2 (two weeks ago) |
| 181 | + |
| 182 | +This allows the model to learn patterns like: "If northern neighbors had high incidence last week (lag1), expect it here this week." |
| 183 | + |
| 184 | +### Missing Values |
| 185 | +- If a location has no neighbors in a direction (within max_distance_km), the feature value is NaN |
| 186 | +- These are handled by LightGBM during training |
| 187 | +- Edge locations (e.g., coastal states) may have missing values for certain directions |
| 188 | + |
| 189 | +## Location Support |
| 190 | + |
| 191 | +Currently supported: |
| 192 | +- **State-level** (agg_level='state'): US states, DC, PR, and national level |
| 193 | + |
| 194 | +**Data Source:** State centroids are loaded from `src/idmodels/data/state_centroids.csv`, which contains geographic centroids computed from US Census Bureau TIGER/Line shapefiles. See `src/idmodels/data/README.md` for detailed source information. |
| 195 | + |
| 196 | +To add support for other aggregation levels (county, HSA, etc.): |
| 197 | +1. Create a CSV file (e.g., `county_centroids.csv`) with columns: `fips`, `name`, `latitude`, `longitude` |
| 198 | +2. Place it in `src/idmodels/data/` |
| 199 | +3. Update `_load_state_centroids()` function in `src/idmodels/spatial_utils.py` to support the new level |
| 200 | +4. Document the data source in `src/idmodels/data/README.md` |
| 201 | + |
| 202 | +## Performance Considerations |
| 203 | + |
| 204 | +### Computational Cost |
| 205 | +- Scales O(n²) with number of locations (n) |
| 206 | +- For 50 states: ~2,500 distance calculations (precomputed) |
| 207 | +- Feature computation is done once per training run |
| 208 | + |
| 209 | +### Feature Count Impact |
| 210 | +Feature count depends on configuration: |
| 211 | +- Base: n_directions + (1 if aggregate else 0) |
| 212 | +- With lags: base × (1 + len(temporal_lags)) |
| 213 | +- With velocity: total × 1.5 (approximately) |
| 214 | + |
| 215 | +More features → longer training time, but potentially better predictions |
| 216 | + |
| 217 | +### Recommendations |
| 218 | +- Start with default configuration (8 directions, 2 lags) |
| 219 | +- Experiment with fewer directions if training is slow |
| 220 | +- Use `wave_include_velocity=False` unless you have evidence of acceleration patterns |
| 221 | + |
| 222 | +## Interpretation |
| 223 | + |
| 224 | +### Feature Importance |
| 225 | +After training, you can examine feature importance to understand: |
| 226 | +- Which directions are most predictive (e.g., is NE spread more important than SW?) |
| 227 | +- Whether lags matter (are lag1 features more important than current?) |
| 228 | +- Whether velocity features add value |
| 229 | + |
| 230 | +### Example Interpretations |
| 231 | +- **High importance for `wave_N_lag1`**: Disease tends to arrive from the north with 1-week delay |
| 232 | +- **High importance for `wave_avg`**: General spatial clustering matters more than direction |
| 233 | +- **High importance for `wave_NE_velocity`**: Acceleration of northeastern spread is predictive |
| 234 | + |
| 235 | +## Warnings and Validation |
| 236 | + |
| 237 | +The implementation includes validation that warns about: |
| 238 | +- **Opposite directions included**: If both N and S (or E and W, etc.) are included, they may be correlated in datasets with uniform spatial patterns. However, in typical epidemic scenarios with directional spread, opposite directions provide independent information. Tree-based models like LightGBM are also robust to multicollinearity, so this warning is informational rather than critical. |
| 239 | + |
| 240 | +## Example: Complete GBQR Configuration |
| 241 | + |
| 242 | +```python |
| 243 | +from types import SimpleNamespace |
| 244 | +from idmodels.gbqr import GBQRModel |
| 245 | + |
| 246 | +# Model configuration with directional wave features |
| 247 | +model_config = SimpleNamespace( |
| 248 | + model_class = "gbqr", |
| 249 | + model_name = "gbqr_with_waves", |
| 250 | + |
| 251 | + # Standard GBQR parameters |
| 252 | + incl_level_feats = True, |
| 253 | + num_bags = 10, |
| 254 | + bag_frac_samples = 0.7, |
| 255 | + reporting_adj = False, |
| 256 | + sources = ["nhsn"], |
| 257 | + fit_locations_separately = False, |
| 258 | + power_transform = "4rt", |
| 259 | + |
| 260 | + # Directional wave features |
| 261 | + use_directional_waves = True, |
| 262 | + wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], |
| 263 | + wave_temporal_lags = [1, 2], |
| 264 | + wave_max_distance_km = 1500, |
| 265 | + wave_include_velocity = False, |
| 266 | + wave_include_aggregate = True |
| 267 | +) |
| 268 | + |
| 269 | +# Run configuration |
| 270 | +run_config = SimpleNamespace( |
| 271 | + disease = "flu", |
| 272 | + ref_date = datetime.date(2024, 1, 6), |
| 273 | + output_root = "output/", |
| 274 | + artifact_store_root = "artifacts/", |
| 275 | + save_feat_importance = True, |
| 276 | + locations = None, # All locations |
| 277 | + max_horizon = 4, |
| 278 | + q_levels = [0.025, 0.10, 0.25, 0.50, 0.75, 0.90, 0.975], |
| 279 | + q_labels = ["0.025", "0.1", "0.25", "0.5", "0.75", "0.9", "0.975"] |
| 280 | +) |
| 281 | + |
| 282 | +# Run model |
| 283 | +model = GBQRModel(model_config) |
| 284 | +model.run(run_config) |
| 285 | +``` |
| 286 | + |
| 287 | +## Backwards Compatibility |
| 288 | + |
| 289 | +The implementation is fully backwards compatible: |
| 290 | +- Disabled by default (`use_directional_waves = False`) |
| 291 | +- Existing configurations without wave parameters work unchanged |
| 292 | +- Uses `hasattr()` checks to gracefully handle missing attributes |
| 293 | + |
| 294 | +Old configurations will continue to work without modification. |
| 295 | + |
| 296 | +## Testing |
| 297 | + |
| 298 | +The implementation includes comprehensive tests: |
| 299 | + |
| 300 | +### Unit Tests |
| 301 | +- `tests/unit/test_spatial_utils.py`: Tests for spatial calculations (distance, bearing, neighbors) |
| 302 | +- `tests/unit/test_directional_wave_features.py`: Tests for feature generation logic |
| 303 | + |
| 304 | +### Integration Tests |
| 305 | +- `tests/integration/test_gbqr_wave_features.py`: End-to-end tests with realistic data |
| 306 | + |
| 307 | +Run tests with: |
| 308 | +```bash |
| 309 | +uv run pytest tests/unit/test_spatial_utils.py -v |
| 310 | +uv run pytest tests/unit/test_directional_wave_features.py -v |
| 311 | +uv run pytest tests/integration/test_gbqr_wave_features.py -v |
| 312 | +``` |
| 313 | + |
| 314 | +## References |
| 315 | + |
| 316 | +### Epidemiological Motivation |
| 317 | +- Spatial spread of infectious diseases often follows directional patterns |
| 318 | +- Travel corridors, population density gradients, and climate patterns create anisotropic spread |
| 319 | +- Historical examples: 1918 flu pandemic, COVID-19 spread in US |
| 320 | + |
| 321 | +### Implementation Details |
| 322 | +- Haversine distance formula for great circle distance |
| 323 | +- Bearing calculation using spherical trigonometry |
| 324 | +- Inverse distance weighting for spatial interpolation |
| 325 | + |
| 326 | +## Future Enhancements |
| 327 | + |
| 328 | +Potential extensions: |
| 329 | +1. **Additional aggregation levels**: County, HSA, HRR support |
| 330 | +2. **Custom distance weighting**: Gaussian kernel, exponential decay |
| 331 | +3. **Population-weighted features**: Weight by neighbor population, not just distance |
| 332 | +4. **Temporal smoothing**: Moving averages of wave features |
| 333 | +5. **Asymmetric cones**: Different cone widths for different directions |
| 334 | + |
| 335 | +## Troubleshooting |
| 336 | + |
| 337 | +### "Missing coordinates for locations" error |
| 338 | +- Ensure all locations in your data have entries in `STATE_CENTROIDS` (spatial_utils.py) |
| 339 | +- Check that `agg_level` in your data matches supported levels ('state') |
| 340 | + |
| 341 | +### Wave features are all NaN |
| 342 | +- Check `wave_max_distance_km` - may be too small |
| 343 | +- Verify location codes match those in `STATE_CENTROIDS` |
| 344 | +- Some edge locations (islands, coastal states) naturally have fewer neighbors |
| 345 | + |
| 346 | +### Training is slow |
| 347 | +- Reduce number of directions (try just ['N', 'S', 'E', 'W']) |
| 348 | +- Reduce `wave_max_distance_km` |
| 349 | +- Disable velocity features |
| 350 | +- Use `fit_locations_separately=True` in model_config |
| 351 | + |
| 352 | +## Contact |
| 353 | + |
| 354 | +For questions or issues, please file an issue on the GitHub repository. |
0 commit comments