@@ -147,7 +147,7 @@ def const_marginal_effect_inference(self, X):
147
147
warn ("Final model doesn't have a `prediction_stderr` method, "
148
148
"only point estimates will be returned." )
149
149
return NormalInferenceResults (d_t = self .d_t , d_y = self .d_y , pred = pred ,
150
- pred_stderr = pred_stderr , inf_type = 'effect' ,
150
+ pred_stderr = pred_stderr , mean_pred_stderr = None , inf_type = 'effect' ,
151
151
feature_names = self ._est .cate_feature_names (),
152
152
output_names = self ._est .cate_output_names (),
153
153
treatment_names = self ._est .cate_treatment_names ())
@@ -193,9 +193,10 @@ def effect_inference(self, X, *, T0, T1):
193
193
e_pred = np .einsum (einsum_str , cme_pred , dT )
194
194
e_stderr = np .einsum (einsum_str , cme_stderr , np .abs (dT )) if cme_stderr is not None else None
195
195
d_y = self ._d_y [0 ] if self ._d_y else 1
196
+
196
197
# d_t=None here since we measure the effect across all Ts
197
198
return NormalInferenceResults (d_t = None , d_y = d_y , pred = e_pred ,
198
- pred_stderr = e_stderr , inf_type = 'effect' ,
199
+ pred_stderr = e_stderr , mean_pred_stderr = None , inf_type = 'effect' ,
199
200
feature_names = self ._est .cate_feature_names (),
200
201
output_names = self ._est .cate_output_names ())
201
202
@@ -240,15 +241,38 @@ def effect_inference(self, X, *, T0, T1):
240
241
X = np .ones ((T0 .shape [0 ], 1 ))
241
242
elif self .featurizer is not None :
242
243
X = self .featurizer .transform (X )
243
- e_pred = self ._predict (cross_product (X , T1 - T0 ))
244
- e_stderr = self ._prediction_stderr (cross_product (X , T1 - T0 ))
244
+ XT = cross_product (X , T1 - T0 )
245
+ e_pred = self ._predict (XT )
246
+ e_stderr = self ._prediction_stderr (XT )
245
247
d_y = self ._d_y [0 ] if self ._d_y else 1
248
+
249
+ mean_XT = XT .mean (axis = 0 , keepdims = True )
250
+ mean_pred_stderr = self ._prediction_stderr (mean_XT ) # shape[0] will always be 1 here
251
+ # squeeze the first axis
252
+ mean_pred_stderr = np .squeeze (mean_pred_stderr , axis = 0 ) if mean_pred_stderr is not None else None
246
253
# d_t=None here since we measure the effect across all Ts
247
254
return NormalInferenceResults (d_t = None , d_y = d_y , pred = e_pred ,
248
- pred_stderr = e_stderr , inf_type = 'effect' ,
255
+ pred_stderr = e_stderr , mean_pred_stderr = mean_pred_stderr , inf_type = 'effect' ,
249
256
feature_names = self ._est .cate_feature_names (),
250
257
output_names = self ._est .cate_output_names ())
251
258
259
+ def const_marginal_effect_inference (self , X ):
260
+ inf_res = super ().const_marginal_effect_inference (X )
261
+
262
+ # set the mean_pred_stderr
263
+ if X is None :
264
+ X = np .ones ((1 , 1 ))
265
+ elif self .featurizer is not None :
266
+ X = self .featurizer .transform (X )
267
+ X_mean , T_mean = broadcast_unit_treatments (X .mean (axis = 0 ).reshape (1 , - 1 ), self .d_t )
268
+ mean_XT = cross_product (X_mean , T_mean )
269
+ mean_pred_stderr = self ._prediction_stderr (mean_XT )
270
+ if mean_pred_stderr is not None :
271
+ mean_pred_stderr = reshape_treatmentwise_effects (mean_pred_stderr ,
272
+ self ._d_t , self ._d_y ) # shape[0] will always be 1 here
273
+ inf_res .mean_pred_stderr = np .squeeze (mean_pred_stderr , axis = 0 )
274
+ return inf_res
275
+
252
276
def coef__interval (self , * , alpha = 0.1 ):
253
277
lo , hi = self .model_final .coef__interval (alpha )
254
278
lo_int , hi_int = self .model_final .intercept__interval (alpha )
@@ -285,6 +309,7 @@ def coef__inference(self):
285
309
fname_transformer = self ._est .cate_feature_names
286
310
287
311
return NormalInferenceResults (d_t = self .d_t , d_y = self .d_y , pred = coef , pred_stderr = coef_stderr ,
312
+ mean_pred_stderr = None ,
288
313
inf_type = 'coefficient' , fname_transformer = fname_transformer ,
289
314
feature_names = self ._est .cate_feature_names (),
290
315
output_names = self ._est .cate_output_names (),
@@ -323,6 +348,7 @@ def intercept__inference(self):
323
348
intercept_stderr = None
324
349
325
350
return NormalInferenceResults (d_t = self .d_t , d_y = self .d_y , pred = intercept , pred_stderr = intercept_stderr ,
351
+ mean_pred_stderr = None ,
326
352
inf_type = 'intercept' ,
327
353
feature_names = self ._est .cate_feature_names (),
328
354
output_names = self ._est .cate_output_names (),
@@ -380,11 +406,7 @@ def fit(self, estimator, *args, **kwargs):
380
406
self .fit_cate_intercept = estimator .fit_cate_intercept
381
407
382
408
def const_marginal_effect_interval (self , X , * , alpha = 0.1 ):
383
- if (X is not None ) and (self .featurizer is not None ):
384
- X = self .featurizer .transform (X )
385
- preds = np .array ([tuple (map (lambda x : x .reshape ((- 1 ,) + self ._d_y ), mdl .predict_interval (X , alpha = alpha )))
386
- for mdl in self .fitted_models_final ])
387
- return tuple (np .moveaxis (preds , [0 , 1 ], [- 1 , 0 ])) # send treatment to the end, pull bounds to the front
409
+ return self .const_marginal_effect_inference (X ).conf_int (alpha = alpha )
388
410
389
411
def const_marginal_effect_inference (self , X ):
390
412
if (X is not None ) and (self .featurizer is not None ):
@@ -401,22 +423,14 @@ def const_marginal_effect_inference(self, X):
401
423
"Only point estimates will be available." )
402
424
pred_stderr = None
403
425
return NormalInferenceResults (d_t = self .d_t , d_y = self .d_y , pred = pred ,
404
- pred_stderr = pred_stderr , inf_type = 'effect' ,
426
+ pred_stderr = pred_stderr , mean_pred_stderr = None ,
427
+ inf_type = 'effect' ,
405
428
feature_names = self ._est .cate_feature_names (),
406
429
output_names = self ._est .cate_output_names (),
407
430
treatment_names = self ._est .cate_treatment_names ())
408
431
409
432
def effect_interval (self , X , * , T0 , T1 , alpha = 0.1 ):
410
- X , T0 , T1 = self ._est ._expand_treatments (X , T0 , T1 )
411
- if np .any (np .any (T0 > 0 , axis = 1 )):
412
- raise AttributeError ("Can only calculate intervals of effects with respect to baseline treatment!" )
413
- ind = inverse_onehot (T1 )
414
- lower , upper = self .const_marginal_effect_interval (X , alpha = alpha )
415
- lower = np .concatenate ([np .zeros (lower .shape [0 :- 1 ] + (1 ,)), lower ], - 1 )
416
- upper = np .concatenate ([np .zeros (upper .shape [0 :- 1 ] + (1 ,)), upper ], - 1 )
417
- if X is None : # Then const_marginal_effect_interval will return a single row
418
- lower , upper = np .repeat (lower , T0 .shape [0 ], axis = 0 ), np .repeat (upper , T0 .shape [0 ], axis = 0 )
419
- return lower [np .arange (T0 .shape [0 ]), ..., ind ], upper [np .arange (T0 .shape [0 ]), ..., ind ]
433
+ return self .effect_inference (X , T0 = T0 , T1 = T1 ).conf_int (alpha = alpha )
420
434
421
435
def effect_inference (self , X , * , T0 , T1 ):
422
436
X , T0 , T1 = self ._est ._expand_treatments (X , T0 , T1 )
@@ -434,9 +448,10 @@ def effect_inference(self, X, *, T0, T1):
434
448
pred_stderr = np .repeat (pred_stderr , T0 .shape [0 ], axis = 0 ) if pred_stderr is not None else None
435
449
pred = pred [np .arange (T0 .shape [0 ]), ..., ind ]
436
450
pred_stderr = pred_stderr [np .arange (T0 .shape [0 ]), ..., ind ] if pred_stderr is not None else None
451
+
437
452
# d_t=None here since we measure the effect across all Ts
438
453
return NormalInferenceResults (d_t = None , d_y = self .d_y , pred = pred ,
439
- pred_stderr = pred_stderr ,
454
+ pred_stderr = pred_stderr , mean_pred_stderr = None ,
440
455
inf_type = 'effect' ,
441
456
feature_names = self ._est .cate_feature_names (),
442
457
output_names = self ._est .cate_output_names ())
@@ -449,6 +464,33 @@ class LinearModelFinalInferenceDiscrete(GenericModelFinalInferenceDiscrete):
449
464
based on the corresponding methods of the underlying model_final estimator.
450
465
"""
451
466
467
+ def const_marginal_effect_inference (self , X ):
468
+ res_inf = super ().const_marginal_effect_inference (X )
469
+
470
+ # set the mean_pred_stderr
471
+ if (X is not None ) and (self .featurizer is not None ):
472
+ X = self .featurizer .transform (X )
473
+
474
+ if hasattr (self .fitted_models_final [0 ], 'prediction_stderr' ):
475
+ mean_X = X .mean (axis = 0 ).reshape (1 , - 1 ) if X is not None else None
476
+ mean_pred_stderr = np .moveaxis (np .array ([mdl .prediction_stderr (mean_X ).reshape ((- 1 ,) + self ._d_y )
477
+ for mdl in self .fitted_models_final ]),
478
+ 0 , - 1 ) # shape[0] will always be 1 here
479
+ res_inf .mean_pred_stderr = np .squeeze (mean_pred_stderr , axis = 0 )
480
+ return res_inf
481
+
482
+ def effect_inference (self , X , * , T0 , T1 ):
483
+ res_inf = super ().effect_inference (X , T0 = T0 , T1 = T1 )
484
+
485
+ # replace the mean_pred_stderr if T1 and T0 is a constant or a constant of vector
486
+ _ , _ , T1 = self ._est ._expand_treatments (X , T0 , T1 )
487
+ ind = inverse_onehot (T1 )
488
+ if len (set (ind )) == 1 :
489
+ unique_ind = ind [0 ] - 1
490
+ mean_pred_stderr = self .const_marginal_effect_inference (X ).mean_pred_stderr [..., unique_ind ]
491
+ res_inf .mean_pred_stderr = mean_pred_stderr
492
+ return res_inf
493
+
452
494
def coef__interval (self , T , * , alpha = 0.1 ):
453
495
_ , T = self ._est ._expand_treatments (None , T )
454
496
ind = inverse_onehot (T ).item () - 1
@@ -472,8 +514,10 @@ def coef__inference(self, T):
472
514
fname_transformer = None
473
515
if hasattr (self ._est , 'cate_feature_names' ) and callable (self ._est .cate_feature_names ):
474
516
fname_transformer = self ._est .cate_feature_names
517
+
475
518
# d_t=None here since we measure the effect across all Ts
476
519
return NormalInferenceResults (d_t = None , d_y = self .d_y , pred = coef , pred_stderr = coef_stderr ,
520
+ mean_pred_stderr = None ,
477
521
inf_type = 'coefficient' , fname_transformer = fname_transformer ,
478
522
feature_names = self ._est .cate_feature_names (),
479
523
output_names = self ._est .cate_output_names ())
@@ -500,7 +544,7 @@ def intercept__inference(self, T):
500
544
intercept_stderr = None
501
545
# d_t=None here since we measure the effect across all Ts
502
546
return NormalInferenceResults (d_t = None , d_y = self .d_y , pred = self .fitted_models_final [ind ].intercept_ ,
503
- pred_stderr = intercept_stderr ,
547
+ pred_stderr = intercept_stderr , mean_pred_stderr = None ,
504
548
inf_type = 'intercept' ,
505
549
feature_names = self ._est .cate_feature_names (),
506
550
output_names = self ._est .cate_output_names ())
@@ -748,7 +792,6 @@ def summary_frame(self, alpha=0.1, value=0, decimals=3,
748
792
749
793
elif self .inf_type == 'intercept' :
750
794
res .index = res .index .set_levels (['cate_intercept' ], level = "X" )
751
-
752
795
if self ._d_t == 1 :
753
796
res .index = res .index .droplevel ("T" )
754
797
if self .d_y == 1 :
@@ -786,6 +829,7 @@ def population_summary(self, alpha=0.1, value=0, decimals=3, tol=0.001, output_n
786
829
output_names = self .output_names if output_names is None else output_names
787
830
if self .inf_type == 'effect' :
788
831
return PopulationSummaryResults (pred = self .point_estimate , pred_stderr = self .stderr ,
832
+ mean_pred_stderr = None ,
789
833
d_t = self .d_t , d_y = self .d_y ,
790
834
alpha = alpha , value = value , decimals = decimals , tol = tol ,
791
835
output_names = output_names , treatment_names = treatment_names )
@@ -839,17 +883,22 @@ class NormalInferenceResults(InferenceResults):
839
883
Note that when Y or T is a vector rather than a 2-dimensional array,
840
884
the corresponding singleton dimensions should be collapsed
841
885
(e.g. if both are vectors, then the input of this argument will also be a vector)
886
+ mean_pred_stderr: None or array-like or scaler, shape (d_y, d_t) or (d_y,)
887
+ The standard error of the mean point estimate, this is derived from coefficient stderr when final
888
+ stage is linear model, otherwise it's None.
889
+ This is the exact standard error of the mean, which is not conservative.
842
890
inf_type: string
843
891
The type of inference result.
844
892
It could be either 'effect', 'coefficient' or 'intercept'.
845
893
fname_transformer: None or predefined function
846
894
The transform function to get the corresponding feature names from featurizer
847
895
"""
848
896
849
- def __init__ (self , d_t , d_y , pred , pred_stderr , inf_type , fname_transformer = None ,
897
+ def __init__ (self , d_t , d_y , pred , pred_stderr , mean_pred_stderr , inf_type , fname_transformer = None ,
850
898
feature_names = None , output_names = None , treatment_names = None ):
851
899
self .pred_stderr = np .copy (pred_stderr ) if pred_stderr is not None and not np .isscalar (
852
900
pred_stderr ) else pred_stderr
901
+ self .mean_pred_stderr = mean_pred_stderr
853
902
super ().__init__ (d_t , d_y , pred , inf_type , fname_transformer , feature_names , output_names , treatment_names )
854
903
855
904
@property
@@ -915,11 +964,20 @@ def pvalue(self, value=0):
915
964
"""
916
965
return norm .sf (np .abs (self .zstat (value )), loc = 0 , scale = 1 ) * 2
917
966
967
+ def population_summary (self , alpha = 0.1 , value = 0 , decimals = 3 , tol = 0.001 , output_names = None , treatment_names = None ):
968
+ pop_summ = super ().population_summary (alpha = alpha , value = value , decimals = decimals ,
969
+ tol = tol , output_names = output_names , treatment_names = treatment_names )
970
+ pop_summ .mean_pred_stderr = self .mean_pred_stderr
971
+ return pop_summ
972
+ population_summary .__doc__ = InferenceResults .population_summary .__doc__
973
+
918
974
def _expand_outputs (self , n_rows ):
919
975
assert shape (self .pred )[0 ] == shape (self .pred_stderr )[0 ] == 1
920
976
pred = np .repeat (self .pred , n_rows , axis = 0 )
921
977
pred_stderr = np .repeat (self .pred_stderr , n_rows , axis = 0 ) if self .pred_stderr is not None else None
922
- return NormalInferenceResults (self .d_t , self .d_y , pred , pred_stderr , self .inf_type ,
978
+ return NormalInferenceResults (self .d_t , self .d_y , pred , pred_stderr ,
979
+ self .mean_pred_stderr ,
980
+ self .inf_type ,
923
981
self .fname_transformer , self .feature_names ,
924
982
self .output_names , self .treatment_names )
925
983
@@ -1039,6 +1097,10 @@ class PopulationSummaryResults:
1039
1097
Note that when Y or T is a vector rather than a 2-dimensional array,
1040
1098
the corresponding singleton dimensions should be collapsed
1041
1099
(e.g. if both are vectors, then the input of this argument will also be a vector)
1100
+ mean_pred_stderr: None or array-like or scaler, shape (d_y, d_t) or (d_y,)
1101
+ The standard error of the mean point estimate, this is derived from coefficient stderr when final
1102
+ stage is linear model, otherwise it's None.
1103
+ This is the exact standard error of the mean, which is not conservative.
1042
1104
alpha: optional float in [0, 1] (default=0.1)
1043
1105
The overall level of confidence of the reported interval.
1044
1106
The alpha/2, 1-alpha/2 confidence interval is reported.
@@ -1055,10 +1117,11 @@ class PopulationSummaryResults:
1055
1117
1056
1118
"""
1057
1119
1058
- def __init__ (self , pred , pred_stderr , d_t , d_y , alpha , value , decimals , tol ,
1120
+ def __init__ (self , pred , pred_stderr , mean_pred_stderr , d_t , d_y , alpha , value , decimals , tol ,
1059
1121
output_names = None , treatment_names = None ):
1060
1122
self .pred = pred
1061
1123
self .pred_stderr = pred_stderr
1124
+ self .mean_pred_stderr = mean_pred_stderr
1062
1125
self .d_t = d_t
1063
1126
# For effect summaries, d_t is None, but the result arrays behave as if d_t=1
1064
1127
self ._d_t = d_t or 1
@@ -1106,7 +1169,9 @@ def stderr_mean(self):
1106
1169
the corresponding singleton dimensions in the output will be collapsed
1107
1170
(e.g. if both are vectors, then the output of this method will be a scalar)
1108
1171
"""
1109
- if self .pred_stderr is None :
1172
+ if self .mean_pred_stderr is not None :
1173
+ return self .mean_pred_stderr
1174
+ elif self .pred_stderr is None :
1110
1175
raise AttributeError ("Only point estimates are available!" )
1111
1176
return np .sqrt (np .mean (self .pred_stderr ** 2 , axis = 0 ))
1112
1177
@@ -1312,13 +1377,13 @@ def _print(self, *, alpha=None, value=None, decimals=None, tol=None, output_name
1312
1377
self ._format_res (self .pvalue (value = value ), decimals ),
1313
1378
self ._format_res (self .conf_int_mean (alpha = alpha )[0 ], decimals ),
1314
1379
self ._format_res (self .conf_int_mean (alpha = alpha )[1 ], decimals )))
1315
-
1316
1380
if treatment_names is None :
1317
1381
treatment_names = ['T' + str (i ) for i in range (self ._d_t )]
1318
1382
if output_names is None :
1319
1383
output_names = ['Y' + str (i ) for i in range (self .d_y )]
1320
1384
1321
1385
myheaders1 = ['mean_point' , 'stderr_mean' , 'zstat' , 'pvalue' , 'ci_mean_lower' , 'ci_mean_upper' ]
1386
+
1322
1387
mystubs = self ._get_stub_names (self .d_y , self ._d_t , treatment_names , output_names )
1323
1388
title1 = "Uncertainty of Mean Point Estimate"
1324
1389
@@ -1331,13 +1396,12 @@ def _print(self, *, alpha=None, value=None, decimals=None, tol=None, output_name
1331
1396
1332
1397
smry = Summary ()
1333
1398
smry .add_table (res1 , myheaders1 , mystubs , title1 )
1334
- if self .pred_stderr is not None :
1399
+ if self .pred_stderr is not None and self . mean_pred_stderr is None :
1335
1400
text1 = "Note: The stderr_mean is a conservative upper bound."
1336
1401
smry .add_extra_txt ([text1 ])
1337
1402
smry .add_table (res2 , myheaders2 , mystubs , title2 )
1338
1403
1339
1404
if self .pred_stderr is not None :
1340
-
1341
1405
# 3. Total Variance of Point Estimate
1342
1406
res3 = np .hstack ((self ._format_res (self .stderr_point , self .decimals ),
1343
1407
self ._format_res (self .conf_int_point (alpha = alpha , tol = tol )[0 ],
0 commit comments