Skip to content

Commit 8cd568b

Browse files
committed
use clone, joblib, warning on degenerate samples, n_bootstraps=200
1 parent f258cdb commit 8cd568b

1 file changed

Lines changed: 27 additions & 8 deletions

File tree

causalml/inference/meta/tlearner.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from packaging import version
55
from scipy.stats import norm
66
import sklearn
7+
from sklearn.base import clone
78
from sklearn.exceptions import ConvergenceWarning
89
from sklearn.neural_network import MLPRegressor
10+
from joblib import Parallel, delayed
911

1012
if version.parse(sklearn.__version__) >= version.parse("0.22.0"):
1113
from sklearn.utils._testing import ignore_warnings
@@ -77,9 +79,10 @@ def fit(
7779
y,
7880
p=None,
7981
store_bootstraps=False,
80-
n_bootstraps=1000,
82+
n_bootstraps=200,
8183
bootstrap_size=10000,
8284
random_state=None,
85+
n_jobs=1,
8386
):
8487
"""Fit the inference model
8588
@@ -91,7 +94,12 @@ def fit(
9194
store_bootstraps (bool, optional): if True, trains a bootstrap ensemble
9295
during fit and stores it in self.bootstrap_models_ for post-fit CI
9396
estimation via predict(return_ci=True). Default: False.
94-
n_bootstraps (int, optional): number of bootstrap iterations. Default: 1000.
97+
n_bootstraps (int, optional): number of bootstrap iterations. Default: 200.
98+
Note: storing N bootstraps of a GBM-based learner with k treatment
99+
groups holds 2*N*k model objects in memory. Monitor RAM for large N
100+
or heavy base learners.
101+
n_jobs (int, optional): number of parallel jobs for bootstrap fitting.
102+
-1 uses all available cores. Default: 1.
95103
bootstrap_size (int, optional): number of samples per bootstrap. Default: 10000.
96104
random_state (int, optional): random seed for reproducible bootstrap sampling.
97105
"""
@@ -118,25 +126,36 @@ def fit(
118126
logger.info(
119127
"Storing bootstrap ensemble ({} iterations)".format(n_bootstraps)
120128
)
121-
self.bootstrap_models_ = []
122-
for i in tqdm(range(n_bootstraps)):
123-
idxs = rng.choice(np.arange(X.shape[0]), size=bootstrap_size)
129+
seeds = rng.randint(0, np.iinfo(np.int32).max, size=n_bootstraps)
130+
131+
def _fit_one_bootstrap(seed):
132+
local_rng = np.random.RandomState(seed)
133+
idxs = local_rng.choice(np.arange(X.shape[0]), size=bootstrap_size)
124134
X_b, treatment_b, y_b = X[idxs], treatment[idxs], y[idxs]
125-
models_c_b = {group: deepcopy(self.model_c) for group in self.t_groups}
126-
models_t_b = {group: deepcopy(self.model_t) for group in self.t_groups}
135+
models_c_b = {group: clone(self.model_c) for group in self.t_groups}
136+
models_t_b = {group: clone(self.model_t) for group in self.t_groups}
127137
for group in self.t_groups:
128138
mask = (treatment_b == group) | (treatment_b == self.control_name)
129139
treatment_filt = treatment_b[mask]
130140
X_filt = X_b[mask]
131141
y_filt = y_b[mask]
132142
w = (treatment_filt == group).astype(int)
133143
if w.sum() == 0 or (w == 0).sum() == 0:
144+
logger.warning(
145+
"Bootstrap sample has no treated or no control units "
146+
"for group {}. Falling back to global model — "
147+
"CI may be underestimated.".format(group)
148+
)
134149
models_c_b[group] = self.models_c[group]
135150
models_t_b[group] = self.models_t[group]
136151
continue
137152
models_c_b[group].fit(X_filt[w == 0], y_filt[w == 0])
138153
models_t_b[group].fit(X_filt[w == 1], y_filt[w == 1])
139-
self.bootstrap_models_.append((models_c_b, models_t_b))
154+
return models_c_b, models_t_b
155+
156+
self.bootstrap_models_ = Parallel(n_jobs=n_jobs)(
157+
delayed(_fit_one_bootstrap)(s) for s in tqdm(seeds)
158+
)
140159
else:
141160
self.bootstrap_models_ = None
142161

0 commit comments

Comments
 (0)