Skip to content

Commit 1169985

Browse files
committed
Implemented Tables.jl API and support for GPUs
1 parent d5cabb6 commit 1169985

13 files changed

+131
-64
lines changed

Manifest.toml

+33-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.11.1"
44
manifest_format = "2.0"
5-
project_hash = "ef638d9b7dd3411a6b5c86406ac77e48f19d8d42"
5+
project_hash = "48b0ecc3de09367019241b9866f1be8d1ab8f4cc"
66

77
[[deps.Artifacts]]
88
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
@@ -13,6 +13,21 @@ deps = ["Artifacts", "Libdl"]
1313
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
1414
version = "1.1.1+0"
1515

16+
[[deps.DataAPI]]
17+
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
18+
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
19+
version = "1.16.0"
20+
21+
[[deps.DataValueInterfaces]]
22+
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
23+
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
24+
version = "1.0.0"
25+
26+
[[deps.IteratorInterfaceExtensions]]
27+
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
28+
uuid = "82899510-4779-5014-852e-03e436cf321d"
29+
version = "1.0.0"
30+
1631
[[deps.Libdl]]
1732
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1833
version = "1.11.0"
@@ -27,6 +42,11 @@ deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
2742
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
2843
version = "0.3.27+1"
2944

45+
[[deps.OrderedCollections]]
46+
git-tree-sha1 = "12f1439c4f986bb868acda6ea33ebc78e19b95ad"
47+
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
48+
version = "1.7.0"
49+
3050
[[deps.Random]]
3151
deps = ["SHA"]
3252
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -36,6 +56,18 @@ version = "1.11.0"
3656
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
3757
version = "0.7.0"
3858

59+
[[deps.TableTraits]]
60+
deps = ["IteratorInterfaceExtensions"]
61+
git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
62+
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
63+
version = "1.0.1"
64+
65+
[[deps.Tables]]
66+
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"]
67+
git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297"
68+
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
69+
version = "1.12.0"
70+
3971
[[deps.libblastrampoline_jll]]
4072
deps = ["Artifacts", "Libdl"]
4173
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ version = "0.8.0"
66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
910

1011
[compat]
1112
Aqua = "0.8"
1213
DataFrames = "1.5"
1314
Documenter = "1.2"
1415
LinearAlgebra = "1.8"
1516
Random = "1.8"
17+
Tables = "1.12.0"
1618
Test = "1.8"
1719
julia = "1.8"
1820

docs/src/api.md

