Skip to content

Commit 9f71013

Browse files
merge: resolve conflicts with origin/main
- summary.py: adopt origin/main's R-compatible summary format - test_analysis.py: adopt origin/main's docstring removal
2 parents b36495c + 266ca0c commit 9f71013

5 files changed

Lines changed: 264 additions & 119 deletions

File tree

python/causal_impact/analysis.py

Lines changed: 106 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,34 @@
1111
class CausalImpactResults:
1212
"""Results of causal impact analysis."""
1313

14-
point_effects: np.ndarray # (T_post,) mean effect per time point
15-
point_effect_lower: np.ndarray # (T_post,) lower CI per time point
16-
point_effect_upper: np.ndarray # (T_post,) upper CI per time point
17-
ci_lower: float # lower CI bound on average effect
18-
ci_upper: float # upper CI bound on average effect
19-
point_effect_mean: float # mean of point effects across time
20-
cumulative_effect: np.ndarray # (T_post,) cumulative point effects
21-
cumulative_effect_lower: np.ndarray # (T_post,) lower cumulative CI
22-
cumulative_effect_upper: np.ndarray # (T_post,) upper cumulative CI
23-
cumulative_effect_total: float # total cumulative effect
24-
relative_effect_mean: float # relative effect (effect / predicted)
25-
p_value: float # Bayesian one-sided tail probability
26-
predictions_mean: np.ndarray # (T_post,) mean counterfactual
27-
predictions_lower: np.ndarray # (T_post,) lower CI counterfactual
28-
predictions_upper: np.ndarray # (T_post,) upper CI counterfactual
14+
actual: np.ndarray
15+
point_effects: np.ndarray
16+
point_effect_lower: np.ndarray
17+
point_effect_upper: np.ndarray
18+
ci_lower: float
19+
ci_upper: float
20+
point_effect_mean: float
21+
average_effect_sd: float
22+
cumulative_effect: np.ndarray
23+
cumulative_effect_lower: np.ndarray
24+
cumulative_effect_upper: np.ndarray
25+
cumulative_effect_total: float
26+
cumulative_effect_sd: float
27+
relative_effect_mean: float
28+
relative_effect_sd: float
29+
relative_effect_lower: float
30+
relative_effect_upper: float
31+
p_value: float
32+
predictions_mean: np.ndarray
33+
predictions_sd: np.ndarray
34+
predictions_lower: np.ndarray
35+
predictions_upper: np.ndarray
36+
average_prediction_sd: float
37+
average_prediction_lower: float
38+
average_prediction_upper: float
39+
cumulative_prediction_sd: float
40+
cumulative_prediction_lower: float
41+
cumulative_prediction_upper: float
2942

3043

