Skip to content

Commit b870a07

Browse files
committed
adding directional wave features documentation
1 parent 20597ce commit b870a07

1 file changed

Lines changed: 354 additions & 0 deletions

File tree

docs/directional_wave_features.md

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
# Directional Wave Features for GBQR Model
2+
3+
## Overview
4+
5+
Directional wave features capture spatial-temporal patterns in disease spread by computing distance-weighted averages of neighboring locations' incidence in specified directions (N, NE, E, SE, S, SW, W, NW).
6+
7+
These features allow the GBQR model to learn how disease "waves" propagate geographically over time, improving forecast accuracy when spatial spread patterns are important.
8+
9+
## Motivation
10+
11+
Traditional forecasting models treat each location independently or use simple spatial averaging. Directional wave features enable the model to:
12+
13+
1. **Capture directional spread patterns**: Disease may spread preferentially in certain directions (e.g., following travel corridors, climate patterns)
14+
2. **Learn wave propagation speed**: By including temporal lags, the model can learn how long it takes for a wave to travel between locations
15+
3. **Distinguish between spreading and receding waves**: Velocity features capture acceleration/deceleration of spread
16+
17+
## Feature Types
18+
19+
For each location and time point, the following features are generated:
20+
21+
### 1. Base Directional Features
22+
- `inc_trans_cs_wave_N`: Distance-weighted average of northern neighbors' incidence
23+
- `inc_trans_cs_wave_NE`: Distance-weighted average of northeastern neighbors' incidence
24+
- `inc_trans_cs_wave_E`: Distance-weighted average of eastern neighbors' incidence
25+
- ... (one for each specified direction)
26+
27+
### 2. Aggregate Feature
28+
- `inc_trans_cs_wave_avg`: Overall distance-weighted average of all neighbors (regardless of direction)
29+
30+
### 3. Temporal Lag Features
31+
- `inc_trans_cs_wave_N_lag1`: Northern neighbors' incidence from 1 week ago
32+
- `inc_trans_cs_wave_N_lag2`: Northern neighbors' incidence from 2 weeks ago
33+
- ... (for each direction and lag)
34+
35+
**Important**: `lag1` refers to time t-1, `lag2` refers to time t-2, etc.
36+
37+
### 4. Velocity Features (optional)
38+
- `inc_trans_cs_wave_N_velocity`: Rate of change = current - lag1
39+
- ... (one for each direction)
40+
41+
## Configuration
42+
43+
Directional wave features are **disabled by default** for backwards compatibility. To enable them, add the following parameters to your `model_config`:
44+
45+
```python
46+
from types import SimpleNamespace
47+
48+
model_config = SimpleNamespace(
49+
# ... existing parameters ...
50+
51+
# Directional wave features (disabled by default)
52+
use_directional_waves = True, # Set to True to enable
53+
54+
# Which directions to compute (subset of: N, NE, E, SE, S, SW, W, NW)
55+
wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], # Default: all 8
56+
57+
# Temporal lags to include (lag1 = t-1, lag2 = t-2)
58+
wave_temporal_lags = [1, 2], # Default: [1, 2]
59+
60+
# Maximum distance (km) to consider as neighbor
61+
wave_max_distance_km = 1000, # Default: 1000
62+
63+
# Include velocity (rate of change) features
64+
wave_include_velocity = False, # Default: False
65+
66+
# Include aggregate weighted average feature
67+
wave_include_aggregate = True # Default: True
68+
)
69+
```
70+
71+
### Configuration Parameters Explained
72+
73+
#### `use_directional_waves` (bool, default: False)
74+
- Master switch to enable/disable directional wave features
75+
- Must be set to `True` to generate wave features
76+
77+
#### `wave_directions` (list of str, default: ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'])
78+
- Which directions to compute features for
79+
- Valid directions: N, NE, E, SE, S, SW, W, NW
80+
- Each direction has a 45° cone (±22.5° around center)
81+
- Examples:
82+
- `['N', 'S', 'E', 'W']` - Just cardinal directions (4 features)
83+
- `['NE', 'SW']` - Just diagonal directions (2 features)
84+
85+
#### `wave_temporal_lags` (list of int, default: [1, 2])
86+
- Which temporal lags to include
87+
- `lag1` means t-1 (last week), `lag2` means t-2 (two weeks ago)
88+
- Example: `[1, 2, 3]` includes 1, 2, and 3 week lags
89+
90+
#### `wave_max_distance_km` (float, default: 1000)
91+
- Maximum distance (kilometers) to consider a location as a neighbor
92+
- Only locations within this distance are included in directional averages
93+
- Larger values include more distant neighbors (slower computation)
94+
- Typical values:
95+
- 500-1000 km for state-level analysis (immediate neighbors)
96+
- 2000-3000 km for regional patterns
97+
- 5000+ km for continent-wide patterns
98+
99+
#### `wave_include_velocity` (bool, default: False)
100+
- Whether to include velocity features (rate of change)
101+
- Velocity = current - lag1
102+
- Captures acceleration/deceleration of wave spread
103+
- Increases feature count by ~50% (one velocity per direction)
104+
105+
#### `wave_include_aggregate` (bool, default: True)
106+
- Whether to include overall weighted average (all neighbors, any direction)
107+
- Provides general spatial context independent of direction
108+
- Recommended to keep enabled
109+
110+
## Example Configurations
111+
112+
### Minimal Configuration (4 cardinal directions)
113+
```python
114+
model_config = SimpleNamespace(
115+
# ... other params ...
116+
use_directional_waves = True,
117+
wave_directions = ['N', 'S', 'E', 'W']
118+
)
119+
```
120+
Generates: 4 base + 4 aggregate + (4+1)×2 lags = **14 features**
121+
122+
### Standard Configuration (8 directions)
123+
```python
124+
model_config = SimpleNamespace(
125+
# ... other params ...
126+
use_directional_waves = True,
127+
wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'],
128+
wave_temporal_lags = [1, 2]
129+
)
130+
```
131+
Generates: 8 base + 1 aggregate + (8+1)×2 lags = **27 features**
132+
133+
### Maximum Information (all options)
134+
```python
135+
model_config = SimpleNamespace(
136+
# ... other params ...
137+
use_directional_waves = True,
138+
wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'],
139+
wave_temporal_lags = [1, 2],
140+
wave_max_distance_km = 2000,
141+
wave_include_velocity = True,
142+
wave_include_aggregate = True
143+
)
144+
```
145+
Generates: 8 base + 1 aggregate + (8+1)×2 lags + (8+1) velocity = **36 features**
146+
147+
### Hypothesis-Driven (specific directions)
148+
```python
149+
# If you suspect disease spreads along NE-SW axis
150+
model_config = SimpleNamespace(
151+
# ... other params ...
152+
use_directional_waves = True,
153+
wave_directions = ['NE', 'SW'],
154+
wave_temporal_lags = [1, 2, 3], # Longer lags for slower spread
155+
wave_max_distance_km = 1500
156+
)
157+
```
158+
159+
## Technical Details
160+
161+
### Directional Cone Definition
162+
Each direction captures neighbors within a 45° cone:
163+
- **N (North)**: 0° ±22.5° (337.5° to 22.5°)
164+
- **NE (Northeast)**: 45° ±22.5° (22.5° to 67.5°)
165+
- **E (East)**: 90° ±22.5° (67.5° to 112.5°)
166+
- ... and so on
167+
168+
### Distance Weighting
169+
Inverse distance weighting is used:
170+
```
171+
weight = 1 / distance
172+
weighted_average = Σ(weight × neighbor_value) / Σ(weight)
173+
```
174+
175+
Closer neighbors have more influence on the feature value.
176+
177+
### Lag Semantics
178+
- **Base feature** (no lag suffix): Uses current time t
179+
- **lag1**: Uses time t-1 (one week ago)
180+
- **lag2**: Uses time t-2 (two weeks ago)
181+
182+
This allows the model to learn patterns like: "If northern neighbors had high incidence last week (lag1), expect it here this week."
183+
184+
### Missing Values
185+
- If a location has no neighbors in a direction (within max_distance_km), the feature value is NaN
186+
- These are handled by LightGBM during training
187+
- Edge locations (e.g., coastal states) may have missing values for certain directions
188+
189+
## Location Support
190+
191+
Currently supported:
192+
- **State-level** (agg_level='state'): US states, DC, PR, and national level
193+
194+
**Data Source:** State centroids are loaded from `src/idmodels/data/state_centroids.csv`, which contains geographic centroids computed from US Census Bureau TIGER/Line shapefiles. See `src/idmodels/data/README.md` for detailed source information.
195+
196+
To add support for other aggregation levels (county, HSA, etc.):
197+
1. Create a CSV file (e.g., `county_centroids.csv`) with columns: `fips`, `name`, `latitude`, `longitude`
198+
2. Place it in `src/idmodels/data/`
199+
3. Update `_load_state_centroids()` function in `src/idmodels/spatial_utils.py` to support the new level
200+
4. Document the data source in `src/idmodels/data/README.md`
201+
202+
## Performance Considerations
203+
204+
### Computational Cost
205+
- Scales O(n²) with number of locations (n)
206+
- For 50 states: ~2,500 distance calculations (precomputed)
207+
- Feature computation is done once per training run
208+
209+
### Feature Count Impact
210+
Feature count depends on configuration:
211+
- Base: n_directions + (1 if aggregate else 0)
212+
- With lags: base × (1 + len(temporal_lags))
213+
- With velocity: total × 1.5 (approximately)
214+
215+
More features → longer training time, but potentially better predictions
216+
217+
### Recommendations
218+
- Start with default configuration (8 directions, 2 lags)
219+
- Experiment with fewer directions if training is slow
220+
- Use `wave_include_velocity=False` unless you have evidence of acceleration patterns
221+
222+
## Interpretation
223+
224+
### Feature Importance
225+
After training, you can examine feature importance to understand:
226+
- Which directions are most predictive (e.g., is NE spread more important than SW?)
227+
- Whether lags matter (are lag1 features more important than current?)
228+
- Whether velocity features add value
229+
230+
### Example Interpretations
231+
- **High importance for `wave_N_lag1`**: Disease tends to arrive from the north with 1-week delay
232+
- **High importance for `wave_avg`**: General spatial clustering matters more than direction
233+
- **High importance for `wave_NE_velocity`**: Acceleration of northeastern spread is predictive
234+
235+
## Warnings and Validation
236+
237+
The implementation includes validation that warns about:
238+
- **Opposite directions included**: If both N and S (or E and W, etc.) are included, they may be correlated in datasets with uniform spatial patterns. However, in typical epidemic scenarios with directional spread, opposite directions provide independent information. Tree-based models like LightGBM are also robust to multicollinearity, so this warning is informational rather than critical.
239+
240+
## Example: Complete GBQR Configuration
241+
242+
```python
243+
from types import SimpleNamespace
244+
from idmodels.gbqr import GBQRModel
245+
246+
# Model configuration with directional wave features
247+
model_config = SimpleNamespace(
248+
model_class = "gbqr",
249+
model_name = "gbqr_with_waves",
250+
251+
# Standard GBQR parameters
252+
incl_level_feats = True,
253+
num_bags = 10,
254+
bag_frac_samples = 0.7,
255+
reporting_adj = False,
256+
sources = ["nhsn"],
257+
fit_locations_separately = False,
258+
power_transform = "4rt",
259+
260+
# Directional wave features
261+
use_directional_waves = True,
262+
wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'],
263+
wave_temporal_lags = [1, 2],
264+
wave_max_distance_km = 1500,
265+
wave_include_velocity = False,
266+
wave_include_aggregate = True
267+
)
268+
269+
# Run configuration
270+
run_config = SimpleNamespace(
271+
disease = "flu",
272+
ref_date = datetime.date(2024, 1, 6),
273+
output_root = "output/",
274+
artifact_store_root = "artifacts/",
275+
save_feat_importance = True,
276+
locations = None, # All locations
277+
max_horizon = 4,
278+
q_levels = [0.025, 0.10, 0.25, 0.50, 0.75, 0.90, 0.975],
279+
q_labels = ["0.025", "0.1", "0.25", "0.5", "0.75", "0.9", "0.975"]
280+
)
281+
282+
# Run model
283+
model = GBQRModel(model_config)
284+
model.run(run_config)
285+
```
286+
287+
## Backwards Compatibility
288+
289+
The implementation is fully backwards compatible:
290+
- Disabled by default (`use_directional_waves = False`)
291+
- Existing configurations without wave parameters work unchanged
292+
- Uses `hasattr()` checks to gracefully handle missing attributes
293+
294+
Old configurations will continue to work without modification.
295+
296+
## Testing
297+
298+
The implementation includes comprehensive tests:
299+
300+
### Unit Tests
301+
- `tests/unit/test_spatial_utils.py`: Tests for spatial calculations (distance, bearing, neighbors)
302+
- `tests/unit/test_directional_wave_features.py`: Tests for feature generation logic
303+
304+
### Integration Tests
305+
- `tests/integration/test_gbqr_wave_features.py`: End-to-end tests with realistic data
306+
307+
Run tests with:
308+
```bash
309+
uv run pytest tests/unit/test_spatial_utils.py -v
310+
uv run pytest tests/unit/test_directional_wave_features.py -v
311+
uv run pytest tests/integration/test_gbqr_wave_features.py -v
312+
```
313+
314+
## References
315+
316+
### Epidemiological Motivation
317+
- Spatial spread of infectious diseases often follows directional patterns
318+
- Travel corridors, population density gradients, and climate patterns create anisotropic spread
319+
- Historical examples: 1918 flu pandemic, COVID-19 spread in US
320+
321+
### Implementation Details
322+
- Haversine distance formula for great circle distance
323+
- Bearing calculation using spherical trigonometry
324+
- Inverse distance weighting for spatial interpolation
325+
326+
## Future Enhancements
327+
328+
Potential extensions:
329+
1. **Additional aggregation levels**: County, HSA, HRR support
330+
2. **Custom distance weighting**: Gaussian kernel, exponential decay
331+
3. **Population-weighted features**: Weight by neighbor population, not just distance
332+
4. **Temporal smoothing**: Moving averages of wave features
333+
5. **Asymmetric cones**: Different cone widths for different directions
334+
335+
## Troubleshooting
336+
337+
### "Missing coordinates for locations" error
338+
- Ensure all locations in your data have entries in `STATE_CENTROIDS` (spatial_utils.py)
339+
- Check that `agg_level` in your data matches supported levels ('state')
340+
341+
### Wave features are all NaN
342+
- Check `wave_max_distance_km` - may be too small
343+
- Verify location codes match those in `STATE_CENTROIDS`
344+
- Some edge locations (islands, coastal states) naturally have fewer neighbors
345+
346+
### Training is slow
347+
- Reduce number of directions (try just ['N', 'S', 'E', 'W'])
348+
- Reduce `wave_max_distance_km`
349+
- Disable velocity features
350+
- Use `fit_locations_separately=True` in model_config
351+
352+
## Contact
353+
354+
For questions or issues, please file an issue on the GitHub repository.

0 commit comments

Comments
 (0)