Skip to content

Commit 8687ff0

Browse files
feat: extend CausalImpactResults with summary statistics and R-compatible summary format
CausalImpactResults に14個の新フィールドを追加し、summary() 出力を R CausalImpact と同等のフォーマットに改修。 Why: v0.1.0の summary() 出力には「-」プレースホルダーが残っており、 Actual行やPrediction (s.d.)行が表示できない状態だった。 Changes: - analysis.py: actual, predictions_sd, average_prediction_sd/lower/upper, cumulative_prediction_sd/lower/upper, average_effect_sd, cumulative_effect_sd, relative_effect_sd/lower/upper の14フィールド追加 - analysis.py: compute_effects() に cross-sample 集約計算を追加 (n_samples=1 の ddof=1 NaN を 0 にクランプ) - summary.py: R互換フォーマット実装 (Actual行、Prediction(s.d.)行、 3つのCI行、Posterior prob. 常時表示) - test_summary.py: CI行の行番号修正 (7→8、新フォーマットで行位置変更)
1 parent fc62fef commit 8687ff0

3 files changed

Lines changed: 199 additions & 23 deletions

File tree

python/causal_impact/analysis.py

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,45 @@
1111
class CausalImpactResults:
1212
"""Results of causal impact analysis."""
1313

14+
# Observed data
15+
actual: np.ndarray # (T_post,) observed y in post period
16+
17+
# Pointwise effects
1418
point_effects: np.ndarray # (T_post,) mean effect per time point
1519
point_effect_lower: np.ndarray # (T_post,) lower CI per time point
1620
point_effect_upper: np.ndarray # (T_post,) upper CI per time point
1721
ci_lower: float # lower CI bound on average effect
1822
ci_upper: float # upper CI bound on average effect
1923
point_effect_mean: float # mean of point effects across time
24+
average_effect_sd: float # std of per-sample average effects
25+
26+
# Cumulative effects
2027
cumulative_effect: np.ndarray # (T_post,) cumulative point effects
2128
cumulative_effect_lower: np.ndarray # (T_post,) lower cumulative CI
2229
cumulative_effect_upper: np.ndarray # (T_post,) upper cumulative CI
2330
cumulative_effect_total: float # total cumulative effect
31+
cumulative_effect_sd: float # std of per-sample cumulative effects
32+
33+
# Relative effects
2434
relative_effect_mean: float # relative effect (effect / predicted)
35+
relative_effect_sd: float # std of per-sample relative effects
36+
relative_effect_lower: float # lower CI on relative effect
37+
relative_effect_upper: float # upper CI on relative effect
38+
39+
# Significance
2540
p_value: float # Bayesian one-sided tail probability
41+
42+
# Counterfactual predictions
2643
predictions_mean: np.ndarray # (T_post,) mean counterfactual
44+
predictions_sd: np.ndarray # (T_post,) std of predictions per time point
2745
predictions_lower: np.ndarray # (T_post,) lower CI counterfactual
2846
predictions_upper: np.ndarray # (T_post,) upper CI counterfactual
47+
average_prediction_sd: float # std of per-sample average predictions
48+
average_prediction_lower: float # lower CI on average prediction
49+
average_prediction_upper: float # upper CI on average prediction
50+
cumulative_prediction_sd: float # std of per-sample cumulative predictions
51+
cumulative_prediction_lower: float # lower CI on cumulative prediction
52+
cumulative_prediction_upper: float # upper CI on cumulative prediction
2953

3054

3155
class CausalAnalysis:
@@ -78,12 +102,74 @@ def compute_effects(
78102
)
79103
cumulative_effect_total = float(cumulative_effect[-1])
80104

