Skip to content

Ol ctmle template #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 50 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
dfb1e75
add WIP collaborative TMLE
olivierlabayle Apr 21, 2025
5934385
add collaborative initial loop
olivierlabayle Apr 22, 2025
901202f
extract two methods
olivierlabayle Apr 23, 2025
ca5dde1
update some docstrings
olivierlabayle Apr 23, 2025
0ed758a
back to where I started but hopefully better
olivierlabayle Apr 23, 2025
b539333
algo more or less in place but not working due to cache management
olivierlabayle Apr 23, 2025
8be66fc
rename candidate nt and add some test
olivierlabayle Apr 24, 2025
b050cb0
isolate caching retrieval
olivierlabayle Apr 24, 2025
7e1099f
fix tests
olivierlabayle Apr 24, 2025
1269cd7
tryx fix docs
olivierlabayle Apr 24, 2025
1cfdb22
move some code into utils
olivierlabayle Apr 24, 2025
89180a3
possibly working ctmle, need to clean and test more
olivierlabayle Apr 25, 2025
1f1a787
add some more content and maybe fix docs
olivierlabayle Apr 25, 2025
514a2a8
replace MLJ with MLBase in examples
olivierlabayle Apr 26, 2025
3e91e77
update interaction simulation docs
olivierlabayle Apr 26, 2025
429165b
try fix docs again
olivierlabayle Apr 26, 2025
b6424f3
fix some imports
olivierlabayle Apr 26, 2025
818fe5e
fix super learning example
olivierlabayle Apr 27, 2025
21c17da
add test for factorial estimand
olivierlabayle Apr 29, 2025
8730d6f
update interaction example
olivierlabayle Apr 29, 2025
3dbdf93
rename estimator structures
olivierlabayle Apr 29, 2025
7711840
refactor some code
olivierlabayle Apr 29, 2025
0efb8ce
add more tests
olivierlabayle Apr 29, 2025
28529c5
update tests
olivierlabayle Apr 29, 2025
02ae207
extract collaborative loop
olivierlabayle Apr 29, 2025
9f098a0
add more tests
olivierlabayle Apr 29, 2025
2a53da6
fix call to collaborative estimator
olivierlabayle Apr 29, 2025
7d8dc9b
move some code
olivierlabayle Apr 29, 2025
81bee6d
add ctmle to DR list
olivierlabayle Apr 29, 2025
6781a51
make wTMLE the default
olivierlabayle Apr 29, 2025
3e1e385
fix some prints in composition test
olivierlabayle Apr 29, 2025
d2aef48
reove some logs
olivierlabayle Apr 29, 2025
45bb499
add DataFrames.jl dependency
olivierlabayle Apr 30, 2025
ba1e79c
add CausalStratifiedCV
olivierlabayle Apr 30, 2025
c9b66f0
make default resampling the treatment stratified one
olivierlabayle Apr 30, 2025
51fc04d
fix early stopping bug and change some logging behaviour
olivierlabayle May 1, 2025
9f08b03
replace repeat with fill
olivierlabayle May 2, 2025
81569c5
add test for sleectols and stop caching fluctuations
olivierlabayle May 5, 2025
a503504
remove unused key function
olivierlabayle May 5, 2025
1a371bf
remove column COLLABORATIVE_INTERCEPT, simplify some code and make Da…
olivierlabayle May 5, 2025
056845e
rename targeting estimators
olivierlabayle May 5, 2025
2f7ec3e
remove unused retrieve_models
olivierlabayle May 5, 2025
6cab65e
simplify some more code
olivierlabayle May 5, 2025
ebde29e
prepare ground for multithreading
olivierlabayle May 6, 2025
7b4d6de
replace sum_loss by mean_loss
olivierlabayle May 6, 2025
044d2a7
add multithreading acceleration
olivierlabayle May 6, 2025
5a6d5f7
make example use dataframe
olivierlabayle May 6, 2025
8c8ca35
rsion
olivierlabayle May 6, 2025
be8734c
fix missing acceleration kw
olivierlabayle May 6, 2025
3ce5d1b
fix comment
olivierlabayle May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
name = "TMLE"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
authors = ["Olivier Labayle"]
version = "0.18.1"
version = "0.19.0"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
Expand All @@ -21,6 +23,7 @@ Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Expand All @@ -41,20 +44,23 @@ AutoHashEquals = "2.1.0"
CategoricalArrays = "0.10"
CausalTables = "1.2.1"
Combinatorics = "1.0.2"
ComputationalResources = "0.3.2"
DataFrames = "1.7.0"
DifferentiationInterface = "0.6.43"
Distributions = "0.25"
GLM = "1.8.2"
Graphs = "1.8"
HypothesisTests = "0.10, 0.11"
JSON = "0.21.4"
LogExpFunctions = "0.3"
MLJBase = "1.0.1"
MLJBase = "1"
MLJGLMInterface = "0.3.4"
MLJModels = "0.15, 0.16, 0.17"
MetaGraphsNext = "0.7"
Missings = "1.0"
OrderedCollections = "1.6.3"
SplitApplyCombine = "1.2.2"
StatisticalMeasures = "0.2"
TableOperations = "1.2"
Tables = "1.6"
YAML = "0.4.9"
Expand Down
8 changes: 6 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
TMLE = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -30,8 +33,9 @@ DataFrames = "1.5"
Distributions = "0.25"
Documenter = "1"
Literate = "2.13"
MLJ = "0.20.1"
MLJBase = "1.8"
MLJLinearModels = "0.10.0"
MLJModels = "0.17"
MLJXGBoostInterface = "0.3.8"
NearestNeighborModels = "0.2"
StableRNGs = "1.0"
3 changes: 3 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using TMLE
using Documenter
using Literate
using Logging

Logging.disable_logging(Logging.Warn)

DocMeta.setdocmeta!(TMLE, :DocTestSetup, :(using TMLE); recursive=true)

Expand Down
9 changes: 5 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function tmle_estimates(data)
Q_binary=MLJLinearModels.LogisticClassifier(),
G=MLJLinearModels.LogisticClassifier()
)
Ψ̂ = TMLEE(models=models, weighted=true)
Ψ̂ = Tmle(models=models, weighted=true)
Ψ = ATE(;
outcome=:Y,
treatment_values=(T=(case=true, control = false),),
Expand Down Expand Up @@ -80,7 +80,7 @@ end
function plot(β̂s_confounded, β̂s_unconfounded, tmles_confounded, tmles_unconfounded, β, ATE₀)
fig = Figure(size=(1000, 800))
ax = Axis(fig[1, 1], title="Distribution of Linear Model's and TMLE's Estimates", yticks=(1:2, ["Confounded", "Unconfounded"]))
labels = vcat(repeat(["Confounded"], length(β̂s_confounded)), repeat(["Unconfounded"], length(β̂s_unconfounded)))
labels = vcat(fill("Confounded", length(β̂s_confounded)), fill("Unconfounded", length(β̂s_unconfounded)))
rainclouds!(ax, labels, vcat(β̂s_confounded, β̂s_unconfounded), orientation = :horizontal, color=(:blue, 0.5))
rainclouds!(ax, labels, vcat(tmles_confounded, tmles_unconfounded), orientation = :horizontal, color=(:orange, 0.5))
vlines!(ax, ATE₀, label="ATE", color=:green)
Expand Down Expand Up @@ -142,13 +142,14 @@ using Random
using CategoricalArrays
using MLJLinearModels
using LogExpFunctions
using DataFrames

rng = StableRNG(123)
n = 100
W = rand(rng, Uniform(), n)
T = rand(rng, Uniform(), n) .< logistic.(1 .- 2W)
Y = 1 .+ 3T .- T.*W .+ rand(rng, Normal(0, 0.01), n)
dataset = (Y=Y, T=categorical(T), W=W)
dataset = DataFrame(Y=Y, T=categorical(T), W=W)
nothing # hide
```

Expand All @@ -167,7 +168,7 @@ The Average Treatment Effect of ``T`` on ``Y`` confounded by ``W`` is defined as
### 3. An estimator: here a Targeted Maximum Likelihood Estimator (TMLE)

```@example quick-start
tmle = TMLEE()
tmle = Tmle()
result, _ = tmle(Ψ, dataset, verbosity=0);
result
```
Expand Down
2 changes: 1 addition & 1 deletion docs/src/integrations.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ ct = rand(scm, 100)

# Define a causal estimand and estimate it using TMLE
Ψ = ATE(outcome = :Y, treatment_values = (A = (case = 1, control = 0),))
estimator = TMLEE()
estimator = Tmle()
Ψ̂, cache = estimator(Ψ, ct; verbosity=0)
```

Expand Down
20 changes: 10 additions & 10 deletions docs/src/user_guide/estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using CategoricalArrays
using TMLE
using LogExpFunctions
using MLJLinearModels
using MLJ
using MLJBase

function make_dataset(;n=1000)
rng = StableRNG(123)
Expand Down Expand Up @@ -53,8 +53,8 @@ scm = SCM([

Once a statistical estimand has been defined, we can proceed with estimation. There are two semi-parametric efficient estimators in TMLE.jl:

- The Targeted Maximum-Likelihood Estimator (`TMLEE`)
- The One-Step Estimator (`OSE`)
- The Targeted Maximum-Likelihood Estimator (`Tmle`)
- The One-Step Estimator (`Ose`)

While they have similar asymptotic properties, their finite sample performance may be different. They also have a very distinguishing feature, the TMLE is a plugin estimator, which means it respects the natural bounds of the estimand of interest. In contrast, the OSE may in theory report values outside these bounds. In practice, this is not often the case and the estimand of interest may not impose any restriction on its domain.

Expand All @@ -67,7 +67,7 @@ Drawing from the example dataset and `SCM` from the Walk Through section, we can
treatment_confounders=(T₁=[:W₁₁, :W₁₂],),
outcome_extra_covariates=[:C]
)
tmle = TMLEE()
tmle = Tmle()
result₁, cache = tmle(Ψ₁, dataset);
result₁
nothing # hide
Expand All @@ -87,7 +87,7 @@ Both the TMLE and OSE are asymptotically linear estimators, standard Z/T tests f
tmle_test_result₁ = pvalue(OneSampleTTest(result₁))
```

Let us now turn to the Average Treatment Effect of `T₂`, we will estimate it with a `OSE`:
Let us now turn to the Average Treatment Effect of `T₂`, we will estimate it with a `Ose`:

```@example estimation
Ψ₂ = ATE(
Expand All @@ -96,7 +96,7 @@ Let us now turn to the Average Treatment Effect of `T₂`, we will estimate it w
treatment_confounders=(T₂=[:W₂₁, :W₂₂],),
outcome_extra_covariates=[:C]
)
ose = OSE()
ose = Ose()
result₂, cache = ose(Ψ₂, dataset;cache=cache);
result₂
nothing # hide
Expand All @@ -121,7 +121,7 @@ models = default_models(
Q_continuous = xgboost_regressor,
G = xgboost_classifier
)
tmle_gboost = TMLEE(models=models)
tmle_gboost = Tmle(models=models)
```

The advantage of using `default_models` is that it will automatically prepend each model with a [ContinuousEncoder](https://alan-turing-institute.github.io/MLJ.jl/dev/transformers/#MLJModels.ContinuousEncoder) to make sure the correct types are passed to the downstream models.
Expand All @@ -144,7 +144,7 @@ models = default_models( # For all non-specified variables use the following def
# Unspecified G defaults to Logistic Regression
)

tmle_custom = TMLEE(models=models)
tmle_custom = Tmle(models=models)
```

Notice that `with_encoder` is simply a shorthand to construct a pipeline with a `ContinuousEncoder` and that the resulting `models` is simply a `Dict`.
Expand All @@ -154,13 +154,13 @@ Notice that `with_encoder` is simply a shorthand to construct a pipeline with a
Canonical TMLE/OSE are essentially using the dataset twice, once for the estimation of the nuisance functions and once for the estimation of the parameter of interest. This means that there is a risk of over-fitting and residual bias ([see here](https://arxiv.org/abs/2203.06469) for some discussion). One way to address this limitation is to use a technique called sample-splitting / cross-validation. In order to activate the sample-splitting mode, simply provide a `MLJ.ResamplingStrategy` using the `resampling` keyword argument:

```@example estimation
TMLEE(resampling=StratifiedCV());
Tmle(resampling=StratifiedCV());
```

or

```julia
OSE(resampling=StratifiedCV(nfolds=3));
Ose(resampling=StratifiedCV(nfolds=3));
```

There are some practical considerations
Expand Down
4 changes: 2 additions & 2 deletions docs/src/walk_through.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ Alternatively, you can also directly define the statistical parameters (see [Est
Then each parameter can be estimated by building an estimator (which is simply a function) and evaluating it on data. For illustration, we will keep the models simple. We define a Targeted Maximum Likelihood Estimator:

```@example walk-through
tmle = TMLEE()
tmle = Tmle()
```

Because we haven't identified the `cm` causal estimand yet, we need to provide the `scm` as well to the estimator:
Expand All @@ -148,7 +148,7 @@ result
Statistical Estimands can be estimated without a ``SCM``, let's use the One-Step estimator:

```@example walk-through
ose = OSE()
ose = Ose()
result, cache = ose(statistical_aie, dataset)
result
```
Expand Down
4 changes: 2 additions & 2 deletions examples/double_robustness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Y \sim \mathcal{N}(e^{1 - 10 \cdot T + W}, 1)
=#

using TMLE
using MLJ
using MLJBase
using Distributions
using StableRNGs
using LogExpFunctions
Expand Down Expand Up @@ -161,7 +161,7 @@ function tmle_inference(data)
:Y => with_encoder(LinearRegressor()),
:Tcat => with_encoder(LinearBinaryClassifier())
)
tmle = TMLEE(models=models)
tmle = Tmle(models=models)
result, _ = tmle(Ψ, data; verbosity=0)
lb, ub = confint(OneSampleTTest(result))
return (TMLE.estimate(result), lb, ub)
Expand Down
Loading
Loading