Skip to content

Commit bef9d63

Browse files
authored
Merge pull request #82 from dscolby/development
v0.8.0
2 parents 1897aba + 550fded commit bef9d63

22 files changed

+662
-498
lines changed

.github/workflows/CI.yml

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ on:
33
push:
44
branches:
55
- main
6+
- development
67
tags: '*'
78
pull_request:
89
branches:
@@ -25,6 +26,7 @@ jobs:
2526
- '1.8'
2627
- '1.9'
2728
- '1.10'
29+
- '1.11'
2830
- 'nightly'
2931
os:
3032
- ubuntu-latest

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
/docs/build
2+
/.vscode

Manifest.toml

+43-10
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,74 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
julia_version = "1.8.5"
3+
julia_version = "1.11.1"
44
manifest_format = "2.0"
5-
project_hash = "18a38d2a3c0a24ffa847859ade56a5a957640011"
5+
project_hash = "48b0ecc3de09367019241b9866f1be8d1ab8f4cc"
66

77
[[deps.Artifacts]]
88
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
9+
version = "1.11.0"
910

1011
[[deps.CompilerSupportLibraries_jll]]
1112
deps = ["Artifacts", "Libdl"]
1213
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
13-
version = "1.0.1+0"
14+
version = "1.1.1+0"
15+
16+
[[deps.DataAPI]]
17+
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
18+
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
19+
version = "1.16.0"
20+
21+
[[deps.DataValueInterfaces]]
22+
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
23+
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
24+
version = "1.0.0"
25+
26+
[[deps.IteratorInterfaceExtensions]]
27+
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
28+
uuid = "82899510-4779-5014-852e-03e436cf321d"
29+
version = "1.0.0"
1430

1531
[[deps.Libdl]]
1632
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
33+
version = "1.11.0"
1734

1835
[[deps.LinearAlgebra]]
19-
deps = ["Libdl", "libblastrampoline_jll"]
36+
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
2037
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
38+
version = "1.11.0"
2139

2240
[[deps.OpenBLAS_jll]]
2341
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
2442
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
25-
version = "0.3.20+0"
43+
version = "0.3.27+1"
44+
45+
[[deps.OrderedCollections]]
46+
git-tree-sha1 = "12f1439c4f986bb868acda6ea33ebc78e19b95ad"
47+
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
48+
version = "1.7.0"
2649

2750
[[deps.Random]]
28-
deps = ["SHA", "Serialization"]
51+
deps = ["SHA"]
2952
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
53+
version = "1.11.0"
3054

3155
[[deps.SHA]]
3256
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
3357
version = "0.7.0"
3458

35-
[[deps.Serialization]]
36-
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
59+
[[deps.TableTraits]]
60+
deps = ["IteratorInterfaceExtensions"]
61+
git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
62+
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
63+
version = "1.0.1"
64+
65+
[[deps.Tables]]
66+
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"]
67+
git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297"
68+
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
69+
version = "1.12.0"
3770

3871
[[deps.libblastrampoline_jll]]
39-
deps = ["Artifacts", "Libdl", "OpenBLAS_jll"]
72+
deps = ["Artifacts", "Libdl"]
4073
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
41-
version = "5.1.1+0"
74+
version = "5.11.0+0"

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
name = "CausalELM"
22
uuid = "26abab4e-b12e-45db-9809-c199ca6ddca8"
33
authors = ["Darren Colby <[email protected]> and contributors"]
4-
version = "0.7.0"
4+
version = "0.8.0"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
910

1011
[compat]
1112
Aqua = "0.8"
1213
DataFrames = "1.5"
1314
Documenter = "1.2"
1415
LinearAlgebra = "1.7"
1516
Random = "1.7"
17+
Tables = "1.12.0"
1618
Test = "1.7"
1719
julia = "1.7"
1820

README.md

+39-48
Original file line numberDiff line numberDiff line change
@@ -34,51 +34,39 @@
3434
</p>
3535

3636
<p>
37-
CausalELM enables estimation of causal effects in settings where a randomized control trial
38-
or traditional statistical models would be infeasible or unacceptable. It enables estimation
39-
of the average treatment effect (ATE)/intent to treat effect (ITE) with interrupted time
40-
series analysis, G-computation, and double machine learning; average treatment effect on the
41-
treated (ATT) with G-computation; cumulative treatment effect with interrupted time series
42-
analysis; and the conditional average treatment effect (CATE) via S-learning, T-learning,
43-
X-learning, R-learning, and doubly robust estimation. Underlying all of these estimators are
44-
ensembles of extreme learning machines, a simple neural network that uses randomized weights
45-
and least squares optimization instead of gradient descent. Once a model has been estimated,
46-
CausalELM can summarize the model and conduct sensitivity analysis to validate the
47-
plausibility of modeling assumptions. Furthermore, all of this can be done in four lines of
48-
code.
37+
CausalELM provides easy-to-use implementations of modern causal inference methods. While
38+
CausalELM implements a variety of estimators, they all have one thing in common—the use of
39+
machine learning models to flexibly estimate causal effects. This is where the ELM in
40+
CausalELM comes from—the machine learning model underlying all the estimators is an extreme
41+
learning machine (ELM). ELMs are a simple neural network that use randomized weights and
42+
offer a good tradeoff between learning non-linear dependencies and simplicity. Furthermore,
43+
CausalELM implements bagged ensembles of ELMs to reduce the variance resulting from
44+
randomized weights.
4945
</p>
5046

