Skip to content

Commit 41030b6

Browse files
committed
Added marginal effects to summaries
1 parent 31dcb21 commit 41030b6

7 files changed

+91
-26
lines changed

docs/src/release_notes.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ These release notes adhere to the [keep a changelog](https://keepachangelog.com/
44
## Version [0.8.0](https://github.com/dscolby/CausalELM.jl/releases/tag/v0.8.0) - 2024-10-31
55
### Added
66
* Implemented randomization inference-based confidence intervals [#78](https://github.com/dscolby/CausalELM.jl/issues/78)
7+
* Added marginal effects to model summaries [#78](https://github.com/dscolby/CausalELM.jl/issues/78)
78
### Fixed
89
* Removed unnecessary include and using statements
910
* Slightly sped up the randomization inference implementation and clarified it in the docs [#77](https://github.com/dscolby/CausalELM.jl/issues/77)
1011
* Fixed the randomization inference index selection procedure for interrupted time series estimators
11-
* Inlined certain methods to slightly improve performance [#79](https://github.com/dscolby/CausalELM.jl/issues/79)
12+
* Inlined certain methods to slightly improve performance [#76](https://github.com/dscolby/CausalELM.jl/issues/76)
1213

1314
## Version [v0.7.0](https://github.com/dscolby/CausalELM.jl/releases/tag/v0.7.0) - 2024-06-22
1415
### Added

src/estimators.jl

+38-15
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ mutable struct InterruptedTimeSeries
4848
Y₀::Array{Float64}
4949
X₁::Array{Float64}
5050
Y₁::Array{Float64}
51+
marginal_effect::Float64
5152
@model_config individual_effect
5253
end
5354

@@ -77,6 +78,7 @@ function InterruptedTimeSeries(
7778
float(Y₀),
7879
X₁,
7980
float(Y₁),
81+
NaN,
8082
"difference",
8183
true,
8284
task,
@@ -137,6 +139,7 @@ julia> m5 = GComputation(x_df, t_df, y_df)
137139
mutable struct GComputation <: CausalEstimator
138140
@standard_input_data
139141
@model_config average_effect
142+
marginal_effect::Float64
140143
ensemble::ELMEnsemble
141144

142145
function GComputation(
@@ -173,6 +176,7 @@ mutable struct GComputation <: CausalEstimator
173176
num_feats,
174177
num_neurons,
175178
NaN,
179+
NaN,
176180
)
177181
end
178182
end
@@ -220,6 +224,7 @@ julia> m2 = DoubleMachineLearning(x_df, t_df, y_df)
220224
mutable struct DoubleMachineLearning <: CausalEstimator
221225
@standard_input_data
222226
@model_config average_effect
227+
marginal_effect::Float64
223228
folds::Integer
224229
end
225230

@@ -256,6 +261,7 @@ function DoubleMachineLearning(
256261
num_feats,
257262
num_neurons,
258263
NaN,
264+
NaN,
259265
folds,
260266
)
261267
end
@@ -285,6 +291,7 @@ julia> estimate_causal_effect!(m1)
285291

286292
fit!(learner)
287293
its.causal_effect = predict(learner, its.X₁) .- its.Y₁
294+
its.marginal_effect = mean(its.causal_effect)
288295

289296
return its.causal_effect
290297
end
@@ -309,7 +316,9 @@ julia> estimate_causal_effect!(m1)
309316
```
310317
"""
311318
@inline function estimate_causal_effect!(g::GComputation)
312-
g.causal_effect = mean(g_formula!(g))
319+
causal_effect, marginal_effect = g_formula!(g)
320+
g.causal_effect, g.marginal_effect = mean(causal_effect), mean(marginal_effect)
321+
313322
return g.causal_effect
314323
end
315324

@@ -330,6 +339,7 @@ julia> g_formula!(m2)
330339
"""
331340
@inline function g_formula!(g) # Keeping this separate for reuse with S-Learning
332341
covariates, y = hcat(g.X, g.T), g.Y
342+
x₁, x₀ = hcat(g.X, ones(size(g.X, 1))), hcat(g.X, zeros(size(g.X, 1)))
333343

334344
if g.quantity_of_interest ("ITT", "ATE", "CATE")
335345
Xₜ = hcat(covariates[:, 1:(end - 1)], ones(size(covariates, 1)))
@@ -350,10 +360,9 @@ julia> g_formula!(m2)
350360
)
351361

352362
fit!(g.ensemble)
353-
354363
yₜ, yᵤ = predict(g.ensemble, Xₜ), predict(g.ensemble, Xᵤ)
355364

356-
return vec(yₜ) - vec(yᵤ)
365+
return vec(yₜ) - vec(yᵤ), predict(g.ensemble, x₁) - predict(g.ensemble, x₀)
357366
end
358367

359368
"""
@@ -374,27 +383,35 @@ julia> estimate_causal_effect!(m2)
374383
"""
375384
@inline function estimate_causal_effect!(DML::DoubleMachineLearning)
376385
X, T, Y = generate_folds(DML.X, DML.T, DML.Y, DML.folds)
377-
DML.causal_effect = 0
386+
DML.causal_effect, DML.marginal_effect = 0, 0
387+
Δ = var_type(DML.T) isa Binary ? 1.0 : 1.5e-8mean(DML.T)
378388

379389
# Cross fitting by training on the main folds and predicting residuals on the auxillary
380-
for fld in 1:(DML.folds)
381-
X_train, X_test = reduce(vcat, X[1:end .!== fld]), X[fld]
382-
Y_train, Y_test = reduce(vcat, Y[1:end .!== fld]), Y[fld]
383-
T_train, T_test = reduce(vcat, T[1:end .!== fld]), T[fld]
384-
385-
Ỹ, T̃ = predict_residuals(DML, X_train, X_test, Y_train, Y_test, T_train, T_test)
390+
for fold in 1:(DML.folds)
391+
X_train, X_test = reduce(vcat, X[1:end .!== fold]), X[fold]
392+
Y_train, Y_test = reduce(vcat, Y[1:end .!== fold]), Y[fold]
393+
T_train, T_test = reduce(vcat, T[1:end .!== fold]), T[fold]
394+
T_train₊ = var_type(DML.T) isa Binary ? T_train .* 0 : T_train .+ Δ
395+
396+
Ỹ, T̃, T̃₊ = predict_residuals(
397+
DML, X_train, X_test, Y_train, Y_test, T_train, T_test, T_train₊
398+
)
386399

387400
DML.causal_effect +=\
401+
DML.marginal_effect += (T̃₊\- DML.causal_effect) / Δ
388402
end
403+
389404
DML.causal_effect /= DML.folds
405+
DML.marginal_effect /= DML.folds
390406

391407
return DML.causal_effect
392408
end
393409

394410
"""
395-
predict_residuals(D, x_train, x_test, y_train, y_test, t_train, t_test)
411+
predict_residuals(D, x_train, x_test, y_train, y_test, t_train, t_test, t_train₊)
396412
397-
Predict treatment and outcome residuals for double machine learning or R-learning.
413+
Predict treatment, outcome, and marginal effect residuals for double machine learning or
414+
R-learning.
398415
399416
# Notes
400417
This method should not be called directly.
@@ -406,7 +423,7 @@ julia> x_train, x_test = X[1:80, :], X[81:end, :]
406423
julia> y_train, y_test = Y[1:80], Y[81:end]
407424
julia> t_train, t_test = T[1:80], T[81:100]
408425
julia> m1 = DoubleMachineLearning(X, T, Y)
409-
julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test)
426+
julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test, zeros(100))
410427
```
411428
"""
412429
@inline function predict_residuals(
@@ -417,6 +434,7 @@ julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test)
417434
yₜₑ::Vector{Float64},
418435
tₜᵣ::Vector{Float64},
419436
tₜₑ::Vector{Float64},
437+
tₜᵣ₊::Vector{Float64}
420438
)
421439
y = ELMEnsemble(
422440
xₜᵣ, yₜᵣ, D.sample_size, D.num_machines, D.num_feats, D.num_neurons, D.activation
@@ -426,12 +444,17 @@ julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test)
426444
xₜᵣ, tₜᵣ, D.sample_size, D.num_machines, D.num_feats, D.num_neurons, D.activation
427445
)
428446

447+
t₊ = ELMEnsemble(
448+
xₜᵣ, tₜᵣ₊, D.sample_size, D.num_machines, D.num_feats, D.num_neurons, D.activation
449+
)
450+
429451
fit!(y)
430452
fit!(t)
453+
fit!(t₊) # Estimate a model with T + a finite difference
431454

432-
yₚᵣ, tₚᵣ = predict(y, xₜₑ), predict(t, xₜₑ)
455+
yₚᵣ, tₚᵣ, tₚᵣ₊ = predict(y, xₜₑ), predict(t, xₜₑ), predict(t₊, xₜₑ)
433456

434-
return yₜₑ - yₚᵣ, tₜₑ - tₚᵣ
457+
return yₜₑ - yₚᵣ, tₜₑ - tₚᵣ, tₜₑ - tₚᵣ₊
435458
end
436459

437460
"""

src/inference.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ function summarize(mod; kwargs...)
5959
"Standard Error",
6060
"p-value",
6161
"Lower 2.5% CI",
62-
"Upper 97.5% CI"
62+
"Upper 97.5% CI",
63+
"Marginal Effect"
6364
]
6465

6566
if haskey(kwargs, :inference) && kwargs[:inference] == true
@@ -82,7 +83,8 @@ function summarize(mod; kwargs...)
8283
stderr,
8384
p,
8485
lower_ci,
85-
upper_ci
86+
upper_ci,
87+
mod.marginal_effect
8688
]
8789

8890
for (nicename, value) in zip(nicenames, values)
@@ -124,7 +126,8 @@ function summarize(its::InterruptedTimeSeries; kwargs...)
124126
"Standard Error",
125127
"p-value",
126128
"Lower 2.5% CI",
127-
"Upper 97.5% CI"
129+
"Upper 97.5% CI",
130+
"Marginal Effect"
128131
]
129132

130133
values = [
@@ -140,7 +143,8 @@ function summarize(its::InterruptedTimeSeries; kwargs...)
140143
stderr,
141144
p,
142145
l,
143-
u
146+
u,
147+
its.marginal_effect
144148
]
145149

146150
for (nicename, value) in zip(nicenames, values)

src/metalearners.jl

+27-5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ julia> m4 = SLearner(x_df, t_df, y_df)
4545
mutable struct SLearner <: Metalearner
4646
@standard_input_data
4747
@model_config individual_effect
48+
marginal_effect::Vector{Float64}
4849
ensemble::ELMEnsemble
4950

5051
function SLearner(
@@ -76,6 +77,7 @@ mutable struct SLearner <: Metalearner
7677
num_feats,
7778
num_neurons,
7879
fill(NaN, size(T, 1)),
80+
fill(NaN, size(T, 1)),
7981
)
8082
end
8183
end
@@ -123,6 +125,7 @@ julia> m3 = TLearner(x_df, t_df, y_df)
123125
mutable struct TLearner <: Metalearner
124126
@standard_input_data
125127
@model_config individual_effect
128+
marginal_effect::Vector{Float64}
126129
μ₀::ELMEnsemble
127130
μ₁::ELMEnsemble
128131

@@ -154,6 +157,7 @@ mutable struct TLearner <: Metalearner
154157
num_feats,
155158
num_neurons,
156159
fill(NaN, size(T, 1)),
160+
fill(NaN, size(T, 1)),
157161
)
158162
end
159163
end
@@ -201,6 +205,7 @@ julia> m3 = XLearner(x_df, t_df, y_df)
201205
mutable struct XLearner <: Metalearner
202206
@standard_input_data
203207
@model_config individual_effect
208+
marginal_effect::Vector{Float64}
204209
μ₀::ELMEnsemble
205210
μ₁::ELMEnsemble
206211
ps::Array{Float64}
@@ -233,6 +238,7 @@ mutable struct XLearner <: Metalearner
233238
num_feats,
234239
num_neurons,
235240
fill(NaN, size(T, 1)),
241+
fill(NaN, size(T, 1)),
236242
)
237243
end
238244
end
@@ -278,6 +284,7 @@ julia> m2 = RLearner(x_df, t_df, y_df)
278284
mutable struct RLearner <: Metalearner
279285
@standard_input_data
280286
@model_config individual_effect
287+
marginal_effect::Vector{Float64}
281288
folds::Integer
282289
end
283290

@@ -315,6 +322,7 @@ function RLearner(
315322
num_feats,
316323
num_neurons,
317324
fill(NaN, size(T, 1)),
325+
fill(NaN, size(T, 1)),
318326
folds,
319327
)
320328
end
@@ -363,6 +371,7 @@ julia> m3 = DoublyRobustLearner(X, T, Y, W=w)
363371
mutable struct DoublyRobustLearner <: Metalearner
364372
@standard_input_data
365373
@model_config individual_effect
374+
marginal_effect::Vector{Float64}
366375
folds::Integer
367376
end
368377

@@ -398,6 +407,7 @@ function DoublyRobustLearner(
398407
num_feats,
399408
num_neurons,
400409
fill(NaN, size(T, 1)),
410+
fill(NaN, size(T, 1)),
401411
2,
402412
)
403413
end
@@ -421,7 +431,7 @@ julia> estimate_causal_effect!(m4)
421431
```
422432
"""
423433
@inline function estimate_causal_effect!(s::SLearner)
424-
s.causal_effect = g_formula!(s)
434+
s.causal_effect, s.marginal_effect = g_formula!(s)
425435
return s.causal_effect
426436
end
427437

@@ -458,6 +468,7 @@ julia> estimate_causal_effect!(m5)
458468
fit!(t.μ₁)
459469
predictionsₜ, predictionsᵪ = predict(t.μ₁, t.X), predict(t.μ₀, t.X)
460470
t.causal_effect = @fastmath vec(predictionsₜ - predictionsᵪ)
471+
t.marginal_effect = t.causal_effect
461472

462473
return t.causal_effect
463474
end
@@ -488,6 +499,8 @@ julia> estimate_causal_effect!(m1)
488499
(x.ps .* predict(μχ₀, x.X)) .+ ((1 .- x.ps) .* predict(μχ₁, x.X))
489500
))
490501

502+
x.marginal_effect = x.causal_effect # Works since T is binary
503+
491504
return x.causal_effect
492505
end
493506

@@ -510,26 +523,34 @@ julia> estimate_causal_effect!(m1)
510523
"""
511524
@inline function estimate_causal_effect!(R::RLearner)
512525
X, T̃, Ỹ = generate_folds(R.X, R.T, R.Y, R.folds)
526+
T̃₊, Δ = similar(T̃), var_type(R.T) isa Binary ? 1.0 : 1.5e-8mean(R.T)
513527
R.X, R.T, R.Y = reduce(vcat, X), reduce(vcat, T̃), reduce(vcat, Ỹ)
514528

515529
# Get residuals from out-of-fold predictions
516530
for f in 1:(R.folds)
517531
X_train, X_test = reduce(vcat, X[1:end .!== f]), X[f]
518532
Y_train, Y_test = reduce(vcat, Ỹ[1:end .!== f]), Ỹ[f]
519533
T_train, T_test = reduce(vcat, T̃[1:end .!== f]), T̃[f]
520-
Ỹ[f], T̃[f] = predict_residuals(R, X_train, X_test, Y_train, Y_test, T_train, T_test)
534+
T_train₊ = var_type(R.T) isa Binary ? T_train .* 0 : T_train .+ Δ
535+
Ỹ[f], T̃[f], T̃₊[f] = predict_residuals(
536+
R, X_train, X_test, Y_train, Y_test, T_train, T_test, T_train₊
537+
)
521538
end
522539

523540
# Using target transformation and the weight trick to minimize the causal loss
524541
T̃², target = reduce(vcat, T̃).^2, reduce(vcat, Ỹ) ./ reduce(vcat, T̃)
525542
Xʷ, Yʷ = R.X .* T̃², target .* T̃²
526-
527-
# Fit a weighted residual-on-residual model
543+
T̃²₊, target₊ = reduce(vcat, T̃₊).^2, reduce(vcat, Ỹ) ./ reduce(vcat, T̃₊)
528544
final_model = ELMEnsemble(
529545
Xʷ, Yʷ, R.sample_size, R.num_machines, R.num_feats, R.num_neurons, R.activation
530546
)
531-
fit!(final_model)
547+
548+
# Using finite differences to calculate marginal effects
549+
final_model₊ = deepcopy(final_model)
550+
final_model₊.X, final_model₊.Y = R.X .* T̃²₊, target₊ .* T̃²₊
551+
fit!(final_model); fit!(final_model₊)
532552
R.causal_effect = predict(final_model, R.X)
553+
R.marginal_effect = (predict(final_model₊, final_model.X) - R.causal_effect) ./ Δ
533554

534555
return R.causal_effect
535556
end
@@ -563,6 +584,7 @@ julia> estimate_causal_effect!(m1)
563584

564585
causal_effect ./= 2
565586
DRE.causal_effect = causal_effect
587+
DRE.marginal_effect = causal_effect
566588

567589
return DRE.causal_effect
568590
end

0 commit comments

Comments
 (0)