-
Notifications
You must be signed in to change notification settings - Fork 852
Description
Problem
all meta learners (BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor, BaseDRRegressor) support confidence intervals via fit_predict(..., return_ci=True), which runs bootstrap resampling during training. However, predict() method used after a model is already fitted, has no return_ci parameter in any of these classes.
this makes post fit uncertainty quantification impossible. In real deployments, you train a model once and score new data repeatedly. There is currently no way to attach confidence intervals to downstream predictions without re running the full bootstrap on every new batch, which defeats the purpose.
this was partially noted in #67, which was closed after bootstrap CIs were added to fit_predict(). The predict() gap was never addressed
Proposed Solution
add return_ci, n_bootstraps, and bootstrap_size to the predict() signature of all meta-learners using a stored bootstrap ensemble approach
during fit(..., store_bootstraps=True) (new optional flag, default False), train n_bootstraps versions of the model on bootstrap samples and store them in self.bootstrap_models_. Then predict(X, return_ci=True) scores all stored models and returns the percentile CI, no retraining required
Proposed API
learner = BaseTRegressor(LGBMRegressor(), control_name='control')
# Train once, store bootstrap ensemble
learner.fit(
X_train, treatment_train, y_train,
store_bootstraps=True,
n_bootstraps=200
)
# Score new data with CIs — no retraining needed
tau, tau_lb, tau_ub = learner.predict(
X_test,
return_ci=True,
ci_quantile=0.05
)Design Rationale
- mirrors how scikit-learn's
BaggingRegressorexposesestimators_for downstream use - consistent with EconML's
const_marginal_ate_inference()pattern, which separates fit from inference - fully backward-compatible:
store_bootstrapsdefaults toFalseandpredict()signature is unchanged unless opted in bootstrap_models_exposed as a documented attribute so users can serialize/deserialize viajoblib
Affected Classes
BaseSRegressor/BaseSClassifierBaseTRegressor/BaseTClassifierBaseXRegressor/BaseXClassifierBaseRRegressor/BaseRClassifierBaseDRRegressor
Willingness to Contribute
happy to open a draft PR starting with BaseTRegressor as the reference implementation, with the other learners following the same pattern, will tag @jeongyoonlee for design feedback before going broad