|
| 1 | +"""State-space seasonal smoother の Python 統合テスト. |
| 2 | +
|
| 3 | +nseasons > 1 の場合に状態空間 seasonal smoother が正しく動作することを検証する。 |
| 4 | +R bsts AddSeasonal() 互換の実装に対応。 |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +import pytest |
| 10 | +from causal_impact import CausalImpact |
| 11 | + |
| 12 | + |
| 13 | +MCMC_ARGS_FAST = {"niter": 500, "nwarmup": 200, "seed": 42, "prior_level_sd": 0.01} |
| 14 | +MCMC_ARGS_MEDIUM = {"niter": 2000, "nwarmup": 500, "seed": 42, "prior_level_sd": 0.01} |
| 15 | + |
| 16 | + |
| 17 | +def _make_seasonal_df( |
| 18 | + n: int = 84, |
| 19 | + pre_end: int = 56, |
| 20 | + nseasons: int = 7, |
| 21 | + effect: float = 5.0, |
| 22 | + noise_sd: float = 0.5, |
| 23 | + seed: int = 42, |
| 24 | +) -> tuple[pd.DataFrame, list[int], list[int]]: |
| 25 | + """Generate time series with seasonal pattern and post-period effect.""" |
| 26 | + rng = np.random.default_rng(seed) |
| 27 | + seasonal_pattern = np.sin(2 * np.pi * np.arange(nseasons) / nseasons) |
| 28 | + repeated = np.resize(seasonal_pattern, n) |
| 29 | + y = 20.0 + repeated + rng.normal(0, noise_sd, n) |
| 30 | + y[pre_end:] += effect |
| 31 | + df = pd.DataFrame({"y": y}) |
| 32 | + return df, [0, pre_end - 1], [pre_end, n - 1] |
| 33 | + |
| 34 | + |
| 35 | +class TestSeasonalSmootherIntegration: |
| 36 | + def test_sigma_seasonal_exists_when_nseasons_set(self): |
| 37 | + """nseasons > 1 のとき sigma_seasonal が非空であること.""" |
| 38 | + df, pre, post = _make_seasonal_df() |
| 39 | + ci = CausalImpact( |
| 40 | + df, pre, post, model_args={**MCMC_ARGS_FAST, "nseasons": 7} |
| 41 | + ) |
| 42 | + from causal_impact._core import run_gibbs_sampler |
| 43 | + |
| 44 | + result = run_gibbs_sampler( |
| 45 | + list(df["y"]), |
| 46 | + None, |
| 47 | + pre[1] + 1, |
| 48 | + 500, |
| 49 | + 200, |
| 50 | + 1, |
| 51 | + 42, |
| 52 | + 0.01, |
| 53 | + 1.0, |
| 54 | + 7.0, |
| 55 | + 1.0, |
| 56 | + False, |
| 57 | + "local_level", |
| 58 | + ) |
| 59 | + assert len(result.sigma_seasonal) > 0 |
| 60 | + |
| 61 | + def test_sigma_seasonal_empty_when_no_seasons(self): |
| 62 | + """nseasons=None のとき sigma_seasonal が空であること.""" |
| 63 | + from causal_impact._core import run_gibbs_sampler |
| 64 | + |
| 65 | + y = [10.0 + 0.1 * i for i in range(20)] |
| 66 | + result = run_gibbs_sampler( |
| 67 | + y, None, 15, 10, 5, 1, 42, 0.01, 1.0, None, None, False, "local_level" |
| 68 | + ) |
| 69 | + assert len(result.sigma_seasonal) == 0 |
| 70 | + |
| 71 | + def test_sigma_seasonal_positive_all_samples(self): |
| 72 | + """sigma_seasonal の全サンプルが正であること.""" |
| 73 | + from causal_impact._core import run_gibbs_sampler |
| 74 | + |
| 75 | + y = [20.0 + np.sin(2 * np.pi * i / 7) for i in range(84)] |
| 76 | + result = run_gibbs_sampler( |
| 77 | + y, None, 56, 200, 100, 1, 42, 0.01, 1.0, 7.0, 1.0, False, "local_level" |
| 78 | + ) |
| 79 | + assert all(s > 0 for s in result.sigma_seasonal) |
| 80 | + |
| 81 | + def test_sigma_seasonal_len_equals_post_warmup(self): |
| 82 | + """sigma_seasonal の長さが niter - nwarmup であること.""" |
| 83 | + from causal_impact._core import run_gibbs_sampler |
| 84 | + |
| 85 | + niter, nwarmup = 100, 30 |
| 86 | + y = [20.0 + np.sin(2 * np.pi * i / 7) for i in range(84)] |
| 87 | + result = run_gibbs_sampler( |
| 88 | + y, None, 56, niter, nwarmup, 1, 42, 0.01, 1.0, 7.0, 1.0, False, "local_level" |
| 89 | + ) |
| 90 | + assert len(result.sigma_seasonal) == niter |
| 91 | + |
| 92 | + def test_point_effect_finite_with_seasonal(self): |
| 93 | + """seasonal モデルの point_effect_mean が有限値であること.""" |
| 94 | + df, pre, post = _make_seasonal_df() |
| 95 | + ci = CausalImpact( |
| 96 | + df, pre, post, model_args={**MCMC_ARGS_FAST, "nseasons": 7} |
| 97 | + ) |
| 98 | + assert np.isfinite(ci.summary_stats["point_effect_mean"]) |
| 99 | + |
| 100 | + def test_ci_bounds_finite_with_seasonal(self): |
| 101 | + """seasonal モデルの CI bounds が有限値であること.""" |
| 102 | + df, pre, post = _make_seasonal_df() |
| 103 | + ci = CausalImpact( |
| 104 | + df, pre, post, model_args={**MCMC_ARGS_FAST, "nseasons": 7} |
| 105 | + ) |
| 106 | + assert np.isfinite(ci.summary_stats["ci_lower"]) |
| 107 | + assert np.isfinite(ci.summary_stats["ci_upper"]) |
| 108 | + assert ci.summary_stats["ci_lower"] < ci.summary_stats["ci_upper"] |
| 109 | + |
| 110 | + def test_seasonal_predictions_continue_in_post(self): |
| 111 | + """post period でも seasonal パターンが予測に反映されること.""" |
| 112 | + df, pre, post = _make_seasonal_df(effect=0.0, noise_sd=0.01) |
| 113 | + ci = CausalImpact( |
| 114 | + df, pre, post, model_args={**MCMC_ARGS_MEDIUM, "nseasons": 7} |
| 115 | + ) |
| 116 | + inf = ci.inferences |
| 117 | + post_predictions = inf["predicted_mean"] |
| 118 | + expected_post_len = post[1] - post[0] + 1 |
| 119 | + assert len(post_predictions) == expected_post_len |
| 120 | + assert all(np.isfinite(post_predictions)) |
| 121 | + |
| 122 | + def test_strong_seasonal_effect_detected(self): |
| 123 | + """強い因果効果 + seasonal → significant.""" |
| 124 | + df, pre, post = _make_seasonal_df(effect=10.0, noise_sd=0.5) |
| 125 | + ci = CausalImpact( |
| 126 | + df, pre, post, model_args={**MCMC_ARGS_MEDIUM, "nseasons": 7} |
| 127 | + ) |
| 128 | + assert ci.summary_stats["p_value"] < 0.05 |
| 129 | + |
| 130 | + def test_no_effect_seasonal_not_significant(self): |
| 131 | + """因果効果なし + seasonal → not significant.""" |
| 132 | + df, pre, post = _make_seasonal_df(effect=0.0, noise_sd=2.0) |
| 133 | + ci = CausalImpact( |
| 134 | + df, pre, post, model_args={**MCMC_ARGS_MEDIUM, "nseasons": 7} |
| 135 | + ) |
| 136 | + assert ci.summary_stats["p_value"] > 0.05 |
| 137 | + |
| 138 | + def test_seasonal_backward_compat_nseasons_none(self): |
| 139 | + """nseasons=None のとき既存動作と同一であること.""" |
| 140 | + rng = np.random.default_rng(99) |
| 141 | + y = 10.0 + rng.normal(0, 0.5, 30) |
| 142 | + y[20:] += 3.0 |
| 143 | + df = pd.DataFrame({"y": y}) |
| 144 | + ci = CausalImpact( |
| 145 | + df, |
| 146 | + [0, 19], |
| 147 | + [20, 29], |
| 148 | + model_args={"niter": 200, "nwarmup": 100, "seed": 99}, |
| 149 | + ) |
| 150 | + assert np.isfinite(ci.summary_stats["point_effect_mean"]) |
| 151 | + assert ci.summary_stats["p_value"] < 0.05 |
| 152 | + |
| 153 | + def test_nseasons_2_valid(self): |
| 154 | + """S=2(最小の seasonal)でエラーなし.""" |
| 155 | + df, pre, post = _make_seasonal_df(nseasons=2) |
| 156 | + ci = CausalImpact( |
| 157 | + df, pre, post, model_args={**MCMC_ARGS_FAST, "nseasons": 2} |
| 158 | + ) |
| 159 | + assert np.isfinite(ci.summary_stats["point_effect_mean"]) |
| 160 | + |
| 161 | + def test_nseasons_12_valid(self): |
| 162 | + """S=12(月次 seasonal)でエラーなし.""" |
| 163 | + df, pre, post = _make_seasonal_df(n=120, pre_end=84, nseasons=12) |
| 164 | + ci = CausalImpact( |
| 165 | + df, pre, post, model_args={**MCMC_ARGS_FAST, "nseasons": 12} |
| 166 | + ) |
| 167 | + assert np.isfinite(ci.summary_stats["point_effect_mean"]) |
| 168 | + |
| 169 | + def test_season_duration_7_valid(self): |
| 170 | + """season_duration=7(週次ブロック)でエラーなし.""" |
| 171 | + df, pre, post = _make_seasonal_df(n=168, pre_end=112, nseasons=4) |
| 172 | + ci = CausalImpact( |
| 173 | + df, |
| 174 | + pre, |
| 175 | + post, |
| 176 | + model_args={**MCMC_ARGS_FAST, "nseasons": 4, "season_duration": 7}, |
| 177 | + ) |
| 178 | + assert np.isfinite(ci.summary_stats["point_effect_mean"]) |
0 commit comments