44from packaging import version
55from scipy .stats import norm
66import sklearn
7+ from sklearn .base import clone
78from sklearn .exceptions import ConvergenceWarning
89from sklearn .neural_network import MLPRegressor
10+ from joblib import Parallel , delayed
911
1012if version .parse (sklearn .__version__ ) >= version .parse ("0.22.0" ):
1113 from sklearn .utils ._testing import ignore_warnings
@@ -77,9 +79,10 @@ def fit(
7779 y ,
7880 p = None ,
7981 store_bootstraps = False ,
80- n_bootstraps = 1000 ,
82+ n_bootstraps = 200 ,
8183 bootstrap_size = 10000 ,
8284 random_state = None ,
85+ n_jobs = 1 ,
8386 ):
8487 """Fit the inference model
8588
@@ -91,7 +94,12 @@ def fit(
9194 store_bootstraps (bool, optional): if True, trains a bootstrap ensemble
9295 during fit and stores it in self.bootstrap_models_ for post-fit CI
9396 estimation via predict(return_ci=True). Default: False.
94- n_bootstraps (int, optional): number of bootstrap iterations. Default: 1000.
97+ n_bootstraps (int, optional): number of bootstrap iterations. Default: 200.
98+ Note: storing N bootstraps of a GBM-based learner with k treatment
99+ groups holds 2*N*k model objects in memory. Monitor RAM for large N
100+ or heavy base learners.
101+ n_jobs (int, optional): number of parallel jobs for bootstrap fitting.
102+ -1 uses all available cores. Default: 1.
95103 bootstrap_size (int, optional): number of samples per bootstrap. Default: 10000.
96104 random_state (int, optional): random seed for reproducible bootstrap sampling.
97105 """
@@ -118,25 +126,36 @@ def fit(
118126 logger .info (
119127 "Storing bootstrap ensemble ({} iterations)" .format (n_bootstraps )
120128 )
121- self .bootstrap_models_ = []
122- for i in tqdm (range (n_bootstraps )):
123- idxs = rng .choice (np .arange (X .shape [0 ]), size = bootstrap_size )
129+ seeds = rng .randint (0 , np .iinfo (np .int32 ).max , size = n_bootstraps )
130+
131+ def _fit_one_bootstrap (seed ):
132+ local_rng = np .random .RandomState (seed )
133+ idxs = local_rng .choice (np .arange (X .shape [0 ]), size = bootstrap_size )
124134 X_b , treatment_b , y_b = X [idxs ], treatment [idxs ], y [idxs ]
125- models_c_b = {group : deepcopy (self .model_c ) for group in self .t_groups }
126- models_t_b = {group : deepcopy (self .model_t ) for group in self .t_groups }
135+ models_c_b = {group : clone (self .model_c ) for group in self .t_groups }
136+ models_t_b = {group : clone (self .model_t ) for group in self .t_groups }
127137 for group in self .t_groups :
128138 mask = (treatment_b == group ) | (treatment_b == self .control_name )
129139 treatment_filt = treatment_b [mask ]
130140 X_filt = X_b [mask ]
131141 y_filt = y_b [mask ]
132142 w = (treatment_filt == group ).astype (int )
133143 if w .sum () == 0 or (w == 0 ).sum () == 0 :
144+ logger .warning (
145+ "Bootstrap sample has no treated or no control units "
146+ "for group {}. Falling back to global model — "
147+ "CI may be underestimated." .format (group )
148+ )
134149 models_c_b [group ] = self .models_c [group ]
135150 models_t_b [group ] = self .models_t [group ]
136151 continue
137152 models_c_b [group ].fit (X_filt [w == 0 ], y_filt [w == 0 ])
138153 models_t_b [group ].fit (X_filt [w == 1 ], y_filt [w == 1 ])
139- self .bootstrap_models_ .append ((models_c_b , models_t_b ))
154+ return models_c_b , models_t_b
155+
156+ self .bootstrap_models_ = Parallel (n_jobs = n_jobs )(
157+ delayed (_fit_one_bootstrap )(s ) for s in tqdm (seeds )
158+ )
140159 else :
141160 self .bootstrap_models_ = None
142161
0 commit comments