Skip to content

Commit 98fb3f7

Browse files
committed
Implemented marginal effects
1 parent 41030b6 commit 98fb3f7

File tree

4 files changed

+105
-168
lines changed

4 files changed

+105
-168
lines changed

src/estimators.jl

+22-39
Original file line numberDiff line numberDiff line change
@@ -338,25 +338,19 @@ julia> g_formula!(m2)
338338
```
339339
"""
340340
@inline function g_formula!(g) # Keeping this separate for reuse with S-Learning
341-
covariates, y = hcat(g.X, g.T), g.Y
341+
vars, y = hcat(g.X, g.T), g.Y
342342
x₁, x₀ = hcat(g.X, ones(size(g.X, 1))), hcat(g.X, zeros(size(g.X, 1)))
343343

344344
if g.quantity_of_interest ("ITT", "ATE", "CATE")
345-
Xₜ = hcat(covariates[:, 1:(end - 1)], ones(size(covariates, 1)))
346-
Xᵤ = hcat(covariates[:, 1:(end - 1)], zeros(size(covariates, 1)))
345+
Xₜ = hcat(vars[:, 1:(end - 1)], ones(size(vars, 1)))
346+
Xᵤ = hcat(vars[:, 1:(end - 1)], zeros(size(vars, 1)))
347347
else
348-
Xₜ = hcat(covariates[g.T .== 1, 1:(end - 1)], ones(size(g.T[g.T .== 1], 1)))
349-
Xᵤ = hcat(covariates[g.T .== 1, 1:(end - 1)], zeros(size(g.T[g.T .== 1], 1)))
348+
Xₜ = hcat(vars[g.T .== 1, 1:(end - 1)], ones(size(g.T[g.T .== 1], 1)))
349+
Xᵤ = hcat(vars[g.T .== 1, 1:(end - 1)], zeros(size(g.T[g.T .== 1], 1)))
350350
end
351351

352352
g.ensemble = ELMEnsemble(
353-
covariates,
354-
y,
355-
g.sample_size,
356-
g.num_machines,
357-
g.num_feats,
358-
g.num_neurons,
359-
g.activation
353+
vars, y, g.sample_size, g.num_machines, g.num_feats, g.num_neurons, g.activation
360354
)
361355

362356
fit!(g.ensemble)
@@ -383,22 +377,20 @@ julia> estimate_causal_effect!(m2)
383377
"""
384378
@inline function estimate_causal_effect!(DML::DoubleMachineLearning)
385379
X, T, Y = generate_folds(DML.X, DML.T, DML.Y, DML.folds)
386-
DML.causal_effect, DML.marginal_effect = 0, 0
387-
Δ = var_type(DML.T) isa Binary ? 1.0 : 1.5e-8mean(DML.T)
380+
DML.causal_effect, DML.marginal_effect, Δ = 0, 0, 1.5e-8mean(DML.T)
388381

389382
# Cross fitting by training on the main folds and predicting residuals on the auxillary
390383
for fold in 1:(DML.folds)
391384
X_train, X_test = reduce(vcat, X[1:end .!== fold]), X[fold]
392385
Y_train, Y_test = reduce(vcat, Y[1:end .!== fold]), Y[fold]
393386
T_train, T_test = reduce(vcat, T[1:end .!== fold]), T[fold]
394-
T_train₊ = var_type(DML.T) isa Binary ? T_train .* 0 : T_train .+ Δ
395387

396-
Ỹ, T̃, T̃₊ = predict_residuals(
397-
DML, X_train, X_test, Y_train, Y_test, T_train, T_test, T_train₊
388+
Ỹ, T̃, T̃₊, T̃₋ = predict_residuals(
389+
DML, X_train, X_test, Y_train, Y_test, T_train, T_test, Δ
398390
)
399391

400392
DML.causal_effect +=\
401-
DML.marginal_effect += (T̃₊\- DML.causal_effect) / Δ
393+
DML.marginal_effect += (T̃₊\- T̃₋\) / 2Δ
402394
end
403395

404396
DML.causal_effect /= DML.folds
@@ -408,7 +400,7 @@ julia> estimate_causal_effect!(m2)
408400
end
409401

410402
"""
411-
predict_residuals(D, x_train, x_test, y_train, y_test, t_train, t_test, t_train₊)
403+
predict_residuals(D, x_train, x_test, y_train, y_test, t_train, t_test, Δ)
412404
413405
Predict treatment, outcome, and marginal effect residuals for double machine learning or
414406
R-learning.
@@ -423,7 +415,7 @@ julia> x_train, x_test = X[1:80, :], X[81:end, :]
423415
julia> y_train, y_test = Y[1:80], Y[81:end]
424416
julia> t_train, t_test = T[1:80], T[81:100]
425417
julia> m1 = DoubleMachineLearning(X, T, Y)
426-
julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test, zeros(100))
418+
julia> predict_residuals(m1, x_tr, x_te, y_tr, y_te, t_tr, t_te, zeros(100), 1e-5)
427419
```
428420
"""
429421
@inline function predict_residuals(
@@ -433,28 +425,19 @@ julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test,
433425
yₜᵣ::Vector{Float64},
434426
yₜₑ::Vector{Float64},
435427
tₜᵣ::Vector{Float64},
436-
tₜₑ::Vector{Float64},
437-
tₜᵣ₊::Vector{Float64}
428+
tₜₑ::Vector{Float64},
429+
Δ::Float64
438430
)
439-
y = ELMEnsemble(
440-
xₜᵣ, yₜᵣ, D.sample_size, D.num_machines, D.num_feats, D.num_neurons, D.activation
441-
)
442-
443-
t = ELMEnsemble(
444-
xₜᵣ, tₜᵣ, D.sample_size, D.num_machines, D.num_feats, D.num_neurons, D.activation
445-
)
446-
447-
t₊ = ELMEnsemble(
448-
xₜᵣ, tₜᵣ₊, D.sample_size, D.num_machines, D.num_feats, D.num_neurons, D.activation
449-
)
450-
451-
fit!(y)
452-
fit!(t)
453-
fit!(t₊) # Estimate a model with T + a finite difference
431+
args = D.sample_size, D.num_machines, D.num_feats, D.num_neurons, D.activation
432+
y = ELMEnsemble(xₜᵣ, yₜᵣ, args...)
433+
t = ELMEnsemble(xₜᵣ, tₜᵣ, args...)
454434

455-
yₚᵣ, tₚᵣ, tₚᵣ₊ = predict(y, xₜₑ), predict(t, xₜₑ), predict(t₊, xₜₑ)
435+
fit!(y); fit!(t)
436+
yₚᵣ, tₚᵣ = predict(y, xₜₑ), predict(t, xₜₑ)
437+
tₜₑ₊ = var_type(tₜₑ) isa Binary ? ones(size(tₜₑ)) : tₜₑ .+ Δ
438+
tₜₑ₋ = var_type(tₜₑ) isa Binary ? ones(size(tₜₑ)) : tₜₑ .- Δ
456439

457-
return yₜₑ - yₚᵣ, tₜₑ - tₚᵣ, tₜₑ - tₚᵣ
440+
return yₜₑ - yₚᵣ, tₜₑ - tₚᵣ, tₜₑ - tₚᵣ, tₜₑ₋ - tₚᵣ
458441
end
459442

460443
"""

0 commit comments

Comments
 (0)