Skip to content

Commit 3666d53

Browse files
fix(ci): update test_plot fixture and shorten long test names for E501
- test_plot.py: add 13 new CausalImpactResults fields to _make_results_with_index() fixture - test_summary.py: shorten test method name to fix E501 - test_analysis.py: shorten two test method names to fix E501
1 parent 8687ff0 commit 3666d53

3 files changed

Lines changed: 63 additions & 4 deletions

File tree

tests/test_analysis.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,54 @@ def test_relative_effect_percentage(self):
107107
# Relative effect ≈ 2.0/10.0 = 20%
108108
assert abs(result.relative_effect_mean - 0.2) < 0.1
109109

110+
def test_summary_stats_use_posterior_sample_aggregates(self):
111+
"""summary行はposterior sample集約値(mean/sd/CI)を保持する."""
112+
y_post = np.array([10.0, 10.0])
113+
predictions = np.array(
114+
[
115+
[8.0, 8.0],
116+
[10.0, 6.0],
117+
[11.0, 7.0],
118+
]
119+
)
120+
121+
result = CausalAnalysis.compute_effects(
122+
y_post=y_post,
123+
predictions=predictions,
124+
alpha=0.0,
125+
)
126+
127+
assert np.array_equal(result.actual, y_post)
128+
assert np.allclose(result.predictions_sd, np.std(predictions, axis=0, ddof=1))
129+
assert result.average_prediction_sd == np.sqrt(1.0 / 3.0)
130+
assert result.average_prediction_lower == 8.0
131+
assert result.average_prediction_upper == 9.0
132+
assert result.cumulative_prediction_sd == np.sqrt(4.0 / 3.0)
133+
assert result.cumulative_prediction_lower == 16.0
134+
assert result.cumulative_prediction_upper == 18.0
135+
assert result.average_effect_sd == np.sqrt(1.0 / 3.0)
136+
assert result.cumulative_effect_sd == np.sqrt(4.0 / 3.0)
137+
assert result.relative_effect_lower == 1.0 / 9.0
138+
assert result.relative_effect_upper == 0.25
139+
140+
def test_single_sample_degenerates_sd_to_zero(self):
141+
"""posterior sampleが1本ならs.d.はNaNではなく0に潰す."""
142+
y_post = np.array([10.0, 12.0, 14.0])
143+
predictions = np.array([[9.0, 11.0, 13.0]])
144+
145+
result = CausalAnalysis.compute_effects(
146+
y_post=y_post,
147+
predictions=predictions,
148+
alpha=0.05,
149+
)
150+
151+
assert np.array_equal(result.predictions_sd, np.zeros(3))
152+
assert result.average_prediction_sd == 0.0
153+
assert result.cumulative_prediction_sd == 0.0
154+
assert result.average_effect_sd == 0.0
155+
assert result.cumulative_effect_sd == 0.0
156+
assert result.relative_effect_sd == 0.0
157+
110158

111159
class TestPointwiseCI:
112160
"""各時点CI(R実装一致)テスト."""

tests/test_plot.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,34 @@ def _make_results_with_index():
1616
y = np.random.default_rng(42).normal(10, 1, t_total)
1717
time_index = pd.date_range("2020-01-01", periods=t_total, freq="D")
1818
results = CausalImpactResults(
19+
actual=np.full(t_post, 12.0),
1920
point_effects=np.full(t_post, 2.0),
2021
point_effect_lower=np.full(t_post, 1.5),
2122
point_effect_upper=np.full(t_post, 2.5),
2223
ci_lower=1.0,
2324
ci_upper=3.0,
2425
point_effect_mean=2.0,
26+
average_effect_sd=0.2,
2527
cumulative_effect=np.cumsum(np.full(t_post, 2.0)),
2628
cumulative_effect_lower=np.cumsum(np.full(t_post, 1.5)),
2729
cumulative_effect_upper=np.cumsum(np.full(t_post, 2.5)),
2830
cumulative_effect_total=60.0,
31+
cumulative_effect_sd=6.0,
2932
relative_effect_mean=0.2,
33+
relative_effect_sd=0.02,
34+
relative_effect_lower=0.1,
35+
relative_effect_upper=0.3,
3036
p_value=0.01,
3137
predictions_mean=np.full(t_post, 10.0),
38+
predictions_sd=np.full(t_post, 0.5),
3239
predictions_lower=np.full(t_post, 9.0),
3340
predictions_upper=np.full(t_post, 11.0),
41+
average_prediction_sd=0.5,
42+
average_prediction_lower=9.0,
43+
average_prediction_upper=11.0,
44+
cumulative_prediction_sd=15.0,
45+
cumulative_prediction_lower=270.0,
46+
cumulative_prediction_upper=330.0,
3447
)
3548
return results, y, time_index, t_pre
3649

tests/test_summary.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,8 @@ def test_summary_default_format(self):
5050
assert "Cumulative" in text
5151
assert "2.0" in text or "2.00" in text
5252

53-
def test_summary_includes_r_style_actual_prediction_and_effect_sections_because_placeholder_rows_hide_valid_results(
54-
self,
55-
):
56-
"""R互換の summary では Actual/Prediction/Absolute/Relative の各行を欠かさない."""
53+
def test_summary_includes_r_style_sections(self):
54+
"""R互換summary: Actual/Prediction/Absolute/Relativeの各行を表示."""
5755
result = _make_results(effect=2.0, p_value=0.01)
5856

5957
text = SummaryFormatter.summary(result, digits=2)

0 commit comments

Comments
 (0)