+1
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,5 @@ CausalELM.clip_if_binary
112112
CausalELM.@model_config
113113
CausalELM.@standard_input_data
114114
CausalELM.generate_folds
115+
CausalELM.convert_if_table
115116
```

docs/src/guide/doublemachinelearning.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ the residuals from the first stage models.
1616

1717
## Step 1: Initialize a Model
1818
The DoubleMachineLearning constructor takes at least three arguments—covariates, a
19-
treatment statuses, and outcomes, all of which may be either an array or any struct that
20-
implements the Tables.jl interface (e.g. DataFrames). This estimator supports binary, count,
21-
or continuous treatments and binary, count, continuous, or time to event outcomes.
19+
treatment statuses, and outcomes, all of which may be either an AbstractArray or any struct
20+
that implements the Tables.jl interface (e.g. DataFrames). This estimator supports binary,
21+
count, or continuous treatments and binary, count, continuous, or time to event outcomes.
2222

2323
!!! note
2424
Non-binary categorical outcomes are treated as continuous.

docs/src/guide/gcomputation.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ steps for using G-computation in CausalELM are below.
1616

1717
## Step 1: Initialize a Model
1818
The GComputation constructor takes at least three arguments: covariates, treatment statuses,
19-
outcomes, all of which can be either an array or any data structure that implements the
20-
Tables.jl interface (e.g. DataFrames). This implementation supports binary treatments and
21-
binary, continuous, time to event, and count outcome variables.
19+
outcomes, all of which can be either an AbstractArray or any data structure that implements
20+
the Tables.jl interface (e.g. DataFrames). This implementation supports binary treatments
21+
and binary, continuous, time to event, and count outcome variables.
2222

2323
!!! note
2424
Non-binary categorical outcomes are treated as continuous.

docs/src/guide/its.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Estimating an interrupted time series design in CausalELM consists of three step
3131
## Step 1: Initialize an interrupted time series estimator
3232
The InterruptedTimeSeries constructor takes at least four agruments: pre-event covariates,
3333
pre-event outcomes, post-event covariates, and post-event outcomes, all of which can be
34-
either an array or any data structure that implements the Tables.jl interface (e.g.
34+
either an AbstractArray or any data structure that implements the Tables.jl interface (e.g.
3535
DataFrames). The interrupted time series estimator assumes outcomes are either continuous,
3636
count, or time to event variables.
3737

docs/src/guide/metalearners.md

+8-7
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ continuous outcomes.
3030
Kennedy, Edward H. "Towards optimal doubly robust estimation of heterogeneous causal
3131
effects." Electronic Journal of Statistics 17, no. 2 (2023): 3008-3049.
3232

33-
# Initialize a Metalearner
33+
# Step 1: Initialize a Metalearner
3434
S-learners, T-learners, X-learners, R-learners, and doubly robust estimators all take at
3535
least three arguments—covariates, treatment statuses, and outcomes, all of which can be
36-
either an array or any struct that implements the Tables.jl interface (e.g. DataFrames). S,
37-
T, X, and doubly robust learners support binary treatment variables and binary, continuous,
38-
count, or time to event outcomes. The R-learning estimator supports binary, continuous, or
39-
count treatment variables and binary, continuous, count, or time to event outcomes.
36+
either an AbstractArray or any struct that implements the Tables.jl interface (e.g.
37+
DataFrames). S, T, X, and doubly robust learners support binary treatment variables and
38+
binary, continuous, count, or time to event outcomes. The R-learning estimator supports
39+
binary, continuous, or count treatment variables and binary, continuous, count, or time to
40+
event outcomes.
4041

4142
!!! note
4243
Non-binary categorical outcomes are treated as continuous.
@@ -64,7 +65,7 @@ r_learner = RLearner(X, Y, T)
6465
dr_learner = DoublyRobustLearner(X, T, Y)
6566
```
6667

67-
# Estimate the CATE
68+
# Step 2: Estimate the CATE
6869
We can estimate the CATE for all the models by passing them to estimate_causal_effect!.
6970
```julia
7071
estimate_causal_effect!(s_learner)
@@ -74,7 +75,7 @@ estimate_causal_effect!(r_learner)
7475
estimate_causal_effect!(dr_lwarner)
7576
```
7677

77-
# Get a Summary
78+
# Step 3: Get a Summary
7879
We can get a summary of the model by pasing the model to the summarize method.
7980

8081
!!!note

docs/src/index.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ these libraries are:
7575
econometrics, and biostatistics.
7676

