Skip to content

[Enhancement] Add return_ci support to predict() for post-fit confidence interval inference in meta-learners #885

@aman-coder03

Description

@aman-coder03

Problem

all meta learners (BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor, BaseDRRegressor) support confidence intervals via fit_predict(..., return_ci=True), which runs bootstrap resampling during training. However, predict() method used after a model is already fitted, has no return_ci parameter in any of these classes.
this makes post fit uncertainty quantification impossible. In real deployments, you train a model once and score new data repeatedly. There is currently no way to attach confidence intervals to downstream predictions without re running the full bootstrap on every new batch, which defeats the purpose.
this was partially noted in #67, which was closed after bootstrap CIs were added to fit_predict(). The predict() gap was never addressed

Proposed Solution

add return_ci, n_bootstraps, and bootstrap_size to the predict() signature of all meta-learners using a stored bootstrap ensemble approach
during fit(..., store_bootstraps=True) (new optional flag, default False), train n_bootstraps versions of the model on bootstrap samples and store them in self.bootstrap_models_. Then predict(X, return_ci=True) scores all stored models and returns the percentile CI, no retraining required

Proposed API

learner = BaseTRegressor(LGBMRegressor(), control_name='control')

# Train once, store bootstrap ensemble
learner.fit(
    X_train, treatment_train, y_train,
    store_bootstraps=True,
    n_bootstraps=200
)

# Score new data with CIs — no retraining needed
tau, tau_lb, tau_ub = learner.predict(
    X_test,
    return_ci=True,
    ci_quantile=0.05
)

Design Rationale

  • mirrors how scikit-learn's BaggingRegressor exposes estimators_ for downstream use
  • consistent with EconML's const_marginal_ate_inference() pattern, which separates fit from inference
  • fully backward-compatible: store_bootstraps defaults to False and predict() signature is unchanged unless opted in
  • bootstrap_models_ exposed as a documented attribute so users can serialize/deserialize via joblib

Affected Classes

  • BaseSRegressor / BaseSClassifier
  • BaseTRegressor / BaseTClassifier
  • BaseXRegressor / BaseXClassifier
  • BaseRRegressor / BaseRClassifier
  • BaseDRRegressor

Willingness to Contribute

happy to open a draft PR starting with BaseTRegressor as the reference implementation, with the other learners following the same pattern, will tag @jeongyoonlee for design feedback before going broad

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions