Skip to content

Commit 7198361

Browse files
committed
Really fixd confidence intervals
1 parent 376a4e0 commit 7198361

File tree

2 files changed

+44
-25
lines changed

2 files changed

+44
-25
lines changed

src/inference.jl

+27-18
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ function quantities_of_interest(mod, n)
258258
null_dist = generate_null_distribution(mod, n)
259259
avg_effect = mod isa Metalearner ? mean(mod.causal_effect) : mod.causal_effect
260260
pvalue, stderr = p_value_and_std_err(null_dist, avg_effect)
261-
lb, ub = confidence_interval(null_dist)
261+
lb, ub = confidence_interval(null_dist, avg_effect)
262262

263263
return pvalue, stderr, lb, ub
264264
end
@@ -268,13 +268,13 @@ function quantities_of_interest(mod::InterruptedTimeSeries, n, mean_effect)
268268
metric = ifelse(mean_effect, mean, sum)
269269
effect = metric(mod.causal_effect)
270270
pvalue, stderr = p_value_and_std_err(null_dist, effect)
271-
lb, ub = confidence_interval(null_dist)
271+
lb, ub = confidence_interval(null_dist, effect)
272272

273273
return pvalue, stderr, lb, ub
274274
end
275275

276276
"""
277-
confidence_interval(null_dist)
277+
confidence_interval(null_dist, effect)
278278
279279
Compute 95% confidence intervals via randomization inference.
280280
@@ -289,24 +289,33 @@ julia> x, t, y = rand(100, 5), [rand()<0.4 for i in 1:100], rand(1:100, 100, 1)
289289
julia> g_computer = GComputation(x, t, y)
290290
julia> estimate_causal_effect!(g_computer)
291291
julia> null_dist = CausalELM.generate_null_distribution(g_computer, 1000)
292-
julia> confidence_interval(null_dist)
292+
julia> confidence_interval(null_dist, g_computer.causal_effect)
293293
(-0.45147664642089147, 0.45147664642089147)
294294
```
295295
"""
296-
function confidence_interval(null_dist)
297-
sorted_null_dist, n = sort(null_dist), length(null_dist)
298-
low_idx, high_idx = 0.025 * (n - 1), 0.975 * (n - 1)
299-
300-
lb = if isinteger(low_idx)
301-
sorted_null_dist[Int(low_idx)]
302-
else
303-
mean(sorted_null_dist[floor(Int, low_idx):ceil(Int, low_idx)])
304-
end
305-
306-
ub = if isinteger(high_idx)
307-
sorted_null_dist[Int(high_idx)]
308-
else
309-
mean(sorted_null_dist[floor(Int, high_idx):ceil(Int, high_idx)])
296+
function confidence_interval(null_dist, effect)
297+
# Grid to search that probably includes the lower and upper bounds and is pretty precise
298+
max_magnitude_val = maximum(abs.(null_dist))
299+
grid = range(
300+
start=effect - 2max_magnitude_val,
301+
stop=effect + 2max_magnitude_val,
302+
length=4length(null_dist)
303+
)
304+
lb, ub = Inf, -Inf
305+
low_idx, high_idx = 1, length(grid)
306+
307+
# Start from the smallest and largest values until we get p > 0.05 and break out
308+
while (isinf(lb) || isinf(ub)) && (low_idx < high_idx)
309+
left_p_val, _ = p_value_and_std_err(null_dist, grid[low_idx])
310+
right_p_val, _ = p_value_and_std_err(null_dist, grid[high_idx])
311+
312+
lb = left_p_val > 0.05 && isinf(lb) ? grid[low_idx] : lb
313+
ub = right_p_val > 0.05 && isinf(ub) ? grid[high_idx] : ub
314+
315+
(isinf(lb) == false && isinf(ub) == false) && break
316+
317+
low_idx += 1
318+
high_idx -= 1
310319
end
311320

312321
return lb, ub

test/test_inference.jl

+17-7
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ g_computer = GComputation(x, t, y)
99
estimate_causal_effect!(g_computer)
1010
g_inference = CausalELM.generate_null_distribution(g_computer, 1000)
1111
p1, stderr1 = CausalELM.p_value_and_std_err(g_inference, CausalELM.mean(g_inference))
12-
lb1, ub1 = CausalELM.confidence_interval(g_inference)
12+
lb1, ub1 = CausalELM.confidence_interval(g_inference, g_computer.causal_effect)
1313
p11, stderr11, lb11, ub11 = CausalELM.quantities_of_interest(g_computer, 100)
1414
summary1 = summarize(g_computer, n=100, inference=true)
1515

