Skip to content

Commit afc75b7

Browse files
committed
- Adds benchmarking script in comments and shares latest numbers and plots
- Adds test for meta-learner consistency and key attributes
1 parent af7c852 commit afc75b7

4 files changed

Lines changed: 428 additions & 30 deletions

File tree

causalml/inference/meta/drlearner.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ def predict(
239239

240240
# models_mu_c is fold-specific but not group-specific; predict once and reuse.
241241
yhat_c = np.r_[[model.predict(X) for model in self.models_mu_c]].mean(axis=0)
242+
# Shared-reference dict preserves the public yhat_cs[group] API cheaply.
243+
yhat_cs = {group: yhat_c for group in self.t_groups}
242244

243245
for i, group in enumerate(self.t_groups):
244246
models_tau = self.models_tau[group]
@@ -264,7 +266,7 @@ def predict(
264266
if not return_components:
265267
return te
266268
else:
267-
return te, yhat_c, yhat_ts
269+
return te, yhat_cs, yhat_ts
268270

269271
def fit_predict(
270272
self,
@@ -394,11 +396,11 @@ def estimate_ate(
394396
The mean and confidence interval (LB, UB) of the ATE estimate.
395397
"""
396398
if pretrain:
397-
te, yhat_c, yhat_ts = self.predict(
399+
te, yhat_cs, yhat_ts = self.predict(
398400
X, treatment, y, p, return_components=True
399401
)
400402
else:
401-
te, yhat_c, yhat_ts = self.fit_predict(
403+
te, yhat_cs, yhat_ts = self.fit_predict(
402404
X, treatment, y, p, return_components=True, seed=seed
403405
)
404406
X, treatment, y = convert_pd_to_np(X, treatment, y)
@@ -427,17 +429,17 @@ def estimate_ate(
427429
w = (treatment_filt == group).astype(int)
428430
prob_treatment = float(sum(w)) / w.shape[0]
429431

430-
yhat_c_g = yhat_c[mask]
432+
yhat_c = yhat_cs[group][mask]
431433
yhat_t = yhat_ts[group][mask]
432434
y_filt = y[mask]
433435

434436
# SE formula is based on the lower bound formula (7) from Imbens, Guido W., and Jeffrey M. Wooldridge. 2009.
435437
# "Recent Developments in the Econometrics of Program Evaluation." Journal of Economic Literature
436438
se = np.sqrt(
437439
(
438-
(y_filt[w == 0] - yhat_c_g[w == 0]).var() / (1 - prob_treatment)
440+
(y_filt[w == 0] - yhat_c[w == 0]).var() / (1 - prob_treatment)
439441
+ (y_filt[w == 1] - yhat_t[w == 1]).var() / prob_treatment
440-
+ (yhat_t - yhat_c_g).var()
442+
+ (yhat_t - yhat_c).var()
441443
)
442444
/ y_filt.shape[0]
443445
)
@@ -600,6 +602,7 @@ def predict(
600602
yhat_c = np.r_[
601603
[model.predict_proba(X)[:, 1] for model in self.models_mu_c]
602604
].mean(axis=0)
605+
yhat_cs = {group: yhat_c for group in self.t_groups}
603606

604607
for i, group in enumerate(self.t_groups):
605608
models_tau = self.models_tau[group]
@@ -625,7 +628,7 @@ def predict(
625628
if not return_components:
626629
return te
627630
else:
628-
return te, yhat_c, yhat_ts
631+
return te, yhat_cs, yhat_ts
629632

630633

631634
class XGBDRRegressor(BaseDRRegressor):

causalml/inference/meta/tlearner.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def __init__(
5555
else:
5656
self.model_c = control_learner
5757

58+
# Preserve the unfitted template so repeated fit() calls always start fresh.
59+
self._model_c_template = self.model_c
60+
5861
if treatment_learner is None:
5962
self.model_t = deepcopy(learner)
6063
else:
@@ -85,10 +88,13 @@ def fit(self, X, treatment, y, p=None):
8588
self.models_t = {group: deepcopy(self.model_t) for group in self.t_groups}
8689

8790
# model_c is trained on the control group, which is identical for every
88-
# treatment group, so fit it once and store as a single model (not a dict).
91+
# treatment group, so fit it once. Deepcopy from the unfitted template so
92+
# re-calling fit() always starts from a clean state (safe with warm_start).
8993
control_mask = treatment == self.control_name
90-
self.model_c = deepcopy(self.model_c)
94+
self.model_c = deepcopy(self._model_c_template)
9195
self.model_c.fit(X[control_mask], y[control_mask])
96+
# Expose as a shared-reference dict to preserve the public models_c API.
97+
self.models_c = {group: self.model_c for group in self.t_groups}
9298

9399
for group in self.t_groups:
94100
treatment_mask = treatment == group
@@ -112,6 +118,9 @@ def predict(
112118
yhat_ts = {}
113119

114120
yhat_c = self.model_c.predict(X)
121+
# Build a shared-reference dict so return_components callers keep the
122+
# yhat_cs[group] indexing API without duplicating the underlying array.
123+
yhat_cs = {group: yhat_c for group in self.t_groups}
115124

116125
for group in self.t_groups:
117126
yhat_ts[group] = self.models_t[group].predict(X)
@@ -136,7 +145,7 @@ def predict(
136145
if not return_components:
137146
return te
138147
else:
139-
return te, yhat_c, yhat_ts
148+
return te, yhat_cs, yhat_ts
140149

141150
def fit_predict(
142151
self,
@@ -195,6 +204,7 @@ def fit_predict(
195204
self.t_groups = t_groups_global
196205
self._classes = _classes_global
197206
self.model_c = deepcopy(model_c_global)
207+
self.models_c = {group: self.model_c for group in self.t_groups}
198208
self.models_t = deepcopy(models_t_global)
199209

200210
return (te, te_lower, te_upper)
@@ -225,9 +235,9 @@ def estimate_ate(
225235
"""
226236
X, treatment, y = convert_pd_to_np(X, treatment, y)
227237
if pretrain:
228-
te, yhat_c, yhat_ts = self.predict(X, treatment, y, return_components=True)
238+
te, yhat_cs, yhat_ts = self.predict(X, treatment, y, return_components=True)
229239
else:
230-
te, yhat_c, yhat_ts = self.fit_predict(
240+
te, yhat_cs, yhat_ts = self.fit_predict(
231241
X, treatment, y, return_components=True
232242
)
233243

@@ -244,14 +254,14 @@ def estimate_ate(
244254
w = (treatment_filt == group).astype(int)
245255
prob_treatment = float(sum(w)) / w.shape[0]
246256

247-
yhat_c_g = yhat_c[mask]
257+
yhat_c = yhat_cs[group][mask]
248258
yhat_t = yhat_ts[group][mask]
249259

250260
se = np.sqrt(
251261
(
252-
(y_filt[w == 0] - yhat_c_g[w == 0]).var() / (1 - prob_treatment)
262+
(y_filt[w == 0] - yhat_c[w == 0]).var() / (1 - prob_treatment)
253263
+ (y_filt[w == 1] - yhat_t[w == 1]).var() / prob_treatment
254-
+ (yhat_t - yhat_c_g).var()
264+
+ (yhat_t - yhat_c).var()
255265
)
256266
/ y_filt.shape[0]
257267
)
@@ -289,6 +299,7 @@ def estimate_ate(
289299
self.t_groups = t_groups_global
290300
self._classes = _classes_global
291301
self.model_c = deepcopy(model_c_global)
302+
self.models_c = {group: self.model_c for group in self.t_groups}
292303
self.models_t = deepcopy(models_t_global)
293304

294305
return ate, ate_lower, ate_upper
@@ -371,6 +382,7 @@ def predict(
371382
yhat_ts = {}
372383

373384
yhat_c = self.model_c.predict_proba(X)[:, 1]
385+
yhat_cs = {group: yhat_c for group in self.t_groups}
374386

375387
for group in self.t_groups:
376388
yhat_ts[group] = self.models_t[group].predict_proba(X)[:, 1]
@@ -395,7 +407,7 @@ def predict(
395407
if not return_components:
396408
return te
397409
else:
398-
return te, yhat_c, yhat_ts
410+
return te, yhat_cs, yhat_ts
399411

400412

401413
class XGBTRegressor(BaseTRegressor):

causalml/inference/meta/xlearner.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def __init__(
5656
else:
5757
self.model_mu_c = control_outcome_learner
5858

59+
# Preserve the unfitted template so repeated fit() calls always start fresh.
60+
self._model_mu_c_template = self.model_mu_c
61+
5962
if treatment_outcome_learner is None:
6063
self.model_mu_t = deepcopy(learner)
6164
else:
@@ -125,14 +128,18 @@ def fit(self, X, treatment, y, p=None):
125128
self.vars_t = {}
126129

127130
# model_mu_c is trained on control data, which is the same for every treatment
128-
# group, so fit it once and store as a single model (not a per-group dict).
131+
# group. Deepcopy from the unfitted template so re-calling fit() starts fresh.
129132
control_mask = treatment == self.control_name
130-
self.model_mu_c = deepcopy(self.model_mu_c)
133+
self.model_mu_c = deepcopy(self._model_mu_c_template)
131134
self.model_mu_c.fit(X[control_mask], y[control_mask])
135+
# Expose as a shared-reference dict to preserve the public models_mu_c API.
136+
self.models_mu_c = {group: self.model_mu_c for group in self.t_groups}
132137

133-
# var_c depends only on model_mu_c and control data, both constant across groups.
138+
# var_c depends only on model_mu_c and control data constant across groups.
134139
y_control_pred = self.model_mu_c.predict(X[control_mask])
135-
var_c = (y[control_mask] - y_control_pred).var()
140+
self.var_c = (y[control_mask] - y_control_pred).var()
141+
# Keep vars_c dict for backward compatibility with existing callers.
142+
self.vars_c = {group: self.var_c for group in self.t_groups}
136143

137144
for group in self.t_groups:
138145
treatment_mask = treatment == group
@@ -141,9 +148,9 @@ def fit(self, X, treatment, y, p=None):
141148

142149
self.models_mu_t[group].fit(X_treat, y_treat)
143150

144-
self.vars_c[group] = var_c
145-
var_t = (y_treat - self.models_mu_t[group].predict(X_treat)).var()
146-
self.vars_t[group] = var_t
151+
self.vars_t[group] = (
152+
y_treat - self.models_mu_t[group].predict(X_treat)
153+
).var()
147154

148155
# Train treatment effect models using cross-group imputation
149156
d_c = self.models_mu_t[group].predict(X[control_mask]) - y[control_mask]
@@ -289,6 +296,7 @@ def fit_predict(
289296
self.t_groups = t_groups_global
290297
self._classes = _classes_global
291298
self.model_mu_c = deepcopy(model_mu_c_global)
299+
self.models_mu_c = {group: self.model_mu_c for group in self.t_groups}
292300
self.models_mu_t = deepcopy(models_mu_t_global)
293301
self.models_tau_c = deepcopy(models_tau_c_global)
294302
self.models_tau_t = deepcopy(models_tau_t_global)
@@ -367,7 +375,7 @@ def estimate_ate(
367375
se = np.sqrt(
368376
(
369377
self.vars_t[group] / prob_treatment
370-
+ self.vars_c[group] / (1 - prob_treatment)
378+
+ self.var_c / (1 - prob_treatment)
371379
+ (p_filt * dhat_c + (1 - p_filt) * dhat_t).var()
372380
)
373381
/ w.shape[0]
@@ -408,6 +416,7 @@ def estimate_ate(
408416
self.t_groups = t_groups_global
409417
self._classes = _classes_global
410418
self.model_mu_c = deepcopy(model_mu_c_global)
419+
self.models_mu_c = {group: self.model_mu_c for group in self.t_groups}
411420
self.models_mu_t = deepcopy(models_mu_t_global)
412421
self.models_tau_c = deepcopy(models_tau_c_global)
413422
self.models_tau_t = deepcopy(models_tau_t_global)
@@ -546,12 +555,14 @@ def fit(self, X, treatment, y, p=None):
546555
# model_mu_c is trained on control data, which is the same for every treatment
547556
# group, so fit it once and store as a single model (not a per-group dict).
548557
control_mask = treatment == self.control_name
549-
self.model_mu_c = deepcopy(self.model_mu_c)
558+
self.model_mu_c = deepcopy(self._model_mu_c_template)
550559
self.model_mu_c.fit(X[control_mask], y[control_mask])
560+
self.models_mu_c = {group: self.model_mu_c for group in self.t_groups}
551561

552-
# var_c depends only on model_mu_c and control data, both constant across groups.
562+
# var_c depends only on model_mu_c and control data constant across groups.
553563
y_control_pred = self.model_mu_c.predict_proba(X[control_mask])[:, 1]
554-
var_c = (y[control_mask] - y_control_pred).var()
564+
self.var_c = (y[control_mask] - y_control_pred).var()
565+
self.vars_c = {group: self.var_c for group in self.t_groups}
555566

556567
for group in self.t_groups:
557568
treatment_mask = treatment == group
@@ -560,11 +571,9 @@ def fit(self, X, treatment, y, p=None):
560571

561572
self.models_mu_t[group].fit(X_treat, y_treat)
562573

563-
self.vars_c[group] = var_c
564-
var_t = (
574+
self.vars_t[group] = (
565575
y_treat - self.models_mu_t[group].predict_proba(X_treat)[:, 1]
566576
).var()
567-
self.vars_t[group] = var_t
568577

569578
# Train treatment effect models using cross-group imputation
570579
d_c = (

0 commit comments

Comments
 (0)