3144
class CausalAnalysis:
@@ -42,28 +55,19 @@ def compute_effects(
4255

4356
n_samples = predictions.shape[0]
4457

45-
# Effect per sample per time point: observed - counterfactual
46-
# predictions shape: (n_samples, t_post)
47-
effects = y_post[np.newaxis, :] - predictions # (n_samples, t_post)
58+
effects = y_post[np.newaxis, :] - predictions
59+
avg_effects = effects.mean(axis=1)
60+
point_effects = effects.mean(axis=0)
4861

49-
# Average effect across time for each sample
50-
avg_effects = effects.mean(axis=1) # (n_samples,)
51-
52-
# Point effects: mean across samples at each time point
53-
point_effects = effects.mean(axis=0) # (t_post,)
54-
55-
# Summary-table CI on average effect uses sample-average quantiles.
5662
lower_q = alpha / 2
5763
upper_q = 1 - alpha / 2
5864
point_effect_lower = np.percentile(effects, 100 * lower_q, axis=0)
5965
point_effect_upper = np.percentile(effects, 100 * upper_q, axis=0)
6066
ci_lower = float(np.percentile(avg_effects, 100 * lower_q))
6167
ci_upper = float(np.percentile(avg_effects, 100 * upper_q))
6268

63-
# Mean effect
6469
point_effect_mean = float(avg_effects.mean())
6570

66-
# Cumulative effect
6771
cumulative_effect = np.cumsum(point_effects)
6872
cum_effects_samples = np.cumsum(effects, axis=1)
6973
cumulative_effect_lower = np.percentile(
@@ -78,40 +82,107 @@ def compute_effects(
7882
)
7983
cumulative_effect_total = float(cumulative_effect[-1])
8084

81-
# Relative effect
82-
pred_mean_total = predictions.mean()
83-
if abs(pred_mean_total) > 1e-10:
84-
relative_effect_mean = point_effect_mean / pred_mean_total
85+
actual = y_post.copy()
86+
87+
if n_samples == 1:
88+
predictions_sd_arr = np.zeros(predictions.shape[1])
89+
else:
90+
predictions_sd_arr = np.std(predictions, axis=0, ddof=1)
91+
92+
avg_pred_per_sample = predictions.mean(axis=1)
93+
cum_pred_per_sample = predictions.sum(axis=1)
94+
95+
if n_samples == 1:
96+
average_prediction_sd = 0.0
97+
cumulative_prediction_sd = 0.0
8598
else:
86-
relative_effect_mean = 0.0
99+
average_prediction_sd = float(np.std(avg_pred_per_sample, ddof=1))
100+
cumulative_prediction_sd = float(np.std(cum_pred_per_sample, ddof=1))
101+
102+
average_prediction_lower = float(
103+
np.percentile(avg_pred_per_sample, 100 * lower_q)
104+
)
105+
average_prediction_upper = float(
106+
np.percentile(avg_pred_per_sample, 100 * upper_q)
107+
)
108+
cumulative_prediction_lower = float(
109+
np.percentile(cum_pred_per_sample, 100 * lower_q)
110+
)
111+
cumulative_prediction_upper = float(
112+
np.percentile(cum_pred_per_sample, 100 * upper_q)
113+
)
114+
115+
cum_effects_per_sample = effects.sum(axis=1)
116+
117+
if n_samples == 1:
118+
average_effect_sd = 0.0
119+
cumulative_effect_sd = 0.0
120+
else:
121+
average_effect_sd = float(np.std(avg_effects, ddof=1))
122+
cumulative_effect_sd = float(np.std(cum_effects_per_sample, ddof=1))
123+
124+
avg_pred_per_sample_safe = np.where(
125+
np.abs(avg_pred_per_sample) > 1e-10,
126+
avg_pred_per_sample,
127+
np.nan,
128+
)
129+
rel_effects_per_sample = np.where(
130+
np.abs(avg_pred_per_sample) > 1e-10,
131+
avg_effects / avg_pred_per_sample_safe,
132+
0.0,
133+
)
134+
135+
relative_effect_mean = float(rel_effects_per_sample.mean())
136+
137+
if n_samples == 1:
138+
relative_effect_sd = 0.0
139+
else:
140+
relative_effect_sd = float(np.std(rel_effects_per_sample, ddof=1))
141+
142+
relative_effect_lower = float(
143+
np.percentile(rel_effects_per_sample, 100 * lower_q)
144+
)
145+
relative_effect_upper = float(
146+
np.percentile(rel_effects_per_sample, 100 * upper_q)
147+
)
87148

88-
# p-value: proportion of samples where average effect has opposite sign
89149
if point_effect_mean >= 0:
90150
p_value = float(np.mean(avg_effects < 0))
91151
else:
92152
p_value = float(np.mean(avg_effects > 0))
93-
# Ensure minimum p-value of 1/n_samples
94153
p_value = max(p_value, 1.0 / n_samples)
95154

96-
# Counterfactual prediction summaries
97155
predictions_mean = predictions.mean(axis=0)
98156
predictions_lower = np.percentile(predictions, 100 * lower_q, axis=0)
99157
predictions_upper = np.percentile(predictions, 100 * upper_q, axis=0)
100158

101159
return CausalImpactResults(
160+
actual=actual,
102161
point_effects=point_effects,
103162
point_effect_lower=point_effect_lower,
104163
point_effect_upper=point_effect_upper,
105164
ci_lower=ci_lower,
106165
ci_upper=ci_upper,
107166
point_effect_mean=point_effect_mean,
167+
average_effect_sd=average_effect_sd,
108168
cumulative_effect=cumulative_effect,
109169
cumulative_effect_lower=cumulative_effect_lower,
110170
cumulative_effect_upper=cumulative_effect_upper,
111171
cumulative_effect_total=cumulative_effect_total,
172+
cumulative_effect_sd=cumulative_effect_sd,
112173
relative_effect_mean=relative_effect_mean,
174+
relative_effect_sd=relative_effect_sd,
175+
relative_effect_lower=relative_effect_lower,
176+
relative_effect_upper=relative_effect_upper,
113177
p_value=p_value,
114178
predictions_mean=predictions_mean,
179+
predictions_sd=predictions_sd_arr,
115180
predictions_lower=predictions_lower,
116181
predictions_upper=predictions_upper,
182+
average_prediction_sd=average_prediction_sd,
183+
average_prediction_lower=average_prediction_lower,
184+
average_prediction_upper=average_prediction_upper,
185+
cumulative_prediction_sd=cumulative_prediction_sd,
186+
cumulative_prediction_lower=cumulative_prediction_lower,
187+
cumulative_prediction_upper=cumulative_prediction_upper,
117188
)

python/causal_impact/summary.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,44 +6,83 @@
66

77

88
class SummaryFormatter:
9-
"""Format CausalImpact results as text summary or natural language report."""
9+
"""Format CausalImpact results as text summary or report."""
1010

1111
@staticmethod
1212
def summary(results: CausalImpactResults, digits: int = 2) -> str:
1313
fmt = f".{digits}f"
1414

15-
avg_effect = format(results.point_effect_mean, fmt)
16-
avg_ci = f"[{format(results.ci_lower, fmt)}, {format(results.ci_upper, fmt)}]"
17-
cum_effect = format(results.cumulative_effect_total, fmt)
18-
cum_ci = (
15+
avg_actual = format(results.actual.mean(), fmt)
16+
cum_actual = format(results.actual.sum(), fmt)
17+
18+
avg_pred = format(results.predictions_mean.mean(), fmt)
19+
avg_pred_sd = format(results.average_prediction_sd, fmt)
20+
cum_pred = format(results.predictions_mean.sum(), fmt)
21+
cum_pred_sd = format(results.cumulative_prediction_sd, fmt)
22+
23+
avg_pred_ci = (
24+
f"[{format(results.average_prediction_lower, fmt)}, "
25+
f"{format(results.average_prediction_upper, fmt)}]"
26+
)
27+
cum_pred_ci = (
28+
f"[{format(results.cumulative_prediction_lower, fmt)}, "
29+
f"{format(results.cumulative_prediction_upper, fmt)}]"
30+
)
31+
32+
avg_eff = format(results.point_effect_mean, fmt)
33+
avg_eff_sd = format(results.average_effect_sd, fmt)
34+
cum_eff = format(results.cumulative_effect_total, fmt)
35+
cum_eff_sd = format(results.cumulative_effect_sd, fmt)
36+
37+
avg_eff_ci = (
38+
f"[{format(results.ci_lower, fmt)}, {format(results.ci_upper, fmt)}]"
39+
)
40+
cum_eff_ci = (
1941
f"[{format(results.cumulative_effect_lower[-1], fmt)}, "
2042
f"{format(results.cumulative_effect_upper[-1], fmt)}]"
2143
)
22-
rel_effect = format(results.relative_effect_mean * 100, fmt)
44+
45+
rel_m = format(results.relative_effect_mean * 100, fmt)
46+
rel_sd = format(results.relative_effect_sd * 100, fmt)
47+
rel_lo = format(results.relative_effect_lower * 100, fmt)
48+
rel_hi = format(results.relative_effect_upper * 100, fmt)
49+
2350
p_val = format(results.p_value, f".{max(digits, 3)}f")
51+
prob = format((1 - results.p_value) * 100, fmt)
52+
53+
pred_row = (
54+
f"Prediction (s.d.) "
55+
f"{avg_pred} ({avg_pred_sd}) "
56+
f"{cum_pred} ({cum_pred_sd})"
57+
)
58+
eff_row = (
59+
f"Absolute effect (s.d.) "
60+
f"{avg_eff} ({avg_eff_sd}) "
61+
f"{cum_eff} ({cum_eff_sd})"
62+
)
63+
rel_row = f"Relative effect (s.d.) {rel_m}% ({rel_sd}%) {rel_m}% ({rel_sd}%)"
64+
rel_ci_row = (
65+
f"95% CI [{rel_lo}%, {rel_hi}%] [{rel_lo}%, {rel_hi}%]"
66+
)
2467

2568
lines = [
2669
"Posterior inference {CausalImpact}",
2770
"",
2871
" Average Cumulative",
29-
"Actual - -",
30-
"Prediction (s.d.) - -",
31-
f"95% CI {avg_ci} {cum_ci}",
72+
f"Actual {avg_actual} {cum_actual}",
73+
pred_row,
74+
f"95% CI {avg_pred_ci} {cum_pred_ci}",
75+
"",
76+
eff_row,
77+
f"95% CI {avg_eff_ci} {cum_eff_ci}",
3278
"",
33-
f"Absolute effect (mean) {avg_effect} {cum_effect}",
34-
f"Relative effect {rel_effect}%",
79+
rel_row,
80+
rel_ci_row,
3581
"",
3682
f"Posterior tail-area probability p: {p_val}",
83+
f"Posterior prob. of a causal effect: {prob}%",
3784
]
3885

39-
if results.p_value < 0.05:
40-
lines.append(
41-
"Posterior prob. of a causal effect: "
42-
f"{format((1 - results.p_value) * 100, fmt)}%"
43-
)
44-
else:
45-
lines.append("The effect is not statistically significant.")
46-
4786
return "\n".join(lines)
4887

4988
@staticmethod

0 commit comments

Comments
 (0)