Skip to content

Commit 0e490c5

Browse files
authored
Merge pull request #66 from dscolby/development
Development
2 parents 7baf9a8 + 2df0032 commit 0e490c5

28 files changed

+1060
-1761
lines changed

Project.toml

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
name = "CausalELM"
22
uuid = "26abab4e-b12e-45db-9809-c199ca6ddca8"
33
authors = ["Darren Colby <[email protected]> and contributors"]
4-
version = "0.6"
4+
version = "0.7.0"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
99

1010
[compat]
11-
LinearAlgebra = "1.7"
12-
Random = "1.7"
13-
julia = "1.7"
1411
Aqua = "0.8"
1512
DataFrames = "1.5"
1613
Documenter = "1.2"
14+
LinearAlgebra = "1.7"
15+
Random = "1.7"
1716
Test = "1.7"
17+
julia = "1.7"
1818

1919
[extras]
2020
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"

README.md

+18-16
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ series analysis, G-computation, and double machine learning; average treatment e
4141
treated (ATT) with G-computation; cumulative treatment effect with interrupted time series
4242
analysis; and the conditional average treatment effect (CATE) via S-learning, T-learning,
4343
X-learning, R-learning, and doubly robust estimation. Underlying all of these estimators are
44-
extreme learning machines, a simple neural network that uses randomized weights instead of
45-
using gradient descent. Once a model has been estimated, CausalELM can summarize the model,
46-
including computing p-values via randomization inference, and conduct sensitivity analysis
47-
to calidate the plausibility of modeling assumptions. Furthermore, all of this can be done
48-
in four lines of code.
44+
ensembles of extreme learning machines, a simple neural network that uses randomized weights
45+
and least squares optimization instead of gradient descent. Once a model has been estimated,
46+
CausalELM can summarize the model and conduct sensitivity analysis to validate the
47+
plausibility of modeling assumptions. Furthermore, all of this can be done in four lines of
48+
code.
4949
</p>
5050

5151
<h2>Extreme Learning Machines and Causal Inference</h2>
@@ -73,37 +73,39 @@ to adjust the initial estimates. This approach has three advantages. First, it i
7373
efficient with high dimensional data than conventional methods. Metalearners take a similar
7474
approach to estimate the CATE. While all of these models are different, they have one thing
7575
in common: how well they perform depends on the underlying model they fit to the data. To
76-
that end, CausalELMs use extreme learning machines because they are simple yet flexible
77-
enough to be universal function approximators.
76+
that end, CausalELMs use bagged ensembles of extreme learning machines because they are
77+
simple yet flexible enough to be universal function approximators with lower varaince than
78+
single extreme learning machines.
7879
</p>
7980

8081
<h2>CausalELM Features</h2>
8182
<ul>
8283
<li>Estimate a causal effect, get a summary, and validate assumptions in just four lines of code</li>
83-
<li>All models automatically select the best number of neurons and L2 penalty</li>
84+
<li>Bagging improves performance and reduces variance without the need to tune a regularization parameter</li>
8485
<li>Enables using the same structs for regression and classification</li>
8586
<li>Includes 13 activation functions and allows user-defined activation functions</li>
8687
<li>Most inference and validation tests do not assume functional or distributional forms</li>
8788
<li>Implements the latest techniques form statistics, econometrics, and biostatistics</li>
88-
<li>Works out of the box with DataFrames or arrays</li>
89+
<li>Works out of the box with arrays or any data structure that implements the Tables.jl interface</li>
8990
<li>Codebase is high-quality, well tested, and regularly updated</li>
9091
</ul>
9192