51-
<h2>Extreme Learning Machines and Causal Inference</h2>
47+
<h2>Estimators</h2>
5248
<p>
53-
In some cases we would like to know the causal effect of some intervention but we do not
54-
have the counterfactual, making conventional methods of statistical analysis infeasible.
55-
However, it may still be possible to get an unbiased estimate of the causal effect (ATE,
56-
ATE, or ITT) by predicting the counterfactual and comparing it to the observed outcomes.
57-
This is the approach CausalELM takes to conduct interrupted time series analysis,
58-
G-Computation, double machine learning, and metalearning via S-Learners, T-Learners,
59-
X-Learners, R-learners, and doubly robust estimation. In interrupted time series analysis,
60-
we want to estimate the effect of some intervention on the outcome of a single unit that we
61-
observe during multiple time periods. For example, we might want to know how the
62-
announcement of a merger affected the price of Stock A. To do this, we need to know what the
63-
price of stock A would have been if the merger had not been announced, which we can predict
64-
with machine learning methods. Then, we can compare this predicted counterfactual to the
65-
observed price data to estimate the effect of the merger announcement. In another case, we
66-
might want to know the effect of medicine X on disease Y but the administration of X was not
67-
random and it might have also been administered at mulitiple time periods, which would
68-
produce biased estimates. To overcome this, G-computation models the observed data, uses the
69-
model to predict the outcomes if all patients recieved the treatment, and compares it to the
70-
predictions of the outcomes if none of the patients recieved the treatment. Double machine
71-
learning (DML) takes a similar approach but also models the treatment mechanism and uses it
72-
to adjust the initial estimates. This approach has three advantages. First, it is more
73-
efficient with high dimensional data than conventional methods. Metalearners take a similar
74-
approach to estimate the CATE. While all of these models are different, they have one thing
75-
in common: how well they perform depends on the underlying model they fit to the data. To
76-
that end, CausalELMs use bagged ensembles of extreme learning machines because they are
77-
simple yet flexible enough to be universal function approximators with lower varaince than
78-
single extreme learning machines.
49+
CausalELM implements estimators for aggreate e.g. average treatment effect (ATE) and
50+
individualized e.g. conditional average treatment effect (CATE) quantities of interest.
7951
</p>
8052

81-
<h2>CausalELM Features</h2>
53+
<h3>Estimators for Aggregate Effects</h3>
54+
<ul>
55+
<li>Interrupted Time Series Estimator</li>
56+
<li>G-computation</li>
57+
<li>Double machine Learning</li>
58+
</ul>
59+
60+
<h3>Individualized Treatment Effect (CATE) Estimators</h3>
61+
<ul>
62+
<li>S-learner</li>
63+
<li>T-learner</li>
64+
<li>X-learner</li>
65+
<li>R-learner</li>
66+
<li>Doubly Robust Estimator</li>
67+
</ul>
68+
69+
<h2>Features</h2>
8270
<ul>
8371
<li>Estimate a causal effect, get a summary, and validate assumptions in just four lines of code</li>
8472
<li>Bagging improves performance and reduces variance without the need to tune a regularization parameter</li>
@@ -87,25 +75,28 @@ single extreme learning machines.
8775
<li>Most inference and validation tests do not assume functional or distributional forms</li>
8876
<li>Implements the latest techniques form statistics, econometrics, and biostatistics</li>
8977
<li>Works out of the box with arrays or any data structure that implements the Tables.jl interface</li>
78+
<li>Works out of the box with AbstractArrays or any data structure that implements the Tables.jl interface</li>
79+
<li>Works with CuArrays, ROCArrays, and any other GPU-specific arrays that are AbstractArrays</li>
80+
<li>CausalELM is lightweight—its only dependency is Tables.jl</li>
9081
<li>Codebase is high-quality, well tested, and regularly updated</li>
9182
</ul>
9283

