Skip to content

Commit 1169985

Browse files
committed
Implemented Tables.jl API and support for GPUs
1 parent d5cabb6 commit 1169985

13 files changed

+131
-64
lines changed

Manifest.toml

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.11.1"
44
manifest_format = "2.0"
5-
project_hash = "ef638d9b7dd3411a6b5c86406ac77e48f19d8d42"
5+
project_hash = "48b0ecc3de09367019241b9866f1be8d1ab8f4cc"
66

77
[[deps.Artifacts]]
88
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
@@ -13,6 +13,21 @@ deps = ["Artifacts", "Libdl"]
1313
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
1414
version = "1.1.1+0"
1515

16+
[[deps.DataAPI]]
17+
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
18+
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
19+
version = "1.16.0"
20+
21+
[[deps.DataValueInterfaces]]
22+
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
23+
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
24+
version = "1.0.0"
25+
26+
[[deps.IteratorInterfaceExtensions]]
27+
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
28+
uuid = "82899510-4779-5014-852e-03e436cf321d"
29+
version = "1.0.0"
30+
1631
[[deps.Libdl]]
1732
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1833
version = "1.11.0"
@@ -27,6 +42,11 @@ deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
2742
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
2843
version = "0.3.27+1"
2944

45+
[[deps.OrderedCollections]]
46+
git-tree-sha1 = "12f1439c4f986bb868acda6ea33ebc78e19b95ad"
47+
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
48+
version = "1.7.0"
49+
3050
[[deps.Random]]
3151
deps = ["SHA"]
3252
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -36,6 +56,18 @@ version = "1.11.0"
3656
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
3757
version = "0.7.0"
3858

59+
[[deps.TableTraits]]
60+
deps = ["IteratorInterfaceExtensions"]
61+
git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
62+
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
63+
version = "1.0.1"
64+
65+
[[deps.Tables]]
66+
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"]
67+
git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297"
68+
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
69+
version = "1.12.0"
70+
3971
[[deps.libblastrampoline_jll]]
4072
deps = ["Artifacts", "Libdl"]
4173
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ version = "0.8.0"
66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
910

