Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 08c906f

Browse files
authoredMar 3, 2021
Adjust ate inference to get the exact stderr when final stage is linear. (#418)
* adjust the ate inference to exact stderr when it's possible
1 parent 2cc9f62 commit 08c906f

File tree

6 files changed

+126
-39
lines changed

6 files changed

+126
-39
lines changed
 

‎econml/dml/causal_forest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def const_marginal_effect_inference(self, X):
7979
pred = pred.reshape((-1,) + self._d_y + self._d_t)
8080
pred_stderr = np.sqrt(np.diagonal(pred_var, axis1=2, axis2=3).reshape((-1,) + self._d_y + self._d_t))
8181
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
82-
pred_stderr=pred_stderr, inf_type='effect')
82+
pred_stderr=pred_stderr, mean_pred_stderr=None, inf_type='effect')
8383

8484
def effect_interval(self, X, *, T0, T1, alpha=0.1):
8585
return self.effect_inference(X, T0=T0, T1=T1).conf_int(alpha=alpha)
@@ -97,7 +97,7 @@ def effect_inference(self, X, *, T0, T1):
9797
pred = pred.reshape((-1,) + self._d_y)
9898
pred_stderr = np.sqrt(pred_var.reshape((-1,) + self._d_y))
9999
return NormalInferenceResults(d_t=None, d_y=self.d_y, pred=pred,
100-
pred_stderr=pred_stderr, inf_type='effect')
100+
pred_stderr=pred_stderr, mean_pred_stderr=None, inf_type='effect')
101101

102102

103103
class CausalForestDML(_BaseDML):

‎econml/inference/_bootstrap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def normal_inference(*args, **kwargs):
264264
stderr = stderr(*args, **kwargs)
265265
return NormalInferenceResults(
266266
d_t=d_t, d_y=d_y, pred=pred,
267-
pred_stderr=stderr, inf_type=inf_type,
267+
pred_stderr=stderr, mean_pred_stderr=None, inf_type=inf_type,
268268
fname_transformer=fname_transformer,
269269
feature_names=self._wrapped.cate_feature_names(),
270270
output_names=self._wrapped.cate_output_names(),

‎econml/inference/_inference.py

Lines changed: 95 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def const_marginal_effect_inference(self, X):
147147
warn("Final model doesn't have a `prediction_stderr` method, "
148148
"only point estimates will be returned.")
149149
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
150-
pred_stderr=pred_stderr, inf_type='effect',
150+
pred_stderr=pred_stderr, mean_pred_stderr=None, inf_type='effect',
151151
feature_names=self._est.cate_feature_names(),
152152
output_names=self._est.cate_output_names(),
153153
treatment_names=self._est.cate_treatment_names())
@@ -193,9 +193,10 @@ def effect_inference(self, X, *, T0, T1):
193193
e_pred = np.einsum(einsum_str, cme_pred, dT)
194194
e_stderr = np.einsum(einsum_str, cme_stderr, np.abs(dT)) if cme_stderr is not None else None
195195
d_y = self._d_y[0] if self._d_y else 1
196+
196197
# d_t=None here since we measure the effect across all Ts
197198
return NormalInferenceResults(d_t=None, d_y=d_y, pred=e_pred,
198-
pred_stderr=e_stderr, inf_type='effect',
199+
pred_stderr=e_stderr, mean_pred_stderr=None, inf_type='effect',
199200
feature_names=self._est.cate_feature_names(),
200201
output_names=self._est.cate_output_names())
201202

@@ -240,15 +241,38 @@ def effect_inference(self, X, *, T0, T1):
240241
X = np.ones((T0.shape[0], 1))
241242
elif self.featurizer is not None:
242243
X = self.featurizer.transform(X)
243-
e_pred = self._predict(cross_product(X, T1 - T0))
244-
e_stderr = self._prediction_stderr(cross_product(X, T1 - T0))
244+
XT = cross_product(X, T1 - T0)
245+
e_pred = self._predict(XT)
246+
e_stderr = self._prediction_stderr(XT)
245247
d_y = self._d_y[0] if self._d_y else 1
248+
249+
mean_XT = XT.mean(axis=0, keepdims=True)
250+
mean_pred_stderr = self._prediction_stderr(mean_XT) # shape[0] will always be 1 here
251+
# squeeze the first axis
252+
mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0) if mean_pred_stderr is not None else None
246253
# d_t=None here since we measure the effect across all Ts
247254
return NormalInferenceResults(d_t=None, d_y=d_y, pred=e_pred,
248-
pred_stderr=e_stderr, inf_type='effect',
255+
pred_stderr=e_stderr, mean_pred_stderr=mean_pred_stderr, inf_type='effect',
249256
feature_names=self._est.cate_feature_names(),
250257
output_names=self._est.cate_output_names())
251258