9293
<h2>What's New?</h2>
9394
<ul>
9495
<li>Now includes doubly robust estimator for CATE estimation</li>
95-
<li>Uses generalized cross validation with successive halving to find the best ridge penalty</li>
96-
<li>Double machine learning, R-learning, and doubly robust estimators suppot specifying confounders and covariates of interest separately</li>
97-
<li>Counterfactual consistency validation simulates outcomes that violate the assumption rather than the previous binning approach</li>
98-
<li>Standardized and improved docstrings and added doctests</li>
96+
<li>All estimators now implement bagging to reduce predictive performance and reduce variance</li>
97+
<li>Counterfactual consistency validation simulates more realistic violations of the counterfactual consistency assumption</li>
98+
<li>Uses a simple heuristic to choose the number of neurons, which reduces training time and still works well in practice</li>
99+
<li>Probability clipping for classifier predictions and residuals is no longer necessary due to the bagging procedure</li>
99100
<li>CausalELM talk has been accepted to JuliaCon 2024!</li>
100101
</ul>
101102

102103
<h2>What's Next?</h2>
103104
<p>
104-
Newer versions of CausalELM will hopefully support using GPUs and provide textual
105-
interpretations of the results of calling validate on a model that has been estimated.
106-
However, these priorities could also change depending on feedback recieved at JuliaCon.
105+
Newer versions of CausalELM will hopefully support using GPUs and provide interpretations of
106+
the results of calling validate on a model that has been estimated. In addition, some
107+
estimators will also support using instrumental variables. However, these priorities could
108+
also change depending on feedback recieved at JuliaCon.
107109
</p>
108110

109111
<h2>Disclaimer</h2>

docs/src/api.md

+6-21
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# CausalELM
2-
Most of the methods and structs here are private, not exported, should not be called by the
3-
user, and are documented for the purpose of developing CausalELM or to facilitate
4-
understanding of the implementation.
2+
```@docs
3+
CausalELM.CausalELM
4+
```
55

66
## Types
77
```@docs
@@ -15,9 +15,8 @@ RLearner
1515
DoublyRobustLearner
1616
CausalELM.CausalEstimator
1717
CausalELM.Metalearner
18-
CausalELM.ExtremeLearningMachine
1918
CausalELM.ExtremeLearner
20-
CausalELM.RegularizedExtremeLearner
19+
CausalELM.ELMEnsemble
2120
CausalELM.Nonbinary
2221
CausalELM.Binary
2322
CausalELM.Count
@@ -41,28 +40,15 @@ elish
4140
fourier
4241
```
4342

44-
## Cross Validation
45-
```@docs
46-
CausalELM.generate_folds
47-
CausalELM.generate_temporal_folds
48-
CausalELM.validation_loss
49-
CausalELM.cross_validate
50-
CausalELM.best_size
51-
CausalELM.shuffle_data
52-
```
53-
5443
## Average Causal Effect Estimators
5544
```@docs
5645
CausalELM.g_formula!
57-
CausalELM.causal_loss!
5846
CausalELM.predict_residuals
59-
CausalELM.make_folds
6047
CausalELM.moving_average
6148
```
6249

6350
## Metalearners
6451
```@docs
65-
CausalELM.causal_loss
6652
CausalELM.doubly_robust_formula!
6753
CausalELM.stage1!
6854
CausalELM.stage2!
@@ -94,7 +80,6 @@ CausalELM.e_value
9480
CausalELM.binarize
9581
CausalELM.risk_ratio
9682
CausalELM.positivity
97-
CausalELM.var_type
9883
```
9984