81-
# Relative effect
82-
pred_mean_total = predictions.mean()
83-
if abs(pred_mean_total) > 1e-10:
84-
relative_effect_mean = point_effect_mean / pred_mean_total
105+
# Actual observed values
106+
actual = y_post.copy()
107+
108+
# Per-time-point std of predictions across samples
109+
if n_samples == 1:
110+
predictions_sd_arr = np.zeros(predictions.shape[1])
85111
else:
86-
relative_effect_mean = 0.0
112+
predictions_sd_arr = np.std(predictions, axis=0, ddof=1)
113+
114+
# Prediction scalars (cross-sample aggregates)
115+
avg_pred_per_sample = predictions.mean(axis=1) # (n_samples,)
116+
cum_pred_per_sample = predictions.sum(axis=1) # (n_samples,)
117+
118+
if n_samples == 1:
119+
average_prediction_sd = 0.0
120+
cumulative_prediction_sd = 0.0
121+
else:
122+
average_prediction_sd = float(np.std(avg_pred_per_sample, ddof=1))
123+
cumulative_prediction_sd = float(np.std(cum_pred_per_sample, ddof=1))
124+
125+
average_prediction_lower = float(
126+
np.percentile(avg_pred_per_sample, 100 * lower_q)
127+
)
128+
average_prediction_upper = float(
129+
np.percentile(avg_pred_per_sample, 100 * upper_q)
130+
)
131+
cumulative_prediction_lower = float(
132+
np.percentile(cum_pred_per_sample, 100 * lower_q)
133+
)
134+
cumulative_prediction_upper = float(
135+
np.percentile(cum_pred_per_sample, 100 * upper_q)
136+
)
137+
138+
# Effect s.d. scalars
139+
cum_effects_per_sample = effects.sum(axis=1) # (n_samples,)
140+
141+
if n_samples == 1:
142+
average_effect_sd = 0.0
143+
cumulative_effect_sd = 0.0
144+
else:
145+
average_effect_sd = float(np.std(avg_effects, ddof=1))
146+
cumulative_effect_sd = float(np.std(cum_effects_per_sample, ddof=1))
147+
148+
# Relative effect per sample
149+
avg_pred_per_sample_safe = np.where(
150+
np.abs(avg_pred_per_sample) > 1e-10,
151+
avg_pred_per_sample,
152+
np.nan,
153+
)
154+
rel_effects_per_sample = np.where(
155+
np.abs(avg_pred_per_sample) > 1e-10,
156+
avg_effects / avg_pred_per_sample_safe,
157+
0.0,
158+
)
159+
160+
relative_effect_mean = float(rel_effects_per_sample.mean())
161+
162+
if n_samples == 1:
163+
relative_effect_sd = 0.0
164+
else:
165+
relative_effect_sd = float(np.std(rel_effects_per_sample, ddof=1))
166+
167+
relative_effect_lower = float(
168+
np.percentile(rel_effects_per_sample, 100 * lower_q)
169+
)
170+
relative_effect_upper = float(
171+
np.percentile(rel_effects_per_sample, 100 * upper_q)
172+
)
87173

