Skip to content

Commit 938c8ab

Browse files
feat: horseshoe prior for BSTS regression (Kohns & Bhattacharjee 2022) (#29)
* feat: horseshoe prior for BSTS regression (Kohns & Bhattacharjee 2022) Add continuous shrinkage alternative to spike-and-slab via PriorType enum. Makalic & Schmidt (2015) auxiliary variable augmentation: β joint update through precision matrix, per-covariate λ²/ν and global τ²/ξ sampling. - HorseshoeState struct with kappa() shrinkage diagnostic - sample_horseshoe() in both seasonal and non-seasonal Gibbs paths - PriorType enum (SpikeSlab/Horseshoe) replacing string comparison - sample_inv_gamma guard against non-finite params (fixes panic on extreme-scale inputs with standardize_data=False) - kappa() uses same 1e-30 floor as precision diagonal for consistency - GibbsSamples.kappa_shrinkage field exposed via PyO3 arXiv:2011.00938, arXiv:1508.02502 * feat: expose horseshoe prior via ModelOptions(prior_type='horseshoe') - ModelOptions: prior_type field with spike_slab/horseshoe validation - _normalize_model_args: unknown dict key rejection (fixes silent typo fallback e.g. prior_typee), routes through ModelOptions.__post_init__ - posterior_shrinkage property: mean kappa_j per covariate (horseshoe only) - posterior_inclusion_probs returns None for horseshoe (not applicable) - horseshoe + dynamic_regression/retrospective → ValueError * test: horseshoe prior specs (47 tests) + options/sampler boundary tests - test_horseshoe.py: 10 classes, 47 tests covering output shape, shrinkage behavior, posterior_shrinkage, positivity, backward compat, numerical stability, validation, Python API, dense DGP, seasonal - test_options.py: TestPriorTypeBoundary (6 tests) for ModelOptions - test_rust_sampler.py: TestSamplerKappaShrinkage (3 tests) for PyO3 boundary verification of kappa_shrinkage field * docs: horseshoe prior usage, shrinkage diagnostics, and theory - docs/api.md: horseshoe section with usage, diagnostics table, refs - docs/theory.md: hierarchical model, conditional posteriors, kappa_j - README.md: horseshoe in feature tables and model arguments - CHANGELOG.md: feat + fix entries for horseshoe, InvGamma guard, dict validation, kappa floor consistency * docs: document rationale for implementation decisions not in papers Three horseshoe design choices lack prescriptive guidance in the reference papers (Kohns 2022, Makalic 2015, Carvalho 2010). Document the reasoning so reviewers can evaluate them independently: - tau0 heuristic: y_sd / (sqrt(k) * y_norm), chain forgets within warmup - Numerical clamping on derived precision only (not raw draws) to preserve posterior fidelity while protecting Cholesky decomposition - Gibbs ordering difference: horseshoe beta-first (Makalic Alg.1) vs spike-slab sigma2-first (cold-start avoidance) * fix: covariate length validation at PyO3 boundary + docs/test cleanup - lib.rs: reject covariate columns whose length != len(y) with a clear error message before passing data to the Rust sampler - api.md: separate mode parameter into its own subsection (mode is passed via dict, not ModelOptions) - test_options.py: verify ModelOptions rejects mode kwarg (TypeError) - test_rust_sampler.py: verify covariate length mismatch raises ValueError * style: fix E501 line-too-long in test docstrings (ruff 88-char limit)
1 parent f9e6d3b commit 938c8ab

13 files changed

Lines changed: 1617 additions & 32 deletions

File tree

CHANGELOG.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,24 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on Keep a Changelog.
66

7+
## [Unreleased]
8+
9+
### Added
10+
11+
- Horseshoe prior as alternative to spike-and-slab via `ModelOptions(prior_type='horseshoe')`
12+
(Kohns & Bhattacharjee 2022, arXiv:2011.00938). Recommended for dense DGP settings
13+
where many covariates have true effects.
14+
- `posterior_shrinkage` property: mean shrinkage factor kappa_j per covariate (horseshoe only).
15+
- `kappa_shrinkage` field in Rust sampler output for per-iteration shrinkage diagnostics.
16+
17+
### Fixed
18+
19+
- `sample_inv_gamma` no longer panics on non-finite parameters (e.g. extreme-scale
20+
inputs with `standardize_data=False`). Returns a small positive fallback instead.
21+
- `_normalize_model_args` now rejects unknown dict keys (e.g. typo `prior_typee`
22+
silently falling back to `spike_slab` is no longer possible).
23+
- `kappa()` diagnostic now uses the same floor as the precision diagonal for consistency.
24+
725
## [1.6.0] - 2026-03-25
826

927
### Added

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ Posterior prob. of a causal effect: 99.90%
100100
| Algorithm | Gibbs (bsts/C++) | Gibbs (Rust) | TFP-based | VI default / HMC | MLE (statsmodels) |
101101
| Dependencies | R, bsts | numpy, pandas, matplotlib | TF, TFP (3 GB+) | TF, TFP (3 GB+) | statsmodels |
102102
| Spike-and-slab | Yes | Yes | Unknown | No | No |
103+
| Horseshoe prior | No | Yes (`prior_type='horseshoe'`) | No | No | No |
103104
| Seasonal component | Yes | Yes (`nseasons`, `season_duration`) | Unknown | Yes (TFP STS) | No |
104105
| Dynamic regression | Yes | Yes (`dynamic_regression=True`) | Unknown | No | No |
105106
| R numerical test | Reference | ±1% CI-enforced + TOST/ROPE | Not published | Visual comparison (~8% diff) | Not tested |
@@ -205,6 +206,7 @@ Evidence per implementation (all verified from source code, not documentation cl
205206
| DATE decomposition | Extended | Decomposes effects into spot/persistent/trend (arXiv:2602.00836) |
206207
| Retrospective mode | Extended | Treatment indicators as covariates; effects from beta posteriors (arXiv:2602.00836) |
207208
| Placebo test | Extended | Null distribution from pre-period splits |
209+
| Horseshoe prior | Extended | Continuous shrinkage alternative to spike-and-slab (Kohns & Bhattacharjee 2022) |
208210
| Conformal inference | Extended | Distribution-free prediction intervals |
209211
| DTW control selection | Extended | Automatic covariate selection via Dynamic Time Warping |
210212

@@ -223,6 +225,7 @@ Features that go beyond R's CausalImpact. These have no R equivalent.
223225
| Placebo test | `ci.run_placebo_test()` | Validates effect against null distribution from pre-period splits | |
224226
| Conformal inference | `ci.run_conformal_analysis()` | Distribution-free prediction intervals | Vovk et al. (2005) |
225227
| DTW control selection | `select_controls()` | Automatic covariate selection via Dynamic Time Warping | Sakoe & Chiba (1978) |
228+
| Horseshoe prior | `ModelOptions(prior_type='horseshoe')` | Continuous shrinkage alternative to spike-and-slab for dense DGP | Kohns & Bhattacharjee (2022), arXiv:2011.00938 |
226229

227230
## API
228231

@@ -251,6 +254,7 @@ Features that go beyond R's CausalImpact. These have no R equivalent.
251254
| `season_duration` | `None` | Optional duration of each seasonal block; defaults to `1` when `nseasons` is set |
252255
| `dynamic_regression` | `False` | Enable time-varying regression coefficients (random-walk beta) |
253256
| `state_model` | `"local_level"` | `"local_level"` or `"local_linear_trend"` |
257+
| `prior_type` | `"spike_slab"` | `"spike_slab"` or `"horseshoe"` (continuous shrinkage for dense DGP) |
254258
| `mode` | `"forward"` | `"forward"` (counterfactual prediction) or `"retrospective"` (treatment indicators as covariates) |
255259

256260
#### Methods and Properties
@@ -262,7 +266,8 @@ Features that go beyond R's CausalImpact. These have no R equivalent.
262266
| `plot(metrics=None)` | `Figure` | Matplotlib figure with original/pointwise/cumulative panels |
263267
| `inferences` | `DataFrame` | Per-timestep actuals, predictions, prediction s.d., and effect intervals |
264268
| `summary_stats` | `dict` | Aggregate statistics (effect mean, CI, p-value, etc.) |
265-
| `posterior_inclusion_probs` | `ndarray \| None` | Posterior inclusion probability per covariate |
269+
| `posterior_inclusion_probs` | `ndarray \| None` | Posterior inclusion probability per covariate (spike-and-slab only) |
270+
| `posterior_shrinkage` | `ndarray \| None` | Mean shrinkage factor per covariate (horseshoe only) |
266271
| `decompose(alpha=None)` | `DateDecomposition` | DATE decomposition into spot/persistent/trend components |
267272
| `run_placebo_test(...)` | `PlaceboTestResults` | Placebo test for effect validation |
268273
| `run_conformal_analysis(...)` | `ConformalResults` | Distribution-free conformal prediction intervals |

docs/api.md

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ ci = CausalImpact(data, pre_period, post_period, model_args=None, alpha=0.05)
3535
|---|---|---|
3636
| `inferences` | `DataFrame` | Per-timestep actuals, predictions, prediction s.d., and effect intervals |
3737
| `summary_stats` | `dict` | Aggregate statistics (effect mean, CI, p-value, etc.) |
38-
| `posterior_inclusion_probs` | `ndarray \| None` | Posterior inclusion probability per covariate (requires covariates) |
38+
| `posterior_inclusion_probs` | `ndarray \| None` | Posterior inclusion probability per covariate (spike-and-slab only; returns `None` for horseshoe) |
39+
| `posterior_shrinkage` | `ndarray \| None` | Mean shrinkage factor kappa_j per covariate (horseshoe only; returns `None` for spike-and-slab). Values near 0 = weakly shrunk (included), near 1 = strongly shrunk. |
3940

4041
## `ModelOptions`
4142

@@ -58,11 +59,24 @@ ci = CausalImpact(data, pre_period, post_period, model_args=opts)
5859
| `standardize_data` | `bool` | `True` | Standardize data before fitting |
5960
| `expected_model_size` | `int` | 2 | Expected number of active covariates for spike-and-slab prior |
6061
| `dynamic_regression` | `bool` | `False` | Enable time-varying regression coefficients |
62+
| `prior_type` | `str` | `"spike_slab"` | `"spike_slab"` (discrete variable selection) or `"horseshoe"` (continuous shrinkage). Horseshoe is recommended for dense DGP settings. |
6163
| `state_model` | `str` | `"local_level"` | `"local_level"` or `"local_linear_trend"` |
62-
| `mode` | `str` | `"forward"` | `"forward"` (counterfactual prediction) or `"retrospective"` (treatment indicators as covariates). Retrospective mode adds spot/persistent/trend columns to X and fits on the entire series. Effects are extracted from beta posteriors. |
6364
| `nseasons` | `int \| None` | `None` | Seasonal cycle count. `nseasons=1` is equivalent to no seasonal component. |
6465
| `season_duration` | `int \| None` | `None` | Duration of each seasonal block; defaults to 1 when `nseasons` is set. Requires `nseasons` to be set. |
6566

67+
### Analysis Mode
68+
69+
`mode` controls forward vs retrospective analysis. Pass via `model_args` dict (not `ModelOptions`).
70+
71+
| Value | Description |
72+
|---|---|
73+
| `"forward"` (default) | Counterfactual prediction: fit on pre-period, predict post-period |
74+
| `"retrospective"` | Treatment indicators as covariates: fit on entire series |
75+
76+
```python
77+
ci = CausalImpact(data, pre, post, model_args={"mode": "retrospective"})
78+
```
79+
6680
## `CausalImpactResults`
6781

6882
Returned by `ci._results`. A frozen dataclass containing all computed quantities.
@@ -86,6 +100,54 @@ Returned by `ci._results`. A frozen dataclass containing all computed quantities
86100
| `predictions_lower` | `ndarray` | Lower CI on counterfactual |
87101
| `predictions_upper` | `ndarray` | Upper CI on counterfactual |
88102

103+
## Horseshoe Prior (alternative to spike-and-slab)
104+
105+
CausalImpact supports the horseshoe prior (Carvalho, Polson & Scott 2010)
106+
applied to BSTS regression, following the formulation of
107+
Kohns & Bhattacharjee (2022) (arXiv:2011.00938).
108+
109+
### When to use horseshoe
110+
111+
| Scenario | Recommended prior |
112+
|---|---|
113+
| Few true covariates (sparse DGP) | `spike_slab` (default) |
114+
| Many true covariates (dense DGP) | `horseshoe` |
115+
116+
### Usage
117+
118+
```python
119+
from causal_impact import CausalImpact, ModelOptions
120+
121+
ci = CausalImpact(
122+
data, pre_period, post_period,
123+
model_args=ModelOptions(prior_type='horseshoe'),
124+
)
125+
print(ci.posterior_shrinkage) # mean(kappa_j), 0=included 1=shrunk
126+
# ci.posterior_inclusion_probs is None for horseshoe (spike-slab only)
127+
```
128+
129+
### Shrinkage diagnostics
130+
131+
| Property | prior_type | Meaning |
132+
|---|---|---|
133+
| `posterior_inclusion_probs` | `spike_slab` | E[gamma_j] — discrete inclusion probability |
134+
| `posterior_inclusion_probs` | `horseshoe` | `None` (not applicable) |
135+
| `posterior_shrinkage` | `horseshoe` | E[kappa_j] — continuous shrinkage factor kappa_j = 1/(1+lambda_j^2 * tau^2). Values close to 0 indicate the covariate is weakly shrunk (effectively included). |
136+
| `posterior_shrinkage` | `spike_slab` | `None` (not applicable) |
137+
138+
### Incompatible combinations
139+
140+
- `prior_type='horseshoe'` + `dynamic_regression=True` raises `ValueError`
141+
- `prior_type='horseshoe'` + `mode='retrospective'` raises `ValueError`
142+
143+
### References
144+
145+
- Kohns, D. & Bhattacharjee, A. (2022). Horseshoe Prior for Sparse Bayesian Structural Time Series. arXiv:2011.00938.
146+
- Makalic, E. & Schmidt, D.F. (2015). A simple sampler for the horseshoe estimator. IEEE Signal Processing Letters, 23(1), 179-182.
147+
- Carvalho, C.M., Polson, N.G. & Scott, J.G. (2010). The horseshoe estimator for sparse signals. Biometrika, 97(2), 465-480.
148+
149+
---
150+
89151
## Beyond R Extensions
90152

91153
### Retrospective Mode

docs/theory.md

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Five additional capabilities extend the analysis beyond what R offers.
1010
| Placebo test | `ci.run_placebo_test()` | Validate effects against null distribution |
1111
| Conformal inference | `ci.run_conformal_analysis()` | Distribution-free prediction intervals |
1212
| DTW control selection | `select_controls()` | Automatic covariate selection |
13+
| Horseshoe prior | `ModelOptions(prior_type='horseshoe')` | Continuous shrinkage for dense DGP |
1314

1415
---
1516

@@ -245,6 +246,166 @@ for spoken word recognition."
245246

246247
---
247248

249+
## Horseshoe Prior
250+
251+
### What it does
252+
253+
The horseshoe prior (Carvalho, Polson & Scott 2010) is a continuous shrinkage
254+
alternative to spike-and-slab variable selection. While spike-and-slab performs
255+
discrete inclusion/exclusion of covariates (gamma_j in {0,1}), the horseshoe
256+
applies adaptive shrinkage that can handle dense DGP settings where many
257+
covariates have true effects.
258+
259+
Reference: Kohns & Bhattacharjee (2022), "Horseshoe Prior for Sparse Bayesian
260+
Structural Time Series" (arXiv:2011.00938).
261+
262+
### Hierarchical model
263+
264+
The horseshoe hierarchy uses Half-Cauchy priors decomposed into InvGamma
265+
auxiliary variables (Makalic & Schmidt 2015):
266+
267+
```
268+
beta_j | lambda_j, tau, sigma2 ~ N(0, lambda_j^2 * tau^2 * sigma2_obs)
269+
lambda_j^2 | nu_j ~ InvGamma(1/2, 1/nu_j)
270+
nu_j ~ InvGamma(1/2, 1)
271+
tau^2 | xi ~ InvGamma(1/2, 1/xi)
272+
xi ~ InvGamma(1/2, 1)
273+
```
274+
275+
The conditional posteriors used in the Gibbs sampler:
276+
277+
```
278+
lambda_j^2 | . ~ InvGamma(1, 1/nu_j + beta_j^2 / (2 * tau^2 * sigma2))
279+
nu_j | . ~ InvGamma(1, 1 + 1/lambda_j^2)
280+
tau^2 | . ~ InvGamma((k+1)/2, 1/xi + sum(beta_j^2 / (2 * lambda_j^2 * sigma2)))
281+
xi | . ~ InvGamma(1, 1 + 1/tau^2)
282+
```
283+
284+
### Beta joint update
285+
286+
Unlike spike-and-slab (coordinate-wise), horseshoe uses a joint beta update:
287+
288+
```
289+
A = X'X + diag(1 / (lambda_j^2 * tau^2)) (precision matrix)
290+
b = X'(y - state - seasonal) (right-hand side)
291+
beta ~ N(A^{-1} b, sigma2_obs * A^{-1}) (sampled via Cholesky)
292+
```
293+
294+
### Shrinkage factor
295+
296+
The shrinkage factor kappa_j measures how much each covariate is shrunk:
297+
298+
```
299+
kappa_j = 1 / (1 + lambda_j^2 * tau^2)
300+
```
301+
302+
- kappa_j close to 1: strong shrinkage (covariate effectively excluded)
303+
- kappa_j close to 0: weak shrinkage (covariate effectively included)
304+
305+
The `posterior_shrinkage` property returns E[kappa_j] averaged over post-warmup
306+
MCMC iterations.
307+
308+
### When to use
309+
310+
| Scenario | Recommended prior |
311+
|---|---|
312+
| Few true covariates among many candidates (sparse DGP) | `spike_slab` (default) |
313+
| Many covariates with true effects (dense DGP) | `horseshoe` |
314+
| Time-varying coefficients | `spike_slab` (horseshoe + dynamic_regression not supported) |
315+
316+
### Usage
317+
318+
```python
319+
from causal_impact import CausalImpact, ModelOptions
320+
321+
ci = CausalImpact(
322+
data, pre_period, post_period,
323+
model_args=ModelOptions(prior_type='horseshoe', niter=2000, seed=42),
324+
)
325+
326+
# Shrinkage diagnostics
327+
print(ci.posterior_shrinkage) # E[kappa_j] per covariate
328+
# posterior_inclusion_probs is None for horseshoe
329+
```
330+
331+
### Implementation decisions not specified in the papers
332+
333+
The following design choices are not prescribed by the reference papers.
334+
Each choice is documented here with its rationale so that reviewers can
335+
evaluate them independently.
336+
337+
#### tau0 initialization
338+
339+
The global shrinkage parameter tau^2 requires an initial value for the
340+
Gibbs sampler. None of the three reference papers specify a concrete
341+
formula. This implementation uses a data-adaptive heuristic:
342+
343+
```
344+
y_norm = ||y_pre||_2 / sqrt(T_pre)
345+
tau0 = y_sd / (sqrt(k) * y_norm)
346+
tau^2_init = tau0^2
347+
```
348+
349+
Rationale: after standardization y_sd is approximately 1. Dividing by
350+
sqrt(k) prevents the global scale from growing with the number of
351+
covariates. Dividing by y_norm anchors the prior scale to the signal
352+
magnitude. Because tau^2 is resampled at every Gibbs iteration, the
353+
chain forgets the initial value within the warmup period. If y_norm
354+
is near zero (constant y), tau0 falls back to 1.0 so that the prior
355+
remains diffuse.
356+
357+
#### Numerical clamping on derived precision (not on raw draws)
358+
359+
Raw InvGamma draws for lambda_j^2 and tau^2 receive no floor. Clamping
360+
raw draws would distort the posterior distribution. Instead, the derived
361+
precision diagonal entry is clamped:
362+
363+
```
364+
lambda_tau_prod = max(lambda_j^2 * tau^2, 1e-30) -- prevents 0-division
365+
prior_prec = min(1 / lambda_tau_prod, 1e12) -- prevents inf diagonal
366+
```
367+
368+
This approach keeps the posterior intact while protecting the Cholesky
369+
decomposition from numerical failure. The kappa() diagnostic uses the
370+
same 1e-30 floor so that shrinkage values stay consistent with the
371+
precision matrix actually used in the beta update.
372+
373+
The fallback in sample_inv_gamma (returning 1e-30 for non-finite
374+
parameters) serves as a last-resort guard. It triggers only under
375+
extreme-scale inputs (e.g. standardize_data=False with y of order 1e200)
376+
where the scale parameter overflows to infinity.
377+
378+
#### Gibbs sampling order
379+
380+
Horseshoe and spike-and-slab use different orderings within the Gibbs loop:
381+
382+
```
383+
Horseshoe: state -> beta (joint) -> lambda2/nu -> tau2/xi -> sigma2_obs
384+
Spike-slab: state -> sigma2_obs -> beta (coordinate-wise)
385+
```
386+
387+
Horseshoe samples beta jointly via a precision matrix conditioned on the
388+
current sigma2_obs. After the joint beta update, sigma2_obs is resampled
389+
conditioned on the updated residual. This follows Algorithm 1 of Makalic
390+
& Schmidt (2015) where the regression step precedes the variance update.
391+
392+
Spike-and-slab samples sigma2_obs first because its coordinate-wise
393+
variable selection (gamma_j) is sensitive to cold-start: sampling beta
394+
with an uninformative sigma2_obs on the first iteration can cause the
395+
sampler to exclude all covariates. Sampling sigma2_obs first gives beta
396+
a reasonable scale to condition on.
397+
398+
### References
399+
400+
- Carvalho, C.M., Polson, N.G. & Scott, J.G. (2010). The horseshoe estimator
401+
for sparse signals. Biometrika, 97(2), 465-480.
402+
- Kohns, D. & Bhattacharjee, A. (2022). Horseshoe Prior for Sparse Bayesian
403+
Structural Time Series. arXiv:2011.00938.
404+
- Makalic, E. & Schmidt, D.F. (2015). A simple sampler for the horseshoe
405+
estimator. IEEE Signal Processing Letters, 23(1), 179-182.
406+
407+
---
408+
248409
## Citation
249410

250411
```bibtex

0 commit comments

Comments
 (0)