@@ -48,6 +48,7 @@ mutable struct InterruptedTimeSeries
48
48
Y₀:: Array{Float64}
49
49
X₁:: Array{Float64}
50
50
Y₁:: Array{Float64}
51
+ marginal_effect:: Float64
51
52
@model_config individual_effect
52
53
end
53
54
@@ -77,6 +78,7 @@ function InterruptedTimeSeries(
77
78
float (Y₀),
78
79
X₁,
79
80
float (Y₁),
81
+ NaN ,
80
82
" difference" ,
81
83
true ,
82
84
task,
@@ -137,6 +139,7 @@ julia> m5 = GComputation(x_df, t_df, y_df)
137
139
mutable struct GComputation <: CausalEstimator
138
140
@standard_input_data
139
141
@model_config average_effect
142
+ marginal_effect:: Float64
140
143
ensemble:: ELMEnsemble
141
144
142
145
function GComputation (
@@ -173,6 +176,7 @@ mutable struct GComputation <: CausalEstimator
173
176
num_feats,
174
177
num_neurons,
175
178
NaN ,
179
+ NaN ,
176
180
)
177
181
end
178
182
end
@@ -220,6 +224,7 @@ julia> m2 = DoubleMachineLearning(x_df, t_df, y_df)
220
224
mutable struct DoubleMachineLearning <: CausalEstimator
221
225
@standard_input_data
222
226
@model_config average_effect
227
+ marginal_effect:: Float64
223
228
folds:: Integer
224
229
end
225
230
@@ -256,6 +261,7 @@ function DoubleMachineLearning(
256
261
num_feats,
257
262
num_neurons,
258
263
NaN ,
264
+ NaN ,
259
265
folds,
260
266
)
261
267
end
@@ -285,6 +291,7 @@ julia> estimate_causal_effect!(m1)
285
291
286
292
fit! (learner)
287
293
its. causal_effect = predict (learner, its. X₁) .- its. Y₁
294
+ its. marginal_effect = mean (its. causal_effect)
288
295
289
296
return its. causal_effect
290
297
end
@@ -309,7 +316,9 @@ julia> estimate_causal_effect!(m1)
309
316
```
310
317
"""
311
318
@inline function estimate_causal_effect! (g:: GComputation )
312
- g. causal_effect = mean (g_formula! (g))
319
+ causal_effect, marginal_effect = g_formula! (g)
320
+ g. causal_effect, g. marginal_effect = mean (causal_effect), mean (marginal_effect)
321
+
313
322
return g. causal_effect
314
323
end
315
324
@@ -330,6 +339,7 @@ julia> g_formula!(m2)
330
339
"""
331
340
@inline function g_formula! (g) # Keeping this separate for reuse with S-Learning
332
341
covariates, y = hcat (g. X, g. T), g. Y
342
+ x₁, x₀ = hcat (g. X, ones (size (g. X, 1 ))), hcat (g. X, zeros (size (g. X, 1 )))
333
343
334
344
if g. quantity_of_interest ∈ (" ITT" , " ATE" , " CATE" )
335
345
Xₜ = hcat (covariates[:, 1 : (end - 1 )], ones (size (covariates, 1 )))
@@ -350,10 +360,9 @@ julia> g_formula!(m2)
350
360
)
351
361
352
362
fit! (g. ensemble)
353
-
354
363
yₜ, yᵤ = predict (g. ensemble, Xₜ), predict (g. ensemble, Xᵤ)
355
364
356
- return vec (yₜ) - vec (yᵤ)
365
+ return vec (yₜ) - vec (yᵤ), predict (g . ensemble, x₁) - predict (g . ensemble, x₀)
357
366
end
358
367
359
368
"""
@@ -374,27 +383,35 @@ julia> estimate_causal_effect!(m2)
374
383
"""
375
384
@inline function estimate_causal_effect! (DML:: DoubleMachineLearning )
376
385
X, T, Y = generate_folds (DML. X, DML. T, DML. Y, DML. folds)
377
- DML. causal_effect = 0
386
+ DML. causal_effect, DML. marginal_effect = 0 , 0
387
+ Δ = var_type (DML. T) isa Binary ? 1.0 : 1.5e-8 mean (DML. T)
378
388
379
389
# Cross fitting by training on the main folds and predicting residuals on the auxillary
380
- for fld in 1 : (DML. folds)
381
- X_train, X_test = reduce (vcat, X[1 : end .!= = fld]), X[fld]
382
- Y_train, Y_test = reduce (vcat, Y[1 : end .!= = fld]), Y[fld]
383
- T_train, T_test = reduce (vcat, T[1 : end .!= = fld]), T[fld]
384
-
385
- Ỹ, T̃ = predict_residuals (DML, X_train, X_test, Y_train, Y_test, T_train, T_test)
390
+ for fold in 1 : (DML. folds)
391
+ X_train, X_test = reduce (vcat, X[1 : end .!= = fold]), X[fold]
392
+ Y_train, Y_test = reduce (vcat, Y[1 : end .!= = fold]), Y[fold]
393
+ T_train, T_test = reduce (vcat, T[1 : end .!= = fold]), T[fold]
394
+ T_train₊ = var_type (DML. T) isa Binary ? T_train .* 0 : T_train .+ Δ
395
+
396
+ Ỹ, T̃, T̃₊ = predict_residuals (
397
+ DML, X_train, X_test, Y_train, Y_test, T_train, T_test, T_train₊
398
+ )
386
399
387
400
DML. causal_effect += T̃\ Ỹ
401
+ DML. marginal_effect += (T̃₊\ Ỹ - DML. causal_effect) / Δ
388
402
end
403
+
389
404
DML. causal_effect /= DML. folds
405
+ DML. marginal_effect /= DML. folds
390
406
391
407
return DML. causal_effect
392
408
end
393
409
394
410
"""
395
- predict_residuals(D, x_train, x_test, y_train, y_test, t_train, t_test)
411
+ predict_residuals(D, x_train, x_test, y_train, y_test, t_train, t_test, t_train₊ )
396
412
397
- Predict treatment and outcome residuals for double machine learning or R-learning.
413
+ Predict treatment, outcome, and marginal effect residuals for double machine learning or
414
+ R-learning.
398
415
399
416
# Notes
400
417
This method should not be called directly.
@@ -406,7 +423,7 @@ julia> x_train, x_test = X[1:80, :], X[81:end, :]
406
423
julia> y_train, y_test = Y[1:80], Y[81:end]
407
424
julia> t_train, t_test = T[1:80], T[81:100]
408
425
julia> m1 = DoubleMachineLearning(X, T, Y)
409
- julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test)
426
+ julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test, zeros(100) )
410
427
```
411
428
"""
412
429
@inline function predict_residuals (
@@ -417,6 +434,7 @@ julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test)
417
434
yₜₑ:: Vector{Float64} ,
418
435
tₜᵣ:: Vector{Float64} ,
419
436
tₜₑ:: Vector{Float64} ,
437
+ tₜᵣ₊:: Vector{Float64}
420
438
)
421
439
y = ELMEnsemble (
422
440
xₜᵣ, yₜᵣ, D. sample_size, D. num_machines, D. num_feats, D. num_neurons, D. activation
@@ -426,12 +444,17 @@ julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test)
426
444
xₜᵣ, tₜᵣ, D. sample_size, D. num_machines, D. num_feats, D. num_neurons, D. activation
427
445
)
428
446
447
+ t₊ = ELMEnsemble (
448
+ xₜᵣ, tₜᵣ₊, D. sample_size, D. num_machines, D. num_feats, D. num_neurons, D. activation
449
+ )
450
+
429
451
fit! (y)
430
452
fit! (t)
453
+ fit! (t₊) # Estimate a model with T + a finite difference
431
454
432
- yₚᵣ, tₚᵣ = predict (y, xₜₑ), predict (t, xₜₑ)
455
+ yₚᵣ, tₚᵣ, tₚᵣ₊ = predict (y, xₜₑ), predict (t, xₜₑ), predict (t₊ , xₜₑ)
433
456
434
- return yₜₑ - yₚᵣ, tₜₑ - tₚᵣ
457
+ return yₜₑ - yₚᵣ, tₜₑ - tₚᵣ, tₜₑ - tₚᵣ₊
435
458
end
436
459
437
460
"""
0 commit comments