259+
def const_marginal_effect_inference(self, X):
260+
inf_res = super().const_marginal_effect_inference(X)
261+
262+
# set the mean_pred_stderr
263+
if X is None:
264+
X = np.ones((1, 1))
265+
elif self.featurizer is not None:
266+
X = self.featurizer.transform(X)
267+
X_mean, T_mean = broadcast_unit_treatments(X.mean(axis=0).reshape(1, -1), self.d_t)
268+
mean_XT = cross_product(X_mean, T_mean)
269+
mean_pred_stderr = self._prediction_stderr(mean_XT)
270+
if mean_pred_stderr is not None:
271+
mean_pred_stderr = reshape_treatmentwise_effects(mean_pred_stderr,
272+
self._d_t, self._d_y) # shape[0] will always be 1 here
273+
inf_res.mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0)
274+
return inf_res
275+
252276
def coef__interval(self, *, alpha=0.1):
253277
lo, hi = self.model_final.coef__interval(alpha)
254278
lo_int, hi_int = self.model_final.intercept__interval(alpha)
@@ -285,6 +309,7 @@ def coef__inference(self):
285309
fname_transformer = self._est.cate_feature_names
286310

287311
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=coef, pred_stderr=coef_stderr,
312+
mean_pred_stderr=None,
288313
inf_type='coefficient', fname_transformer=fname_transformer,
289314
feature_names=self._est.cate_feature_names(),
290315
output_names=self._est.cate_output_names(),
@@ -323,6 +348,7 @@ def intercept__inference(self):
323348
intercept_stderr = None
324349

325350
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=intercept, pred_stderr=intercept_stderr,
351+
mean_pred_stderr=None,
326352
inf_type='intercept',
327353
feature_names=self._est.cate_feature_names(),
328354
output_names=self._est.cate_output_names(),
@@ -380,11 +406,7 @@ def fit(self, estimator, *args, **kwargs):
380406
self.fit_cate_intercept = estimator.fit_cate_intercept
381407

382408
def const_marginal_effect_interval(self, X, *, alpha=0.1):
383-
if (X is not None) and (self.featurizer is not None):
384-
X = self.featurizer.transform(X)
385-
preds = np.array([tuple(map(lambda x: x.reshape((-1,) + self._d_y), mdl.predict_interval(X, alpha=alpha)))
386-
for mdl in self.fitted_models_final])
387-
return tuple(np.moveaxis(preds, [0, 1], [-1, 0])) # send treatment to the end, pull bounds to the front
409+
return self.const_marginal_effect_inference(X).conf_int(alpha=alpha)
388410

389411
def const_marginal_effect_inference(self, X):
390412
if (X is not None) and (self.featurizer is not None):
@@ -401,22 +423,14 @@ def const_marginal_effect_inference(self, X):
401423
"Only point estimates will be available.")
402424
pred_stderr = None
403425
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
404-
pred_stderr=pred_stderr, inf_type='effect',
426+
pred_stderr=pred_stderr, mean_pred_stderr=None,
427+
inf_type='effect',
405428
feature_names=self._est.cate_feature_names(),
406429
output_names=self._est.cate_output_names(),
407430
treatment_names=self._est.cate_treatment_names())
408431