1011
[compat]
1112
Aqua = "0.8"
1213
DataFrames = "1.5"
1314
Documenter = "1.2"
1415
LinearAlgebra = "1.8"
1516
Random = "1.8"
17+
Tables = "1.12.0"
1618
Test = "1.8"
1719
julia = "1.8"
1820

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,5 @@ CausalELM.clip_if_binary
112112
CausalELM.@model_config
113113
CausalELM.@standard_input_data
114114
CausalELM.generate_folds
115+
CausalELM.convert_if_table
115116
```

docs/src/guide/doublemachinelearning.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ the residuals from the first stage models.
1616

1717
## Step 1: Initialize a Model
1818
The DoubleMachineLearning constructor takes at least three arguments—covariates, a
19-
treatment statuses, and outcomes, all of which may be either an array or any struct that
20-
implements the Tables.jl interface (e.g. DataFrames). This estimator supports binary, count,
21-
or continuous treatments and binary, count, continuous, or time to event outcomes.
19+
treatment statuses, and outcomes, all of which may be either an AbstractArray or any struct
20+
that implements the Tables.jl interface (e.g. DataFrames). This estimator supports binary,
21+
count, or continuous treatments and binary, count, continuous, or time to event outcomes.
2222

2323
!!! note
2424
Non-binary categorical outcomes are treated as continuous.

docs/src/guide/gcomputation.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ steps for using G-computation in CausalELM are below.
1616

1717
## Step 1: Initialize a Model
1818
The GComputation constructor takes at least three arguments: covariates, treatment statuses,
19-
outcomes, all of which can be either an array or any data structure that implements the
20-
Tables.jl interface (e.g. DataFrames). This implementation supports binary treatments and
21-
binary, continuous, time to event, and count outcome variables.
19+
outcomes, all of which can be either an AbstractArray or any data structure that implements
20+
the Tables.jl interface (e.g. DataFrames). This implementation supports binary treatments
21+
and binary, continuous, time to event, and count outcome variables.
2222

2323
!!! note
2424
Non-binary categorical outcomes are treated as continuous.

docs/src/guide/its.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Estimating an interrupted time series design in CausalELM consists of three step
3131
## Step 1: Initialize an interrupted time series estimator
3232
The InterruptedTimeSeries constructor takes at least four agruments: pre-event covariates,
3333
pre-event outcomes, post-event covariates, and post-event outcomes, all of which can be
34-
either an array or any data structure that implements the Tables.jl interface (e.g.
34+
either an AbstractArray or any data structure that implements the Tables.jl interface (e.g.
3535
DataFrames). The interrupted time series estimator assumes outcomes are either continuous,
3636
count, or time to event variables.
3737

docs/src/guide/metalearners.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ continuous outcomes.
3030
Kennedy, Edward H. "Towards optimal doubly robust estimation of heterogeneous causal
3131
effects." Electronic Journal of Statistics 17, no. 2 (2023): 3008-3049.
3232

33-
# Initialize a Metalearner
33+
# Step 1: Initialize a Metalearner
3434
S-learners, T-learners, X-learners, R-learners, and doubly robust estimators all take at
3535
least three arguments—covariates, treatment statuses, and outcomes, all of which can be
36-
either an array or any struct that implements the Tables.jl interface (e.g. DataFrames). S,
37-
T, X, and doubly robust learners support binary treatment variables and binary, continuous,
38-
count, or time to event outcomes. The R-learning estimator supports binary, continuous, or
39-
count treatment variables and binary, continuous, count, or time to event outcomes.
36+
either an AbstractArray or any struct that implements the Tables.jl interface (e.g.
37+
DataFrames). S, T, X, and doubly robust learners support binary treatment variables and
38+
binary, continuous, count, or time to event outcomes. The R-learning estimator supports
39+
binary, continuous, or count treatment variables and binary, continuous, count, or time to
40+
event outcomes.
4041

4142
!!! note
4243
Non-binary categorical outcomes are treated as continuous.
@@ -64,7 +65,7 @@ r_learner = RLearner(X, Y, T)
6465
dr_learner = DoublyRobustLearner(X, T, Y)
6566
```
6667

67-
# Estimate the CATE
68+
# Step 2: Estimate the CATE
6869
We can estimate the CATE for all the models by passing them to estimate_causal_effect!.
6970
```julia
7071
estimate_causal_effect!(s_learner)
@@ -74,7 +75,7 @@ estimate_causal_effect!(r_learner)
7475
estimate_causal_effect!(dr_lwarner)
7576
```
7677

77-
# Get a Summary
78+
# Step 3: Get a Summary
7879
We can get a summary of the model by pasing the model to the summarize method.
7980

8081
!!!note

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ these libraries are:
7575
econometrics, and biostatistics.
7676

