@@ -55,6 +55,9 @@ def __init__(
5555 else :
5656 self .model_c = control_learner
5757
58+ # Preserve the unfitted template so repeated fit() calls always start fresh.
59+ self ._model_c_template = self .model_c
60+
5861 if treatment_learner is None :
5962 self .model_t = deepcopy (learner )
6063 else :
@@ -85,10 +88,13 @@ def fit(self, X, treatment, y, p=None):
8588 self .models_t = {group : deepcopy (self .model_t ) for group in self .t_groups }
8689
8790 # model_c is trained on the control group, which is identical for every
88- # treatment group, so fit it once and store as a single model (not a dict).
91+ # treatment group, so fit it once. Deepcopy from the unfitted template so
92+ # re-calling fit() always starts from a clean state (safe with warm_start).
8993 control_mask = treatment == self .control_name
90- self .model_c = deepcopy (self .model_c )
94+ self .model_c = deepcopy (self ._model_c_template )
9195 self .model_c .fit (X [control_mask ], y [control_mask ])
96+ # Expose as a shared-reference dict to preserve the public models_c API.
97+ self .models_c = {group : self .model_c for group in self .t_groups }
9298
9399 for group in self .t_groups :
94100 treatment_mask = treatment == group
@@ -112,6 +118,9 @@ def predict(
112118 yhat_ts = {}
113119
114120 yhat_c = self .model_c .predict (X )
121+ # Build a shared-reference dict so return_components callers keep the
122+ # yhat_cs[group] indexing API without duplicating the underlying array.
123+ yhat_cs = {group : yhat_c for group in self .t_groups }
115124
116125 for group in self .t_groups :
117126 yhat_ts [group ] = self .models_t [group ].predict (X )
@@ -136,7 +145,7 @@ def predict(
136145 if not return_components :
137146 return te
138147 else :
139- return te , yhat_c , yhat_ts
148+ return te , yhat_cs , yhat_ts
140149
141150 def fit_predict (
142151 self ,
@@ -195,6 +204,7 @@ def fit_predict(
195204 self .t_groups = t_groups_global
196205 self ._classes = _classes_global
197206 self .model_c = deepcopy (model_c_global )
207+ self .models_c = {group : self .model_c for group in self .t_groups }
198208 self .models_t = deepcopy (models_t_global )
199209
200210 return (te , te_lower , te_upper )
@@ -225,9 +235,9 @@ def estimate_ate(
225235 """
226236 X , treatment , y = convert_pd_to_np (X , treatment , y )
227237 if pretrain :
228- te , yhat_c , yhat_ts = self .predict (X , treatment , y , return_components = True )
238+ te , yhat_cs , yhat_ts = self .predict (X , treatment , y , return_components = True )
229239 else :
230- te , yhat_c , yhat_ts = self .fit_predict (
240+ te , yhat_cs , yhat_ts = self .fit_predict (
231241 X , treatment , y , return_components = True
232242 )
233243
@@ -244,14 +254,14 @@ def estimate_ate(
244254 w = (treatment_filt == group ).astype (int )
245255 prob_treatment = float (sum (w )) / w .shape [0 ]
246256
247- yhat_c_g = yhat_c [mask ]
257+ yhat_c = yhat_cs [ group ] [mask ]
248258 yhat_t = yhat_ts [group ][mask ]
249259
250260 se = np .sqrt (
251261 (
252- (y_filt [w == 0 ] - yhat_c_g [w == 0 ]).var () / (1 - prob_treatment )
262+ (y_filt [w == 0 ] - yhat_c [w == 0 ]).var () / (1 - prob_treatment )
253263 + (y_filt [w == 1 ] - yhat_t [w == 1 ]).var () / prob_treatment
254- + (yhat_t - yhat_c_g ).var ()
264+ + (yhat_t - yhat_c ).var ()
255265 )
256266 / y_filt .shape [0 ]
257267 )
@@ -289,6 +299,7 @@ def estimate_ate(
289299 self .t_groups = t_groups_global
290300 self ._classes = _classes_global
291301 self .model_c = deepcopy (model_c_global )
302+ self .models_c = {group : self .model_c for group in self .t_groups }
292303 self .models_t = deepcopy (models_t_global )
293304
294305 return ate , ate_lower , ate_upper
@@ -371,6 +382,7 @@ def predict(
371382 yhat_ts = {}
372383
373384 yhat_c = self .model_c .predict_proba (X )[:, 1 ]
385+ yhat_cs = {group : yhat_c for group in self .t_groups }
374386
375387 for group in self .t_groups :
376388 yhat_ts [group ] = self .models_t [group ].predict_proba (X )[:, 1 ]
@@ -395,7 +407,7 @@ def predict(
395407 if not return_components :
396408 return te
397409 else :
398- return te , yhat_c , yhat_ts
410+ return te , yhat_cs , yhat_ts
399411
400412
401413class XGBTRegressor (BaseTRegressor ):
0 commit comments