409432
def effect_interval(self, X, *, T0, T1, alpha=0.1):
410-
X, T0, T1 = self._est._expand_treatments(X, T0, T1)
411-
if np.any(np.any(T0 > 0, axis=1)):
412-
raise AttributeError("Can only calculate intervals of effects with respect to baseline treatment!")
413-
ind = inverse_onehot(T1)
414-
lower, upper = self.const_marginal_effect_interval(X, alpha=alpha)
415-
lower = np.concatenate([np.zeros(lower.shape[0:-1] + (1,)), lower], -1)
416-
upper = np.concatenate([np.zeros(upper.shape[0:-1] + (1,)), upper], -1)
417-
if X is None: # Then const_marginal_effect_interval will return a single row
418-
lower, upper = np.repeat(lower, T0.shape[0], axis=0), np.repeat(upper, T0.shape[0], axis=0)
419-
return lower[np.arange(T0.shape[0]), ..., ind], upper[np.arange(T0.shape[0]), ..., ind]
433+
return self.effect_inference(X, T0=T0, T1=T1).conf_int(alpha=alpha)
420434

421435
def effect_inference(self, X, *, T0, T1):
422436
X, T0, T1 = self._est._expand_treatments(X, T0, T1)
@@ -434,9 +448,10 @@ def effect_inference(self, X, *, T0, T1):
434448
pred_stderr = np.repeat(pred_stderr, T0.shape[0], axis=0) if pred_stderr is not None else None
435449
pred = pred[np.arange(T0.shape[0]), ..., ind]
436450
pred_stderr = pred_stderr[np.arange(T0.shape[0]), ..., ind] if pred_stderr is not None else None
451+
437452
# d_t=None here since we measure the effect across all Ts
438453
return NormalInferenceResults(d_t=None, d_y=self.d_y, pred=pred,
439-
pred_stderr=pred_stderr,
454+
pred_stderr=pred_stderr, mean_pred_stderr=None,
440455
inf_type='effect',
441456
feature_names=self._est.cate_feature_names(),
442457
output_names=self._est.cate_output_names())
@@ -449,6 +464,33 @@ class LinearModelFinalInferenceDiscrete(GenericModelFinalInferenceDiscrete):
449464
based on the corresponding methods of the underlying model_final estimator.
450465
"""
451466

467+
def const_marginal_effect_inference(self, X):
468+
res_inf = super().const_marginal_effect_inference(X)
469+
470+
# set the mean_pred_stderr
471+
if (X is not None) and (self.featurizer is not None):
472+
X = self.featurizer.transform(X)
473+
474+
if hasattr(self.fitted_models_final[0], 'prediction_stderr'):
475+
mean_X = X.mean(axis=0).reshape(1, -1) if X is not None else None
476+
mean_pred_stderr = np.moveaxis(np.array([mdl.prediction_stderr(mean_X).reshape((-1,) + self._d_y)
477+
for mdl in self.fitted_models_final]),
478+
0, -1) # shape[0] will always be 1 here
479+
res_inf.mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0)
480+
return res_inf
481+
482+
def effect_inference(self, X, *, T0, T1):
483+
res_inf = super().effect_inference(X, T0=T0, T1=T1)
484+
485+
# replace the mean_pred_stderr if T1 and T0 is a constant or a constant of vector
486+
_, _, T1 = self._est._expand_treatments(X, T0, T1)
487+
ind = inverse_onehot(T1)
488+
if len(set(ind)) == 1:
489+
unique_ind = ind[0] - 1
490+
mean_pred_stderr = self.const_marginal_effect_inference(X).mean_pred_stderr[..., unique_ind]
491+
res_inf.mean_pred_stderr = mean_pred_stderr
492+
return res_inf
493+
452494
def coef__interval(self, T, *, alpha=0.1):
453495
_, T = self._est._expand_treatments(None, T)
454496
ind = inverse_onehot(T).item() - 1
@@ -472,8 +514,10 @@ def coef__inference(self, T):
472514
fname_transformer = None
473515
if hasattr(self._est, 'cate_feature_names') and callable(self._est.cate_feature_names):
474516
fname_transformer = self._est.cate_feature_names
517+
475518
# d_t=None here since we measure the effect across all Ts
476519
return NormalInferenceResults(d_t=None, d_y=self.d_y, pred=coef, pred_stderr=coef_stderr,
520+
mean_pred_stderr=None,
477521
inf_type='coefficient', fname_transformer=fname_transformer,
478522
feature_names=self._est.cate_feature_names(),
479523
output_names=self._est.cate_output_names())
@@ -500,7 +544,7 @@ def intercept__inference(self, T):
500544
intercept_stderr = None
501545
# d_t=None here since we measure the effect across all Ts
502546
return NormalInferenceResults(d_t=None, d_y=self.d_y, pred=self.fitted_models_final[ind].intercept_,
503-
pred_stderr=intercept_stderr,
547+
pred_stderr=intercept_stderr, mean_pred_stderr=None,
504548
inf_type='intercept',
505549
feature_names=self._est.cate_feature_names(),
506550
output_names=self._est.cate_output_names())
@@ -748,7 +792,6 @@ def summary_frame(self, alpha=0.1, value=0, decimals=3,
748792

749793
elif self.inf_type == 'intercept':
750794
res.index = res.index.set_levels(['cate_intercept'], level="X")
751-
752795
if self._d_t == 1:
753796
res.index = res.index.droplevel("T")
754797
if self.d_y == 1:
@@ -786,6 +829,7 @@ def population_summary(self, alpha=0.1, value=0, decimals=3, tol=0.001, output_n
786829
output_names = self.output_names if output_names is None else output_names
787830
if self.inf_type == 'effect':
788831
return PopulationSummaryResults(pred=self.point_estimate, pred_stderr=self.stderr,
832+
mean_pred_stderr=None,
789833
d_t=self.d_t, d_y=self.d_y,
790834
alpha=alpha, value=value, decimals=decimals, tol=tol,
791835
output_names=output_names, treatment_names=treatment_names)
@@ -839,17 +883,22 @@ class NormalInferenceResults(InferenceResults):
839883
Note that when Y or T is a vector rather than a 2-dimensional array,
840884
the corresponding singleton dimensions should be collapsed
841885
(e.g. if both are vectors, then the input of this argument will also be a vector)
886+
mean_pred_stderr: None or array-like or scaler, shape (d_y, d_t) or (d_y,)
887+
The standard error of the mean point estimate, this is derived from coefficient stderr when final
888+
stage is linear model, otherwise it's None.
889+
This is the exact standard error of the mean, which is not conservative.
842890
inf_type: string
843891
The type of inference result.
844892
It could be either 'effect', 'coefficient' or 'intercept'.
845893
fname_transformer: None or predefined function
846894
The transform function to get the corresponding feature names from featurizer
847895
"""
848896