9384
<h2>What's New?</h2>
9485
<ul>
95-
<li>Now includes doubly robust estimator for CATE estimation</li>
96-
<li>All estimators now implement bagging to reduce predictive performance and reduce variance</li>
97-
<li>Counterfactual consistency validation simulates more realistic violations of the counterfactual consistency assumption</li>
86+
<li>See the JuliaCon 2024 CausalELM demonstration <a href="https://www.youtube.com/watch?v=hh_cyj8feu8&t=26s">here.
87+
<li>Includes support for GPU-specific arrays and data structures that implement the Tables.jl API<li>
88+
<li>Only performs randomization inference when the inference argument is set to true in summarize methods</li>
89+
<li>Summaries support calculating marginal effects and confidence intervals</li>
90+
<li>Randomization inference now uses multithreading</li>
91+
<li>Refactored code to be easier to extend and understand</li>
9892
<li>Uses a simple heuristic to choose the number of neurons, which reduces training time and still works well in practice</li>
9993
<li>Probability clipping for classifier predictions and residuals is no longer necessary due to the bagging procedure</li>
100-
<li>CausalELM talk has been accepted to JuliaCon 2024!</li>
10194
</ul>
10295

10396
<h2>What's Next?</h2>
10497
<p>
105-
Newer versions of CausalELM will hopefully support using GPUs and provide interpretations of
106-
the results of calling validate on a model that has been estimated. In addition, some
107-
estimators will also support using instrumental variables. However, these priorities could
108-
also change depending on feedback recieved at JuliaCon.
98+
Efforts for the next version of CausalELM will focus on providing interpreteations for the results of callin validate as well
99+
as fixing any bugs and eliciting feedback.
109100
</p>
110101

111102
<h2>Disclaimer</h2>

docs/src/api.md

+1
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,5 @@ CausalELM.clip_if_binary
112112
CausalELM.@model_config
113113
CausalELM.@standard_input_data
114114
CausalELM.generate_folds
115+
CausalELM.convert_if_table
115116
```

docs/src/guide/doublemachinelearning.md

+10-9
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ the residuals from the first stage models.
1616

1717
## Step 1: Initialize a Model
1818
The DoubleMachineLearning constructor takes at least three arguments—covariates, a
19-
treatment statuses, and outcomes, all of which may be either an array or any struct that
20-
implements the Tables.jl interface (e.g. DataFrames). This estimator supports binary, count,
21-
or continuous treatments and binary, count, continuous, or time to event outcomes.
19+
treatment statuses, and outcomes, all of which may be either an AbstractArray or any struct
20+
that implements the Tables.jl interface (e.g. DataFrames). This estimator supports binary,
21+
count, or continuous treatments and binary, count, continuous, or time to event outcomes.
2222

2323
!!! note
2424
Non-binary categorical outcomes are treated as continuous.
@@ -28,8 +28,8 @@ or continuous treatments and binary, count, continuous, or time to event outcome
2828
extreme learning machines to incorporate in the ensemble, the number of features to
2929
consider for each extreme learning machine, the activation function to use, the number
3030
of observations to bootstrap in each extreme learning machine, and the number of neurons
31-
in each extreme learning machine. These arguments are specified with the `folds`,
32-
`num_machines`, `num_features`, `activation`, `sample_size`, and `num_neurons` keywords.
31+
in each extreme learning machine. These arguments are specified with the folds,
32+
num\_machines, num\_features, activation, sample\_size, and num\_neurons keywords.
3333

3434
```julia
3535
# Create some data with a binary treatment
@@ -53,10 +53,11 @@ estimate_causal_effect!(dml)
5353
We can get a summary of the model by pasing the model to the summarize method.
5454

5555
!!!note
56-
To calculate the p-value and standard error for the treatmetn effect, you can set the
57-
inference argument to false. However, p-values and standard errors are calculated via
58-
randomization inference, which will take a long time. But can be sped up by launching
59-
Julia with a higher number of threads.
56+
To calculate the p-value, standard error, and confidence interval for the treatment
57+
effect, you can set the inference keyword to true. However, these values are calculated
58+
via randomization inference, which will take a long time. This can be greatly sped up by
59+
launching Julia with more threads and lowering the number of iterations using the n
60+
keyword (at the expense of accuracy).
6061