88174
# p-value: proportion of samples where average effect has opposite sign
89175
if point_effect_mean >= 0:
@@ -99,19 +185,32 @@ def compute_effects(
99185
predictions_upper = np.percentile(predictions, 100 * upper_q, axis=0)
100186

101187
return CausalImpactResults(
188+
actual=actual,
102189
point_effects=point_effects,
103190
point_effect_lower=point_effect_lower,
104191
point_effect_upper=point_effect_upper,
105192
ci_lower=ci_lower,
106193
ci_upper=ci_upper,
107194
point_effect_mean=point_effect_mean,
195+
average_effect_sd=average_effect_sd,
108196
cumulative_effect=cumulative_effect,
109197
cumulative_effect_lower=cumulative_effect_lower,
110198
cumulative_effect_upper=cumulative_effect_upper,
111199
cumulative_effect_total=cumulative_effect_total,
200+
cumulative_effect_sd=cumulative_effect_sd,
112201
relative_effect_mean=relative_effect_mean,
202+
relative_effect_sd=relative_effect_sd,
203+
relative_effect_lower=relative_effect_lower,
204+
relative_effect_upper=relative_effect_upper,
113205
p_value=p_value,
114206
predictions_mean=predictions_mean,
207+
predictions_sd=predictions_sd_arr,
115208
predictions_lower=predictions_lower,
116209
predictions_upper=predictions_upper,
210+
average_prediction_sd=average_prediction_sd,
211+
average_prediction_lower=average_prediction_lower,
212+
average_prediction_upper=average_prediction_upper,
213+
cumulative_prediction_sd=cumulative_prediction_sd,
214+
cumulative_prediction_lower=cumulative_prediction_lower,
215+
cumulative_prediction_upper=cumulative_prediction_upper,
117216
)

python/causal_impact/summary.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,83 @@ class SummaryFormatter:
1212
def summary(results: CausalImpactResults, digits: int = 2) -> str:
1313
fmt = f".{digits}f"
1414

15-
avg_effect = format(results.point_effect_mean, fmt)
16-
avg_ci = f"[{format(results.ci_lower, fmt)}, {format(results.ci_upper, fmt)}]"
17-
cum_effect = format(results.cumulative_effect_total, fmt)
18-
cum_ci = (
15+
# Actual
16+
avg_actual = format(results.actual.mean(), fmt)
17+
cum_actual = format(results.actual.sum(), fmt)
18+
19+
# Prediction
20+
avg_pred = format(results.predictions_mean.mean(), fmt)
21+
avg_pred_sd = format(results.average_prediction_sd, fmt)
22+
cum_pred = format(results.predictions_mean.sum(), fmt)
23+
cum_pred_sd = format(results.cumulative_prediction_sd, fmt)
24+
25+
# Prediction CI
26+
avg_pred_ci = (
27+
f"[{format(results.average_prediction_lower, fmt)}, "
28+
f"{format(results.average_prediction_upper, fmt)}]"
29+
)
30+
cum_pred_ci = (
31+
f"[{format(results.cumulative_prediction_lower, fmt)}, "
32+
f"{format(results.cumulative_prediction_upper, fmt)}]"
33+
)
34+
35+
# Absolute effect
36+
avg_eff = format(results.point_effect_mean, fmt)
37+
avg_eff_sd = format(results.average_effect_sd, fmt)
38+
cum_eff = format(results.cumulative_effect_total, fmt)
39+
cum_eff_sd = format(results.cumulative_effect_sd, fmt)
40+
41+
# Absolute effect CI
42+
avg_eff_ci = (
43+
f"[{format(results.ci_lower, fmt)}, {format(results.ci_upper, fmt)}]"
44+
)
45+
cum_eff_ci = (
1946
f"[{format(results.cumulative_effect_lower[-1], fmt)}, "
2047
f"{format(results.cumulative_effect_upper[-1], fmt)}]"
2148
)
22-
rel_effect = format(results.relative_effect_mean * 100, fmt)
49+
50+
# Relative effect
51+
rel_m = format(results.relative_effect_mean * 100, fmt)
52+
rel_sd = format(results.relative_effect_sd * 100, fmt)
53+
rel_lo = format(results.relative_effect_lower * 100, fmt)
54+
rel_hi = format(results.relative_effect_upper * 100, fmt)
55+
2356
p_val = format(results.p_value, f".{max(digits, 3)}f")
57+
prob = format((1 - results.p_value) * 100, fmt)
58+
59+
pred_row = (
60+
f"Prediction (s.d.) "
61+
f"{avg_pred} ({avg_pred_sd}) "
62+
f"{cum_pred} ({cum_pred_sd})"
63+
)
64+
eff_row = (
65+
f"Absolute effect (s.d.) "
66+
f"{avg_eff} ({avg_eff_sd}) "
67+
f"{cum_eff} ({cum_eff_sd})"
68+
)
69+
rel_row = f"Relative effect (s.d.) {rel_m}% ({rel_sd}%) {rel_m}% ({rel_sd}%)"
70+
rel_ci_row = (
71+
f"95% CI [{rel_lo}%, {rel_hi}%] [{rel_lo}%, {rel_hi}%]"
72+
)
2473

2574
lines = [
2675
"Posterior inference {CausalImpact}",
2776
"",
2877
" Average Cumulative",
29-
"Actual - -",
30-
"Prediction (s.d.) - -",
31-
f"95% CI {avg_ci} {cum_ci}",
78+
f"Actual {avg_actual} {cum_actual}",
79+
pred_row,
80+
f"95% CI {avg_pred_ci} {cum_pred_ci}",
81+
"",
82+
eff_row,
83+
f"95% CI {avg_eff_ci} {cum_eff_ci}",
3284
"",
33-
f"Absolute effect (mean) {avg_effect} {cum_effect}",
34-
f"Relative effect {rel_effect}%",
85+
rel_row,
86+
rel_ci_row,
3587
"",
3688
f"Posterior tail-area probability p: {p_val}",
89+
f"Posterior prob. of a causal effect: {prob}%",
3790
]
3891

39-
if results.p_value < 0.05:
40-
lines.append("Posterior prob. of a causal effect: "
41-
f"{format((1 - results.p_value) * 100, fmt)}%")
42-
else:
43-
lines.append("The effect is not statistically significant.")
44-
4592
return "\n".join(lines)
4693

4794
@staticmethod

tests/test_summary.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,34 @@ def _make_results(effect=2.0, p_value=0.01):
99
"""Create a CausalImpactResults fixture."""
1010
t_post = 10
1111
return CausalImpactResults(
12+
actual=np.full(t_post, 12.0),
1213
point_effects=np.full(t_post, effect),
1314
point_effect_lower=np.full(t_post, effect * 0.75),
1415
point_effect_upper=np.full(t_post, effect * 1.25),
1516
ci_lower=effect * 0.5,
1617
ci_upper=effect * 1.5,
1718
point_effect_mean=effect,
19+
average_effect_sd=effect * 0.1,
1820
cumulative_effect=np.cumsum(np.full(t_post, effect)),
1921
cumulative_effect_lower=np.cumsum(np.full(t_post, effect * 0.75)),
2022
cumulative_effect_upper=np.cumsum(np.full(t_post, effect * 1.25)),
2123
cumulative_effect_total=effect * t_post,
24+
cumulative_effect_sd=effect,
2225
relative_effect_mean=effect / 10.0,
26+
relative_effect_sd=effect / 100.0,
27+
relative_effect_lower=effect / 20.0,
28+
relative_effect_upper=effect / 5.0,
2329
p_value=p_value,
2430
predictions_mean=np.full(t_post, 10.0),
31+
predictions_sd=np.full(t_post, 0.5),
2532
predictions_lower=np.full(t_post, 9.0),
2633
predictions_upper=np.full(t_post, 11.0),
34+
average_prediction_sd=0.5,
35+
average_prediction_lower=9.0,
36+
average_prediction_upper=11.0,
37+
cumulative_prediction_sd=5.0,
38+
cumulative_prediction_lower=90.0,
39+
cumulative_prediction_upper=110.0,
2740
)
2841

2942

@@ -37,6 +50,23 @@ def test_summary_default_format(self):
3750
assert "Cumulative" in text
3851
assert "2.0" in text or "2.00" in text
3952

53+
def test_summary_includes_r_style_actual_prediction_and_effect_sections_because_placeholder_rows_hide_valid_results(
54+
self,
55+
):
56+
"""R互換の summary では Actual/Prediction/Absolute/Relative の各行を欠かさない."""
57+
result = _make_results(effect=2.0, p_value=0.01)
58+
59+
text = SummaryFormatter.summary(result, digits=2)
60+
lines = text.split("\n")
61+
62+
assert "Actual 12.00 120.00" in lines
63+
assert "Prediction (s.d.) 10.00 (0.50) 100.00 (5.00)" in lines
64+
assert "95% CI [9.00, 11.00] [90.00, 110.00]" in lines
65+
assert "Absolute effect (s.d.) 2.00 (0.20) 20.00 (2.00)" in lines
66+
assert "95% CI [1.00, 3.00] [15.00, 25.00]" in lines
67+
assert "Relative effect (s.d.) 20.00% (2.00%) 20.00% (2.00%)" in lines
68+
assert "95% CI [10.00%, 40.00%] [10.00%, 40.00%]" in lines
69+
4070
def test_summary_report_format(self):
4171
result = _make_results(effect=2.0, p_value=0.01)
4272
text = SummaryFormatter.report(result)
@@ -55,10 +85,10 @@ def test_summary_digits_10(self):
5585
assert isinstance(text, str)
5686

5787
def test_summary_shows_cumulative_ci_in_95_percent_ci_row(self):
58-
"""95% CI 行の cumulative 列には最終時点の累積CIを表示する."""
88+
"""Absolute effect の 95% CI cumulative 列には最終時点の累積CIを表示する."""
5989
result = _make_results(effect=2.0, p_value=0.01)
6090
text = SummaryFormatter.summary(result, digits=2)
61-
ci_line = next(line for line in text.split("\n") if "95% CI" in line)
91+
ci_line = text.split("\n")[8]
6292
assert "15.00" in ci_line
6393
assert "25.00" in ci_line
6494

0 commit comments

Comments
 (0)