849-
def __init__(self, d_t, d_y, pred, pred_stderr, inf_type, fname_transformer=None,
897+
def __init__(self, d_t, d_y, pred, pred_stderr, mean_pred_stderr, inf_type, fname_transformer=None,
850898
feature_names=None, output_names=None, treatment_names=None):
851899
self.pred_stderr = np.copy(pred_stderr) if pred_stderr is not None and not np.isscalar(
852900
pred_stderr) else pred_stderr
901+
self.mean_pred_stderr = mean_pred_stderr
853902
super().__init__(d_t, d_y, pred, inf_type, fname_transformer, feature_names, output_names, treatment_names)
854903

855904
@property
@@ -915,11 +964,20 @@ def pvalue(self, value=0):
915964
"""
916965
return norm.sf(np.abs(self.zstat(value)), loc=0, scale=1) * 2
917966

967+
def population_summary(self, alpha=0.1, value=0, decimals=3, tol=0.001, output_names=None, treatment_names=None):
968+
pop_summ = super().population_summary(alpha=alpha, value=value, decimals=decimals,
969+
tol=tol, output_names=output_names, treatment_names=treatment_names)
970+
pop_summ.mean_pred_stderr = self.mean_pred_stderr
971+
return pop_summ
972+
population_summary.__doc__ = InferenceResults.population_summary.__doc__
973+
918974
def _expand_outputs(self, n_rows):
919975
assert shape(self.pred)[0] == shape(self.pred_stderr)[0] == 1
920976
pred = np.repeat(self.pred, n_rows, axis=0)
921977
pred_stderr = np.repeat(self.pred_stderr, n_rows, axis=0) if self.pred_stderr is not None else None
922-
return NormalInferenceResults(self.d_t, self.d_y, pred, pred_stderr, self.inf_type,
978+
return NormalInferenceResults(self.d_t, self.d_y, pred, pred_stderr,
979+
self.mean_pred_stderr,
980+
self.inf_type,
923981
self.fname_transformer, self.feature_names,
924982
self.output_names, self.treatment_names)
925983

@@ -1039,6 +1097,10 @@ class PopulationSummaryResults:
10391097
Note that when Y or T is a vector rather than a 2-dimensional array,
10401098
the corresponding singleton dimensions should be collapsed
10411099
(e.g. if both are vectors, then the input of this argument will also be a vector)
1100+
mean_pred_stderr: None or array-like or scaler, shape (d_y, d_t) or (d_y,)
1101+
The standard error of the mean point estimate, this is derived from coefficient stderr when final
1102+
stage is linear model, otherwise it's None.
1103+
This is the exact standard error of the mean, which is not conservative.
10421104
alpha: optional float in [0, 1] (default=0.1)
10431105
The overall level of confidence of the reported interval.
10441106
The alpha/2, 1-alpha/2 confidence interval is reported.
@@ -1055,10 +1117,11 @@ class PopulationSummaryResults:
10551117
10561118
"""
10571119

1058-
def __init__(self, pred, pred_stderr, d_t, d_y, alpha, value, decimals, tol,
1120+
def __init__(self, pred, pred_stderr, mean_pred_stderr, d_t, d_y, alpha, value, decimals, tol,
10591121
output_names=None, treatment_names=None):
10601122
self.pred = pred
10611123
self.pred_stderr = pred_stderr
1124+
self.mean_pred_stderr = mean_pred_stderr
10621125
self.d_t = d_t
10631126
# For effect summaries, d_t is None, but the result arrays behave as if d_t=1
10641127
self._d_t = d_t or 1
@@ -1106,7 +1169,9 @@ def stderr_mean(self):
11061169
the corresponding singleton dimensions in the output will be collapsed
11071170
(e.g. if both are vectors, then the output of this method will be a scalar)
11081171
"""
1109-
if self.pred_stderr is None:
1172+
if self.mean_pred_stderr is not None:
1173+
return self.mean_pred_stderr
1174+
elif self.pred_stderr is None:
11101175
raise AttributeError("Only point estimates are available!")
11111176
return np.sqrt(np.mean(self.pred_stderr**2, axis=0))
11121177

@@ -1312,13 +1377,13 @@ def _print(self, *, alpha=None, value=None, decimals=None, tol=None, output_name
13121377
self._format_res(self.pvalue(value=value), decimals),
13131378
self._format_res(self.conf_int_mean(alpha=alpha)[0], decimals),
13141379
self._format_res(self.conf_int_mean(alpha=alpha)[1], decimals)))
1315-
13161380
if treatment_names is None:
13171381
treatment_names = ['T' + str(i) for i in range(self._d_t)]
13181382
if output_names is None:
13191383
output_names = ['Y' + str(i) for i in range(self.d_y)]
13201384

13211385
myheaders1 = ['mean_point', 'stderr_mean', 'zstat', 'pvalue', 'ci_mean_lower', 'ci_mean_upper']
1386+
13221387
mystubs = self._get_stub_names(self.d_y, self._d_t, treatment_names, output_names)
13231388
title1 = "Uncertainty of Mean Point Estimate"
13241389

@@ -1331,13 +1396,12 @@ def _print(self, *, alpha=None, value=None, decimals=None, tol=None, output_name
13311396

13321397
smry = Summary()
13331398
smry.add_table(res1, myheaders1, mystubs, title1)
1334-
if self.pred_stderr is not None:
1399+
if self.pred_stderr is not None and self.mean_pred_stderr is None:
13351400
text1 = "Note: The stderr_mean is a conservative upper bound."
13361401
smry.add_extra_txt([text1])
13371402
smry.add_table(res2, myheaders2, mystubs, title2)
13381403

13391404
if self.pred_stderr is not None:
1340-
13411405
# 3. Total Variance of Point Estimate
13421406
res3 = np.hstack((self._format_res(self.stderr_point, self.decimals),
13431407
self._format_res(self.conf_int_point(alpha=alpha, tol=tol)[0],
There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Please sign in to comment.