10085
## Validation Metrics
@@ -114,17 +99,17 @@ CausalELM.fit!
11499
CausalELM.predict
115100
CausalELM.predict_counterfactual!
116101
CausalELM.placebo_test
117-
CausalELM.ridge_constant
118102
CausalELM.set_weights_biases
119103
```
120104

121105
## Utility Functions
122106
```@docs
107+
CausalELM.var_type
123108
CausalELM.mean
124109
CausalELM.var
125110
CausalELM.one_hot_encode
126111
CausalELM.clip_if_binary
127112
CausalELM.@model_config
128113
CausalELM.@standard_input_data
129-
CausalELM.@double_learner_input_data
114+
CausalELM.generate_folds
130115
```

docs/src/contributing.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ code follows the guidelines below.
2727

2828
* Most new structs for estimating causal effects should have mostly the same fields. To
2929
reduce the burden of repeatedly defining all these fields, it is advisable to use the
30-
model_config, standard_input_data, and double_learner_input_data macros to
31-
programmatically generate fields for new structs. Doing so will ensure that with little
32-
to no effort the new structs will work with the summarize and validate methods.
30+
model_config and standard_input_data macros to programmatically generate fields for new
31+
structs. Doing so will ensure that with little to no effort the new structs will work
32+
with the summarize and validate methods.
3333

3434
* There are no repeated code blocks. If there are repeated codeblocks, then they should be
3535
consolidated into a separate function.
3636

37-
* Methods should generally include types and be type stable. If there is a strong reason
38-
to deviate from this point, there should be a comment in the code explaining why.
37+
* Interanl methods can contain types and be parametric but public methods should be as
38+
general as possible.
3939

4040
* Minimize use of new constants and macros. If they must be included, the reason for their
4141
inclusion should be obvious or included in the docstring.

docs/src/guide/doublemachinelearning.md

+31-53
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,8 @@ estimating causal effects when the dimensionality of the covariates is too high
44
regression or the treatment or outcomes cannot be easily modeled parametrically. Double
55
machine learning estimates models of the treatment assignment and outcome and then combines
66
them in a final model. This is a semiparametric model in the sense that the first stage
7-
models can take on any functional form but the final stage model is linear.
8-
9-
!!! note
10-
If regularized is set to true then the ridge penalty will be estimated using generalized
11-
cross validation where the maximum number of iterations is 2 * folds for the successive
12-
halving procedure. However, if the penalty in on iteration is approximately the same as in
13-
the previous penalty, then the procedure will stop early.
7+
models can take on any functional form but the final stage model is a linear combination of
8+
the residuals from the first stage models.
149

1510
!!! note
1611
For more information see:
@@ -19,70 +14,53 @@ models can take on any functional form but the final stage model is linear.
1914
Whitney Newey, and James Robins. "Double/debiased machine learning for treatment and
2015
structural parameters." (2018): C1-C68.
2116

22-
2317
## Step 1: Initialize a Model
24-
The DoubleMachineLearning constructor takes at least three arguments, an array of
25-
covariates, a treatment vector, and an outcome vector. This estimator supports binary, count,
26-
or continuous treatments and binary, count, continuous, or time to event outcomes. You can
27-
also specify confounders that you do not want to estimate the CATE for by passing a parameter
28-
to the W argument. Otherwise, the model assumes all possible confounders are contained in X.
18+
The DoubleMachineLearning constructor takes at least three arguments—covariates, a
19+
treatment statuses, and outcomes, all of which may be either an array or any struct that
20+
implements the Tables.jl interface (e.g. DataFrames). This estimator supports binary, count,
21+
or continuous treatments and binary, count, continuous, or time to event outcomes.
2922

3023
!!! note
31-
Internally, the outcome and treatment models are treated as a regression since extreme
32-
learning machines minimize the MSE. This means that predicted treatments and outcomes
33-
under treatment and control groups could fall outside [0, 1], although this is not likely
34-
in practice. To deal with this, predicted binary variables are automatically clipped to
35-
[0.0000001, 0.9999999]. This also means that count outcomes will be predicted as continuous
36-
variables.
24+
Non-binary categorical outcomes are treated as continuous.
3725

3826
!!! tip
39-
You can also specify the following options: whether the treatment vector is categorical ie
40-
not continuous and containing more than two classes, whether to use L2 regularization, the
41-
activation function, the validation metric to use when searching for the best number of
42-
neurons, the minimum and maximum number of neurons to consider, the number of folds to use
43-
for cross validation, the number of iterations to perform cross validation, and the number
44-
of neurons to use in the ELM used to learn the function from number of neurons to validation
45-
loss. These arguments are specified with the following keyword arguments: t\_cat,
46-
regularized, activation, validation\_metric, min\_neurons, max\_neurons, folds, iterations,
47-
and approximator\_neurons.
27+
You can also specify the the number of folds to use for cross-fitting, the number of
28+
extreme learning machines to incorporate in the ensemble, the number of features to
29+
consider for each extreme learning machine, the activation function to use, the number
30+
of observations to bootstrap in each extreme learning machine, and the number of neurons
31+
in each extreme learning machine. These arguments are specified with the folds,
32+
num_machines, num_features, activation, sample_size, and num\_neurons keywords.
33+
4834
```julia
4935
# Create some data with a binary treatment
5036
X, T, Y, W = rand(100, 5), [rand()<0.4 for i in 1:100], rand(100), rand(100, 4)
5137