1616
dm = DoubleMachineLearning(x, t, y)
1717
estimate_causal_effect!(dm)
1818
dm_inference = CausalELM.generate_null_distribution(dm, 1000)
1919
p2, stderr2 = CausalELM.p_value_and_std_err(dm_inference, CausalELM.mean(dm_inference))
20-
lb2, ub2 = CausalELM.confidence_interval(dm_inference)
20+
lb2, ub2 = CausalELM.confidence_interval(dm_inference, dm.causal_effect)
2121
summary2 = summarize(dm, n=100)
2222

2323
# With a continuous treatment variable
@@ -27,7 +27,9 @@ dm_continuous_inference = CausalELM.generate_null_distribution(dm_continuous, 10
2727
p3, stderr3 = CausalELM.p_value_and_std_err(
2828
dm_continuous_inference, CausalELM.mean(dm_continuous_inference)
2929
)
30-
lb3, ub3 = CausalELM.confidence_interval(dm_continuous_inference)
30+
lb3, ub3 = CausalELM.confidence_interval(
31+
dm_continuous_inference, dm_continuous.causal_effect
32+
)
3133
summary3 = summarize(dm_continuous, n=100)
3234

3335
x₀, y₀, x₁, y₁ = rand(1:100, 100, 5), rand(100), rand(10, 5), rand(10)
@@ -39,7 +41,9 @@ summary4_inference = summarize(its, n=100, inference=true)
3941
# Null distributions for the mean and cummulative changes
4042
its_inference1 = CausalELM.generate_null_distribution(its, 1000, true)
4143
its_inference2 = CausalELM.generate_null_distribution(its, 10, false)
42-
lb4, ub4 = CausalELM.confidence_interval(its_inference1)
44+
lb4, ub4 = CausalELM.confidence_interval(
45+
its_inference1, CausalELM.mean(its.causal_effect)
46+
)
4347
p4, stderr4 = CausalELM.p_value_and_std_err(its_inference1, CausalELM.mean(its_inference1))
4448
p44, stderr44, lb44, ub44 = CausalELM.quantities_of_interest(its, 100, true)
4549

@@ -50,7 +54,9 @@ summary5 = summarize(slearner, n=100)
5054
tlearner = TLearner(x, t, y)
5155
estimate_causal_effect!(tlearner)
5256
tlearner_inference = CausalELM.generate_null_distribution(tlearner, 1000)
53-
lb6, ub6 = CausalELM.confidence_interval(tlearner_inference)
57+
lb6, ub6 = CausalELM.confidence_interval(
58+
tlearner_inference, CausalELM.mean(tlearner.causal_effect)
59+
)
5460
p6, stderr6 = CausalELM.p_value_and_std_err(
5561
tlearner_inference, CausalELM.mean(tlearner_inference)
5662
)
@@ -60,7 +66,9 @@ summary6 = summarize(tlearner, n=100)
6066
xlearner = XLearner(x, t, y)
6167
estimate_causal_effect!(xlearner)
6268
xlearner_inference = CausalELM.generate_null_distribution(xlearner, 1000)
63-
lb7, ub7 = CausalELM.confidence_interval(xlearner_inference)
69+
lb7, ub7 = CausalELM.confidence_interval(
70+
xlearner_inference, CausalELM.mean(xlearner.causal_effect)
71+
)
6472
p7, stderr7 = CausalELM.p_value_and_std_err(
6573
xlearner_inference, CausalELM.mean(xlearner_inference)
6674
)
@@ -74,7 +82,9 @@ summary9 = summarize(rlearner, n=100)
7482
dr_learner = DoublyRobustLearner(x, t, y)
7583
estimate_causal_effect!(dr_learner)
7684
dr_learner_inference = CausalELM.generate_null_distribution(dr_learner, 1000)
77-
lb8, ub8 = CausalELM.confidence_interval(dr_learner_inference)
85+
lb8, ub8 = CausalELM.confidence_interval(
86+
dr_learner_inference, CausalELM.mean(dr_learner.causal_effect)
87+
)
7888
p8, stderr8 = CausalELM.p_value_and_std_err(
7989
dr_learner_inference, CausalELM.mean(dr_learner_inference)
8090
)

0 commit comments

Comments
 (0)