|
| 1 | +"""Tests for dynamic regression (time-varying coefficients). |
| 2 | +
|
| 3 | +Dynamic regression allows β_t to vary over time as a random walk, |
| 4 | +unlike static regression where β is constant. This is the key feature |
| 5 | +for capturing structural changes in the pre-intervention relationship. |
| 6 | +""" |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import pytest |
| 10 | +from causal_impact import CausalImpact, ModelOptions |
| 11 | +from causal_impact._core import run_gibbs_sampler |
| 12 | + |
| 13 | +# --------------------------------------------------------------------------- |
| 14 | +# Helpers |
| 15 | +# --------------------------------------------------------------------------- |
| 16 | + |
| 17 | + |
| 18 | +def _make_data_dynamic_k1_constant_beta(n=100, pre_frac=0.7, seed=42): |
| 19 | + """k=1, beta=2.0 constant. Dynamic should match static.""" |
| 20 | + rng = np.random.default_rng(seed) |
| 21 | + pre_end = int(n * pre_frac) |
| 22 | + x = rng.normal(0, 1, n) |
| 23 | + y = 2.0 * x + rng.normal(0, 0.3, n) |
| 24 | + return y, [x.tolist()], pre_end |
| 25 | + |
| 26 | + |
| 27 | +def _make_data_dynamic_k1_structural_break(n=100, pre_frac=0.7, seed=42): |
| 28 | + """k=1, beta jumps from 1.0 to 3.0 midway. Dynamic should track.""" |
| 29 | + rng = np.random.default_rng(seed) |
| 30 | + pre_end = int(n * pre_frac) |
| 31 | + x = rng.normal(0, 1, n) |
| 32 | + beta_true = np.where(np.arange(n) < n // 2, 1.0, 3.0) |
| 33 | + y = beta_true * x + rng.normal(0, 0.3, n) |
| 34 | + return y, [x.tolist()], pre_end |
| 35 | + |
| 36 | + |
| 37 | +def _make_data_dynamic_k2(n=100, pre_frac=0.7, seed=42): |
| 38 | + """k=2, beta1=1.0 constant, beta2 gradually changes.""" |
| 39 | + rng = np.random.default_rng(seed) |
| 40 | + pre_end = int(n * pre_frac) |
| 41 | + x1 = rng.normal(0, 1, n) |
| 42 | + x2 = rng.normal(0, 1, n) |
| 43 | + beta2 = np.linspace(0.5, 2.5, n) |
| 44 | + y = 1.0 * x1 + beta2 * x2 + rng.normal(0, 0.3, n) |
| 45 | + return y, [x1.tolist(), x2.tolist()], pre_end |
| 46 | + |
| 47 | + |
| 48 | +def _run_sampler_dynamic(y, x, pre_end, niter=500, nwarmup=250, seed=42): |
| 49 | + """Call run_gibbs_sampler with dynamic_regression=True.""" |
| 50 | + return run_gibbs_sampler( |
| 51 | + y=y.tolist() if hasattr(y, "tolist") else list(y), |
| 52 | + x=x if x else None, |
| 53 | + pre_end=pre_end, |
| 54 | + niter=niter, |
| 55 | + nwarmup=nwarmup, |
| 56 | + nchains=1, |
| 57 | + seed=seed, |
| 58 | + prior_level_sd=0.01, |
| 59 | + expected_model_size=1.0, |
| 60 | + nseasons=None, |
| 61 | + season_duration=None, |
| 62 | + dynamic_regression=True, |
| 63 | + ) |
| 64 | + |
| 65 | + |
| 66 | +def _run_sampler_static(y, x, pre_end, niter=500, nwarmup=250, seed=42): |
| 67 | + """Call run_gibbs_sampler with dynamic_regression=False.""" |
| 68 | + return run_gibbs_sampler( |
| 69 | + y=y.tolist() if hasattr(y, "tolist") else list(y), |
| 70 | + x=x if x else None, |
| 71 | + pre_end=pre_end, |
| 72 | + niter=niter, |
| 73 | + nwarmup=nwarmup, |
| 74 | + nchains=1, |
| 75 | + seed=seed, |
| 76 | + prior_level_sd=0.01, |
| 77 | + expected_model_size=1.0, |
| 78 | + nseasons=None, |
| 79 | + season_duration=None, |
| 80 | + dynamic_regression=False, |
| 81 | + ) |
| 82 | + |
| 83 | + |
| 84 | +# --------------------------------------------------------------------------- |
| 85 | +# Option validation (2 tests) |
| 86 | +# --------------------------------------------------------------------------- |
| 87 | + |
| 88 | + |
| 89 | +class TestDynamicRegressionOptions: |
| 90 | + def test_default_is_false(self): |
| 91 | + opts = ModelOptions() |
| 92 | + assert opts.dynamic_regression is False |
| 93 | + |
| 94 | + def test_true_accepted(self): |
| 95 | + opts = ModelOptions(dynamic_regression=True) |
| 96 | + assert opts.dynamic_regression is True |
| 97 | + |
| 98 | + |
| 99 | +# --------------------------------------------------------------------------- |
| 100 | +# Basic behavior (3 tests) |
| 101 | +# --------------------------------------------------------------------------- |
| 102 | + |
| 103 | + |
| 104 | +class TestDynamicRegressionBasic: |
| 105 | + def test_predictions_shape_unchanged(self): |
| 106 | + y, x, pre_end = _make_data_dynamic_k1_constant_beta() |
| 107 | + result = _run_sampler_dynamic(y, x, pre_end, niter=50, nwarmup=25) |
| 108 | + preds = np.array(result.predictions) |
| 109 | + t_post = len(y) - pre_end |
| 110 | + assert preds.shape == (50, t_post) |
| 111 | + |
| 112 | + def test_posterior_inclusion_probs_none_when_dynamic(self): |
| 113 | + """Spike-and-slab is disabled when dynamic_regression=True.""" |
| 114 | + y, x, pre_end = _make_data_dynamic_k1_constant_beta() |
| 115 | + result = _run_sampler_dynamic(y, x, pre_end, niter=50, nwarmup=25) |
| 116 | + # gamma should be empty (spike-and-slab disabled) |
| 117 | + assert result.gamma == [] or all(len(g) == 0 for g in result.gamma) |
| 118 | + |
| 119 | + def test_false_matches_existing_behavior(self): |
| 120 | + y, x, pre_end = _make_data_dynamic_k1_constant_beta() |
| 121 | + result_false = _run_sampler_static(y, x, pre_end, niter=50, nwarmup=25) |
| 122 | + result_default = run_gibbs_sampler( |
| 123 | + y=y.tolist(), |
| 124 | + x=x, |
| 125 | + pre_end=pre_end, |
| 126 | + niter=50, |
| 127 | + nwarmup=25, |
| 128 | + nchains=1, |
| 129 | + seed=42, |
| 130 | + prior_level_sd=0.01, |
| 131 | + expected_model_size=1.0, |
| 132 | + ) |
| 133 | + # Predictions must be identical |
| 134 | + np.testing.assert_array_equal( |
| 135 | + result_false.predictions, result_default.predictions |
| 136 | + ) |
| 137 | + |
| 138 | + |
| 139 | +# --------------------------------------------------------------------------- |
| 140 | +# Boundary and edge cases (5 tests) |
| 141 | +# --------------------------------------------------------------------------- |
| 142 | + |
| 143 | + |
| 144 | +class TestDynamicRegressionBoundary: |
| 145 | + def test_k0_no_covariates_falls_back_gracefully(self): |
| 146 | + """k=0 with dynamic_regression=True should run like static.""" |
| 147 | + rng = np.random.default_rng(42) |
| 148 | + y = rng.normal(10, 1, 30) |
| 149 | + result = _run_sampler_dynamic(y, [], 20, niter=20, nwarmup=10) |
| 150 | + preds = np.array(result.predictions) |
| 151 | + assert preds.shape == (20, 10) |
| 152 | + assert np.all(np.isfinite(preds)) |
| 153 | + |
| 154 | + def test_k1_single_covariate_runs_without_error(self): |
| 155 | + y, x, pre_end = _make_data_dynamic_k1_constant_beta() |
| 156 | + result = _run_sampler_dynamic(y, x, pre_end, niter=100, nwarmup=50) |
| 157 | + preds = np.array(result.predictions) |
| 158 | + assert np.all(np.isfinite(preds)) |
| 159 | + |
| 160 | + def test_k2_multiple_covariates_runs_without_error(self): |
| 161 | + y, x, pre_end = _make_data_dynamic_k2() |
| 162 | + result = _run_sampler_dynamic(y, x, pre_end, niter=100, nwarmup=50) |
| 163 | + preds = np.array(result.predictions) |
| 164 | + assert np.all(np.isfinite(preds)) |
| 165 | + |
| 166 | + def test_minimum_tpre_2_does_not_crash(self): |
| 167 | + """T_pre=2 is the minimum for random walk (needs at least 1 diff).""" |
| 168 | + rng = np.random.default_rng(42) |
| 169 | + y = rng.normal(5, 0.5, 5) |
| 170 | + x = [rng.normal(0, 1, 5).tolist()] |
| 171 | + result = _run_sampler_dynamic(y, x, 2, niter=20, nwarmup=10) |
| 172 | + preds = np.array(result.predictions) |
| 173 | + assert np.all(np.isfinite(preds)) |
| 174 | + |
| 175 | + def test_very_large_k_relative_to_tpre_runs_without_nan(self): |
| 176 | + """k=T_pre-1 (k=9, T_pre=10) for numerical stability check.""" |
| 177 | + rng = np.random.default_rng(42) |
| 178 | + t_pre, k = 10, 9 |
| 179 | + y = rng.normal(0, 1, 15) |
| 180 | + x = [rng.normal(0, 1, 15).tolist() for _ in range(k)] |
| 181 | + result = _run_sampler_dynamic(y, x, t_pre, niter=20, nwarmup=10) |
| 182 | + preds = np.array(result.predictions) |
| 183 | + assert np.all(np.isfinite(preds)) |
| 184 | + |
| 185 | + |
| 186 | +# --------------------------------------------------------------------------- |
| 187 | +# Statistical quality (3 tests) |
| 188 | +# --------------------------------------------------------------------------- |
| 189 | + |
| 190 | + |
| 191 | +class TestDynamicRegressionStatistical: |
| 192 | + def test_constant_beta_predictions_reasonable(self): |
| 193 | + """beta=2.0 constant data: predictions within +-20% of y_post.""" |
| 194 | + y, x, pre_end = _make_data_dynamic_k1_constant_beta(n=200, seed=123) |
| 195 | + result = _run_sampler_dynamic(y, x, pre_end, niter=500, nwarmup=250, seed=123) |
| 196 | + preds = np.array(result.predictions) |
| 197 | + y_post = np.array(y[pre_end:]) |
| 198 | + pred_mean = preds.mean(axis=0) |
| 199 | + y_post_mean = y_post.mean() |
| 200 | + pred_mean_overall = pred_mean.mean() |
| 201 | + assert abs(pred_mean_overall - y_post_mean) < 0.2 * abs(y_post_mean) + 0.5 |
| 202 | + |
| 203 | + def test_structural_break_predictions_differ_from_static(self): |
| 204 | + """Structural break data: dynamic and static RMSE should differ.""" |
| 205 | + y, x, pre_end = _make_data_dynamic_k1_structural_break(n=200, seed=99) |
| 206 | + result_dyn = _run_sampler_dynamic( |
| 207 | + y, x, pre_end, niter=500, nwarmup=250, seed=99 |
| 208 | + ) |
| 209 | + result_stat = _run_sampler_static( |
| 210 | + y, x, pre_end, niter=500, nwarmup=250, seed=99 |
| 211 | + ) |
| 212 | + y_post = np.array(y[pre_end:]) |
| 213 | + rmse_dyn = np.sqrt( |
| 214 | + ((np.array(result_dyn.predictions).mean(axis=0) - y_post) ** 2).mean() |
| 215 | + ) |
| 216 | + rmse_stat = np.sqrt( |
| 217 | + ((np.array(result_stat.predictions).mean(axis=0) - y_post) ** 2).mean() |
| 218 | + ) |
| 219 | + # They should not be equal (dynamic adapts, static doesn't) |
| 220 | + assert rmse_dyn != pytest.approx(rmse_stat, rel=0.01) |
| 221 | + |
| 222 | + def test_post_period_predictions_no_nan(self): |
| 223 | + y, x, pre_end = _make_data_dynamic_k1_constant_beta() |
| 224 | + result = _run_sampler_dynamic(y, x, pre_end, niter=200, nwarmup=100) |
| 225 | + preds = np.array(result.predictions) |
| 226 | + assert not np.any(np.isnan(preds)) |
| 227 | + |
| 228 | + |
| 229 | +# --------------------------------------------------------------------------- |
| 230 | +# Integration tests (3 tests) |
| 231 | +# --------------------------------------------------------------------------- |
| 232 | + |
| 233 | + |
| 234 | +class TestDynamicRegressionIntegration: |
| 235 | + def test_causal_impact_end_to_end(self): |
| 236 | + rng = np.random.default_rng(42) |
| 237 | + n = 80 |
| 238 | + import pandas as pd |
| 239 | + |
| 240 | + dates = pd.date_range("2020-01-01", periods=n, freq="D") |
| 241 | + x = rng.normal(0, 1, n) |
| 242 | + y = 2.0 * x + rng.normal(0, 0.3, n) |
| 243 | + y[56:] += 3.0 |
| 244 | + df = pd.DataFrame({"y": y, "x": x}, index=dates) |
| 245 | + ci = CausalImpact( |
| 246 | + df, |
| 247 | + ["2020-01-01", "2020-02-25"], |
| 248 | + ["2020-02-26", "2020-03-20"], |
| 249 | + model_args={"dynamic_regression": True, "niter": 200, "nwarmup": 100}, |
| 250 | + ) |
| 251 | + assert ci.summary() is not None |
| 252 | + |
| 253 | + def test_causal_impact_summary_and_inferences(self): |
| 254 | + rng = np.random.default_rng(42) |
| 255 | + n = 80 |
| 256 | + import pandas as pd |
| 257 | + |
| 258 | + dates = pd.date_range("2020-01-01", periods=n, freq="D") |
| 259 | + x = rng.normal(0, 1, n) |
| 260 | + y = 2.0 * x + rng.normal(0, 0.3, n) |
| 261 | + y[56:] += 3.0 |
| 262 | + df = pd.DataFrame({"y": y, "x": x}, index=dates) |
| 263 | + ci = CausalImpact( |
| 264 | + df, |
| 265 | + ["2020-01-01", "2020-02-25"], |
| 266 | + ["2020-02-26", "2020-03-20"], |
| 267 | + model_args={"dynamic_regression": True, "niter": 200, "nwarmup": 100}, |
| 268 | + ) |
| 269 | + summary = ci.summary() |
| 270 | + assert isinstance(summary, str) |
| 271 | + assert len(summary) > 0 |
| 272 | + |
| 273 | + inferences = ci.inferences |
| 274 | + assert isinstance(inferences, pd.DataFrame) |
| 275 | + assert len(inferences) > 0 |
| 276 | + |
| 277 | + def test_causal_impact_plot_runs(self): |
| 278 | + rng = np.random.default_rng(42) |
| 279 | + n = 80 |
| 280 | + import pandas as pd |
| 281 | + |
| 282 | + dates = pd.date_range("2020-01-01", periods=n, freq="D") |
| 283 | + x = rng.normal(0, 1, n) |
| 284 | + y = 2.0 * x + rng.normal(0, 0.3, n) |
| 285 | + y[56:] += 3.0 |
| 286 | + df = pd.DataFrame({"y": y, "x": x}, index=dates) |
| 287 | + ci = CausalImpact( |
| 288 | + df, |
| 289 | + ["2020-01-01", "2020-02-25"], |
| 290 | + ["2020-02-26", "2020-03-20"], |
| 291 | + model_args={"dynamic_regression": True, "niter": 200, "nwarmup": 100}, |
| 292 | + ) |
| 293 | + import matplotlib |
| 294 | + |
| 295 | + matplotlib.use("Agg") |
| 296 | + fig = ci.plot() |
| 297 | + assert fig is not None |
0 commit comments