6162
```julia
6263
# Can also use the British spelling

docs/src/guide/estimatorselection.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ given dataset and causal question.
66
| Model | Struct | Causal Estimands | Supported Treatment Types | Supported Outcome Types |
77
|----------------------------------|-----------------------|----------------------------------|---------------------------|------------------------------------------|
88
| Interrupted Time Series Analysis | InterruptedTimeSeries | ATE, Cumulative Treatment Effect | Binary | Continuous, Count[^1], Time to Event |
9-
| G-computation | GComputation | ATE, ATT, ITT | Binary | Binary,Continuous, Time to Event, Count[^1] |
9+
| G-computation | GComputation | ATE, ATT, ITT | Binary | Binary, Continuous, Time to Event, Count[^1] |
1010
| Double Machine Learning | DoubleMachineLearning | ATE | Binary, Count[^1], Continuous | Binary, Count[^1], Continuous, Time to Event |
1111
| S-learning | SLearner | CATE | Binary | Binary, Continuous, Time to Event, Count[^1] |
1212
| T-learning | TLearner | CATE | Binary | Binary, Continuous, Count[^1], Time to Event |

docs/src/guide/gcomputation.md

+10-10
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ steps for using G-computation in CausalELM are below.
1616

1717
## Step 1: Initialize a Model
1818
The GComputation constructor takes at least three arguments: covariates, treatment statuses,
19-
outcomes, all of which can be either an array or any data structure that implements the
20-
Tables.jl interface (e.g. DataFrames). This implementation supports binary treatments and
21-
binary, continuous, time to event, and count outcome variables.
19+
outcomes, all of which can be either an AbstractArray or any data structure that implements
20+
the Tables.jl interface (e.g. DataFrames). This implementation supports binary treatments
21+
and binary, continuous, time to event, and count outcome variables.
2222

2323
!!! note
2424
Non-binary categorical outcomes are treated as continuous.
@@ -29,9 +29,8 @@ binary, continuous, time to event, and count outcome variables.
2929
number of features to consider for each extreme learning machine, the number of
3030
bootstrapped observations to include in each extreme learning machine, and the number of
3131
neurons to use during estimation. These options are specified with the following keyword
32-
arguments: `quantity_of_interest`, `activation`, `temporal`, `num_machines`, `num_feats`,
33-
`sample_size`, and `num_neurons`.
34-
32+
arguments: quantity\_of\_interest, activation, temporal, num\_machines, num\_feats,
33+
sample\_size, and num\_neurons.
3534
```julia
3635
# Create some data with a binary treatment
3736
X, T, Y = rand(1000, 5), [rand()<0.4 for i in 1:1000], rand(1000)
@@ -54,10 +53,11 @@ estimate_causal_effect!(g_computer)
5453
We can get a summary of the model by pasing the model to the summarize method.
5554

5655
!!!note
57-
To calculate the p-value and standard error for the treatment effect, you can set the
58-
inference argument to false. However, p-values and standard errors are calculated via
59-
randomization inference, which will take a long time. But can be sped up by launching
60-
Julia with a higher number of threads.
56+
To calculate the p-value, standard error, and confidence interval for the treatment
57+
effect, you can set the inference keyword to true. However, these values are calculated
58+
via randomization inference, which will take a long time. This can be greatly sped up by
59+
launching Julia with more threads and lowering the number of iterations using the n
60+
keyword (at the expense of accuracy).
6161

6262
```julia
6363
summarize(g_computer)

docs/src/guide/its.md

+8-7
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Estimating an interrupted time series design in CausalELM consists of three step
3131
## Step 1: Initialize an interrupted time series estimator
3232
The InterruptedTimeSeries constructor takes at least four agruments: pre-event covariates,
3333
pre-event outcomes, post-event covariates, and post-event outcomes, all of which can be
34-
either an array or any data structure that implements the Tables.jl interface (e.g.
34+
either an AbstractArray or any data structure that implements the Tables.jl interface (e.g.
3535
DataFrames). The interrupted time series estimator assumes outcomes are either continuous,
3636
count, or time to event variables.
3737

@@ -43,8 +43,8 @@ count, or time to event variables.
4343
machines to use, the number of features to consider for each extreme learning machine,
4444
the number of bootstrapped observations to include in each extreme learning machine, and
4545
the number of neurons to use during estimation. These options are specified with the
46-
following keyword arguments: `activation`, `num_machines`, `num_feats`, `sample_size`,
47-
and `num_neurons`.
46+
following keyword arguments: activation, num\_machines, num\_feats, sample\_size, and
47+
num\_neurons.
4848

4949
```julia
5050
# Generate some data to use
@@ -69,10 +69,11 @@ estimate_causal_effect!(its)
6969
We can get a summary of the model by pasing the model to the summarize method.
7070

7171
!!!note
72-
To calculate the p-value and standard error for the treatment effect, you can set the
73-
inference argument to false. However, p-values and standard errors are calculated via
74-
randomization inference, which will take a long time. But can be sped up by launching
75-
Julia with a higher number of threads.
72+
To calculate the p-value, standard error, and confidence interval for the treatment
73+
effect, you can set the inference keyword to true. However, these values are calculated
74+
via randomization inference, which will take a long time. This can be greatly sped up by
75+
launching Julia with more threads and lowering the number of iterations using the n
76+
keyword (at the expense of accuracy).
7677

7778
```julia
7879
summarize(its)

0 commit comments

Comments
 (0)