7777
### Installation
78-
CausalELM requires Julia version 1.7 or greater and can be installed from the REPL as shown
78+
CausalELM requires Julia version 1.8 or greater and can be installed from the REPL as shown
7979
below.
8080
```julia
8181
using Pkg

docs/src/release_notes.md

+2
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ These release notes adhere to the [keep a changelog](https://keepachangelog.com/
55
### Added
66
* Implemented randomization inference-based confidence intervals [#78](https://github.com/dscolby/CausalELM.jl/issues/78)
77
* Added marginal effects to model summaries [#78](https://github.com/dscolby/CausalELM.jl/issues/78)
8+
* CausalELM models now support any AbstractArray data type, including support for using GPUs with CuArrays or similar structures for Mac, Intel, and AMD hardware[#37](https://github.com/dscolby/CausalELM.jl/issues/37)
89
### Fixed
910
* Removed unnecessary include and using statements
1011
* Slightly sped up the randomization inference implementation and clarified it in the docs [#77](https://github.com/dscolby/CausalELM.jl/issues/77)
1112
* Fixed the randomization inference index selection procedure for interrupted time series estimators
1213
* Inlined certain methods to slightly improve performance [#76](https://github.com/dscolby/CausalELM.jl/issues/76)
14+
* CausalELM models now support any data structure that implements the Tables.jl API, not just DataFrames
1315

1416
## Version [v0.7.0](https://github.com/dscolby/CausalELM.jl/releases/tag/v0.7.0) - 2024-06-22
1517
### Added

src/estimators.jl

+22-17
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@ abstract type CausalEstimator end
77
Initialize an interrupted time series estimator.
88
99
# Arguments
10-
- `X₀::Any`: array or DataFrame of covariates from the pre-treatment period.
11-
- `Y₁::Any`: array or DataFrame of outcomes from the pre-treatment period.
12-
- `X₁::Any`: array or DataFrame of covariates from the post-treatment period.
13-
- `Y₁::Any`: array or DataFrame of outcomes from the post-treatment period.
10+
- `X₀::Any`: AbstractArray or Tables.jl API compliant data structure of covariates from the
11+
pre-treatment period.
12+
- `Y₁::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes from the
13+
pre-treatment period.
14+
- `X₁::Any`: AbstractArray or Tables.jl API compliant data structure of covariates from the
15+
post-treatment period.
16+
- `Y₁::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes from the
17+
post-treatment period.
1418
1519
# Keywords
1620
- `activation::Function=swish`: activation function to use.
@@ -44,10 +48,10 @@ julia> m3 = InterruptedTimeSeries(x₀_df, y₀_df, x₁_df, y₁_df)
4448
```
4549
"""
4650
mutable struct InterruptedTimeSeries
47-
X₀::Array{Float64}
48-
Y₀::Array{Float64}
49-
X₁::Array{Float64}
50-
Y₁::Array{Float64}
51+
X₀::AbstractArray{<: Real}
52+
Y₀::AbstractArray{<: Real}
53+
X₁::AbstractArray{<: Real}
54+
Y₁::AbstractArray{<: Real}
5155
marginal_effect::Float64
5256
@model_config individual_effect
5357
end
@@ -65,7 +69,7 @@ function InterruptedTimeSeries(
6569
autoregression::Bool=true,
6670
)
6771
# Convert to arrays
68-
X₀, X₁, Y₀, Y₁ = Matrix{Float64}(X₀), Matrix{Float64}(X₁), Y₀[:, 1], Y₁[:, 1]
72+
X₀, X₁, Y₀, Y₁ = convert_if_table.((X₀, X₁, Y₀, Y₁))
6973

7074
# Add autoregressive term
7175
X₀ = ifelse(autoregression == true, reduce(hcat, (X₀, moving_average(Y₀))), X₀)
@@ -97,9 +101,9 @@ end
97101
Initialize a G-Computation estimator.
98102
99103
# Arguments
100-
- `X::Any`: array or DataFrame of covariates.
101-
- `T::Any`: vector or DataFrame of treatment statuses.
102-
- `Y::Any`: array or DataFrame of outcomes.
104+
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates.
105+
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
106+
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
103107
104108
# Keywords
105109
- `quantity_of_interest::String`: ATE for average treatment effect or ATT for average
@@ -159,7 +163,7 @@ mutable struct GComputation <: CausalEstimator
159163
end
160164

161165
# Convert to arrays
162-
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
166+
X, T, Y = convert_if_table.((X, T, Y))
163167

164168
task = var_type(Y) isa Binary ? "classification" : "regression"
165169

@@ -187,9 +191,10 @@ end
187191
Initialize a double machine learning estimator with cross fitting.
188192
189193
# Arguments
190-
- `X::Any`: array or DataFrame of covariates of interest.
191-
- `T::Any`: vector or DataFrame of treatment statuses.
192-
- `Y::Any`: array or DataFrame of outcomes.
194+
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates of
195+
interest.
196+
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
197+
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
193198
194199
# Keywords
195200
- `activation::Function=swish`: activation function to use.
@@ -240,7 +245,7 @@ function DoubleMachineLearning(
240245
folds::Integer=5,
241246
)
242247
# Convert to arrays
243-
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
248+
X, T, Y = convert_if_table.((X, T, Y))
244249

245250
# Shuffle data with random indices
246251
indices = shuffle(1:length(Y))

src/metalearners.jl

+22-20
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ abstract type Metalearner end
77
Initialize a S-Learner.
88
99
# Arguments
10-
- `X::Any`: an array or DataFrame of covariates.
11-
- `T::Any`: an vector or DataFrame of treatment statuses.
12-
- `Y::Any`: an array or DataFrame of outcomes.
10+
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates.
11+
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
12+
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
1313
1414
# Keywords
1515
- `activation::Function=swish`: the activation function to use.
@@ -60,7 +60,7 @@ mutable struct SLearner <: Metalearner
6060
)
6161

6262
# Convert to arrays
63-
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
63+
X, T, Y = convert_if_table.((X, T, Y))
6464

6565
task = var_type(Y) isa Binary ? "classification" : "regression"
6666

@@ -88,9 +88,9 @@ end
8888
Initialize a T-Learner.
8989
9090
# Arguments
91-
- `X::Any`: an array or DataFrame of covariates.
92-
- `T::Any`: an vector or DataFrame of treatment statuses.
93-
- `Y::Any`: an array or DataFrame of outcomes.
91+
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates.
92+
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
93+
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
9494
9595
# Keywords
9696
- `activation::Function=swish`: the activation function to use.
@@ -140,7 +140,7 @@ mutable struct TLearner <: Metalearner
140140
num_neurons::Integer=round(Int, log10(size(X, 1)) * size(X, 2)),
141141
)
142142
# Convert to arrays
143-
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
143+
X, T, Y = convert_if_table.((X, T, Y))
144144

145145
task = var_type(Y) isa Binary ? "classification" : "regression"
146146

@@ -168,9 +168,9 @@ end
168168
Initialize an X-Learner.
169169
170170
# Arguments
171-
- `X::Any`: an array or DataFrame of covariates.
172-
- `T::Any`: an vector or DataFrame of treatment statuses.
173-
- `Y::Any`: an array or DataFrame of outcomes.
171+
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates.
172+
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
173+
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
174174
175175
# Keywords
176176
- `activation::Function=swish`: the activation function to use.
@@ -221,7 +221,7 @@ mutable struct XLearner <: Metalearner
221221
num_neurons::Integer=round(Int, log10(size(X, 1)) * size(X, 2)),
222222
)
223223
# Convert to arrays
224-
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
224+
X, T, Y = convert_if_table.((X, T, Y))
225225

226226
task = var_type(Y) isa Binary ? "classification" : "regression"
227227

@@ -249,9 +249,10 @@ end
249249
Initialize an R-Learner.
250250
251251
# Arguments
252-
- `X::Any`: an array or DataFrame of covariates of interest.
253-
- `T::Any`: an vector or DataFrame of treatment statuses.
254-
- `Y::Any`: an array or DataFrame of outcomes.
252+
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates of
253+
interest.
254+
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
255+
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
255256
256257
# Keywords
257258
- `activation::Function=swish`: the activation function to use.
@@ -301,7 +302,7 @@ function RLearner(
301302
)
302303

303304
# Convert to arrays
304-
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
305+
X, T, Y = convert_if_table.((X, T, Y))
305306

306307
# Shuffle data with random indices
307308
indices = shuffle(1:length(Y))
@@ -333,9 +334,10 @@ end
333334
Initialize a doubly robust CATE estimator.
334335
335336
# Arguments
336-
- `X::Any`: an array or DataFrame of covariates of interest.
337-
- `T::Any`: an vector or DataFrame of treatment statuses.
338-
- `Y::Any`: an array or DataFrame of outcomes.
337+
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates of
338+
interest.
339+
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
340+
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
339341
340342
# Keywords
341343
- `activation::Function=swish`: the activation function to use.
@@ -386,7 +388,7 @@ function DoublyRobustLearner(
386388
num_neurons::Integer=round(Int, log10(size(X, 1)) * size(X, 2)),
387389
)
388390
# Convert to arrays
389-
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
391+
X, T, Y = convert_if_table.((X, T, Y))
390392

391393
# Shuffle data with random indices
392394
indices = shuffle(1:length(Y))

0 commit comments

Comments
 (0)