52-
# We could also use DataFrames
38+
# We could also use DataFrames or any other package implementing the Tables.jl API
5339
# using DataFrames
5440
# X = DataFrame(x1=rand(100), x2=rand(100), x3=rand(100), x4=rand(100), x5=rand(100))
5541
# T, Y = DataFrame(t=[rand()<0.4 for i in 1:100]), DataFrame(y=rand(100))
56-
# W = DataFrame(w1=rand(100), w2=rand(100), w3=rand(100), w4=rand(100))
57-
58-
# W is optional and means there are confounders that you are not interested in estimating
59-
# the CATE for
60-
dml = DoubleMachineLearning(X, T, Y, W=W)
42+
dml = DoubleMachineLearning(X, T, Y)
6143
```
6244

6345
## Step 2: Estimate the Causal Effect
64-
To estimate the causal effect, we call estimatecausaleffect! on the model above.
46+
To estimate the causal effect, we call estimate_causal_effect! on the model above.
6547
```julia
6648
# we could also estimate the ATT by passing quantity_of_interest="ATT"
6749
estimate_causal_effect!(dml)
6850
```
6951

7052
# Get a Summary
71-
We can get a summary that includes a p-value and standard error estimated via asymptotic
72-
randomization inference by passing our model to the summarize method.
73-
74-
Calling the summarize method returns a dictionary with the estimator's task (regression or
75-
classification), the quantity of interest being estimated (ATE), whether the model uses an
76-
L2 penalty (always true for DML), the activation function used in the model's outcome
77-
predictors, whether the data is temporal (always false for DML), the validation metric used
78-
for cross validation to find the best number of neurons, the number of neurons used in the
79-
ELMs used by the estimator, the number of neurons used in the ELM used to learn a mapping
80-
from number of neurons to validation loss during cross validation, the causal effect,
81-
standard error, and p-value.
53+
We can get a summary of the model by pasing the model to the summarize method.
54+
55+
!!!note
56+
To calculate the p-value and standard error for the treatmetn effect, you can set the
57+
inference argument to false. However, p-values and standard errors are calculated via
58+
randomization inference, which will take a long time. But can be sped up by launching
59+
Julia with a higher number of threads.
60+
8261
```julia
8362
# Can also use the British spelling
8463
# summarise(dml)
85-
8664
summarize(dml)
8765
```
8866

@@ -94,12 +72,12 @@ tests do not provide definitive evidence of a violation of these assumptions. To
9472
counterfactual consistency assumption, we simulate counterfactual outcomes that are
9573
different from the observed outcomes, estimate models with the simulated counterfactual
9674
outcomes, and take the averages. If the outcome is continuous, the noise for the simulated
97-
counterfactuals is drawn from N(0, dev) for each element in devs, otherwise the default is
98-
0.25, 0.5, 0.75, and 1.0 standard deviations from the mean outcome. For discrete variables,
99-
each outcome is replaced with a different value in the range of outcomes with probability ϵ
100-
for each ϵ in devs, otherwise the default is 0.025, 0.05, 0.075, 0.1. If the average
101-
estimate for a given level of violation differs greatly from the effect estimated on the
102-
actual data, then the model is very sensitive to violations of the counterfactual
75+
counterfactuals is drawn from N(0, dev) for each element in devs and each outcome,
76+
multiplied by the original outcome, and added to the original outcome. For discrete
77+
variables, each outcome is replaced with a different value in the range of outcomes with
78+
probability ϵ for each ϵ in devs, otherwise the default is 0.025, 0.05, 0.075, 0.1. If the
79+
average estimate for a given level of violation differs greatly from the effect estimated on
80+
the actual data, then the model is very sensitive to violations of the counterfactual
10381
consistency assumption for that level of violation. Next, this method tests the model's
10482
sensitivity to a violation of the exchangeability assumption by calculating the E-value,
10583
which is the minimum strength of association, on the risk ratio scale, that an unobserved

docs/src/guide/estimatorselection.md

+9-11
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@ given dataset and causal question.
55

66
| Model | Struct | Causal Estimands | Supported Treatment Types | Supported Outcome Types |
77
|----------------------------------|-----------------------|----------------------------------|---------------------------|------------------------------------------|
8-
| Interrupted Time Series Analysis | InterruptedTimeSeries | ATE, Cumulative Treatment Effect | Binary | Continuous, Count[^2], Time to Event |
9-
| G-computation | GComputation | ATE, ATT, ITT | Binary | Binary[^1],Continuous, Time to Event, Count[^2] |
10-
| Double Machine Learning | DoubleMachineLearning | ATE | Binary[^1], Count[^2], Continuous | Binary[^1], Count[^2], Continuous, Time to Event |
11-
| S-learning | SLearner | CATE | Binary | Binary[^1], Continuous, Time to Event, Count[^2] |
12-
| T-learning | TLearner | CATE | Binary | Binary[^1], Continuous, Count[^2], Time to Event |
13-
| X-learning | XLearner | CATE | Binary[^1] | Binary[^1], Continuous, Count[^2], Time to Event |
14-
| R-learning | RLearner | CATE | Binary[^1], Count[^2], Continuous | Binary[^1], Count[^2], Continuous, Time to Event |
15-
| Doubly Robust Estimation | DoublyRobustLearner | CATE | Binary | Binary[^1], Continuous, Count[^2], Time to Event |
8+
| Interrupted Time Series Analysis | InterruptedTimeSeries | ATE, Cumulative Treatment Effect | Binary | Continuous, Count[^1], Time to Event |
9+
| G-computation | GComputation | ATE, ATT, ITT | Binary | Binary,Continuous, Time to Event, Count[^1] |
10+
| Double Machine Learning | DoubleMachineLearning | ATE | Binary, Count[^1], Continuous | Binary, Count[^1], Continuous, Time to Event |
11+
| S-learning | SLearner | CATE | Binary | Binary, Continuous, Time to Event, Count[^1] |
12+
| T-learning | TLearner | CATE | Binary | Binary, Continuous, Count[^1], Time to Event |
13+
| X-learning | XLearner | CATE | Binary | Binary, Continuous, Count[^1], Time to Event |
14+
| R-learning | RLearner | CATE | Binary, Count[^1], Continuous | Binary, Count[^1], Continuous, Time to Event |
15+
| Doubly Robust Estimation | DoublyRobustLearner | CATE | Binary | Binary, Continuous, Count[^1], Time to Event |
1616

17-
[^1]: Models that use propensity scores or predict binary treatment assignment may, on very rare occasions, return values outside of [0, 1]. In that case, values are clipped to be between 0.0000001 and 0.9999999.
18-
19-
[^2]: Similar to other packages, predictions of count variables is treated as a continuous regression task.
17+
[^1]: Similar to other packages, predictions of count variables is treated as a continuous regression task.

0 commit comments

Comments
 (0)