7777
### Installation
78-
CausalELM requires Julia version 1.7 or greater and can be installed from the REPL as shown
78+
CausalELM requires Julia version 1.8 or greater and can be installed from the REPL as shown
7979
below.
8080
```julia
8181
using Pkg

docs/src/release_notes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ These release notes adhere to the [keep a changelog](https://keepachangelog.com/
55
### Added
66
* Implemented randomization inference-based confidence intervals [#78](https://github.com/dscolby/CausalELM.jl/issues/78)
77
* Added marginal effects to model summaries [#78](https://github.com/dscolby/CausalELM.jl/issues/78)
8+
* CausalELM models now support any AbstractArray data type, including support for using GPUs with CuArrays or similar structures for Mac, Intel, and AMD hardware[#37](https://github.com/dscolby/CausalELM.jl/issues/37)
89
### Fixed
910
* Removed unnecessary include and using statements
1011
* Slightly sped up the randomization inference implementation and clarified it in the docs [#77](https://github.com/dscolby/CausalELM.jl/issues/77)
1112
* Fixed the randomization inference index selection procedure for interrupted time series estimators
1213
* Inlined certain methods to slightly improve performance [#76](https://github.com/dscolby/CausalELM.jl/issues/76)
14+
* CausalELM models now support any data structure that implements the Tables.jl API, not just DataFrames
1315

1416
## Version [v0.7.0](https://github.com/dscolby/CausalELM.jl/releases/tag/v0.7.0) - 2024-06-22
1517
### Added

src/estimators.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@ abstract type CausalEstimator end
77
Initialize an interrupted time series estimator.
88
99
# Arguments
10-
- `X₀::Any`: array or DataFrame of covariates from the pre-treatment period.
11-
- `Y₁::Any`: array or DataFrame of outcomes from the pre-treatment period.
12-
- `X₁::Any`: array or DataFrame of covariates from the post-treatment period.
13-
- `Y₁::Any`: array or DataFrame of outcomes from the post-treatment period.
10+
- `X₀::Any`: AbstractArray or Tables.jl API compliant data structure of covariates from the
11+
pre-treatment period.
12+
- `Y₁::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes from the
13+
pre-treatment period.
14+
- `X₁::Any`: AbstractArray or Tables.jl API compliant data structure of covariates from the
15+
post-treatment period.
16+
- `Y₁::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes from the
17+
post-treatment period.
1418
1519
# Keywords
1620
- `activation::Function=swish`: activation function to use.
@@ -44,10 +48,10 @@ julia> m3 = InterruptedTimeSeries(x₀_df, y₀_df, x₁_df, y₁_df)
4448
```
4549
"""
4650
mutable struct InterruptedTimeSeries
47-
X₀::Array{Float64}
48-
Y₀::Array{Float64}
49-
X₁::Array{Float64}
50-
Y₁::Array{Float64}
51+
X₀::AbstractArray{<: Real}
52+
Y₀::AbstractArray{<: Real}
53+
X₁::AbstractArray{<: Real}
54+
Y₁::AbstractArray{<: Real}
5155
marginal_effect::Float64
5256
@model_config individual_effect
5357
end
@@ -65,7 +69,7 @@ function InterruptedTimeSeries(
6569
autoregression::Bool=true,
6670
)
6771
# Convert to arrays
68-
X₀, X₁, Y₀, Y₁ = Matrix{Float64}(X₀), Matrix{Float64}(X₁), Y₀[:, 1], Y₁[:, 1]
72+
X₀, X₁, Y₀, Y₁ = convert_if_table.((X₀, X₁, Y₀, Y₁))
6973

7074
# Add autoregressive term
7175
X₀ = ifelse(autoregression == true, reduce(hcat, (X₀, moving_average(Y₀))), X₀)
@@ -97,9 +101,9 @@ end
97101
Initialize a G-Computation estimator.
98102
99103
# Arguments
100-
- `X::Any`: array or DataFrame of covariates.
101-
- `T::Any`: vector or DataFrame of treatment statuses.
102-
- `Y::Any`: array or DataFrame of outcomes.
104+
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates.
105+
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
106+
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
103107
104108
# Keywords
105109
- `quantity_of_interest::String`: ATE for average treatment effect or ATT for average
@@ -159,7 +163,7 @@ mutable struct GComputation <: CausalEstimator
159163
end
160164

161165
# Convert to arrays
162-
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
166+
X, T, Y = convert_if_table.((X, T, Y))
163167

164168
task = var_type(Y) isa Binary ? "classification" : "regression"
165169

@@ -187,9 +191,10 @@ end
187191
Initialize a double machine learning estimator with cross fitting.
188192
189193
# Arguments
190-
- `X::Any`: array or DataFrame of covariates of interest.
191-
- `T::Any`: vector or DataFrame of treatment statuses.
192-
- `Y::Any`: array or DataFrame of outcomes.
194+
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates of
195+
interest.
196+
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
197+
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
193198
194199
# Keywords
195200
- `activation::Function=swish`: activation function to use.
@@ -240,7 +245,7 @@ function DoubleMachineLearning(
240245
folds::Integer=5,
241246
)
242247
# Convert to arrays
243-
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
248+
X, T, Y = convert_if_table.((X, T, Y))
244249

245250
# Shuffle data with random indices
246251
indices = shuffle(1:length(Y))

0 commit comments

Comments
 (0)