Skip to content

Commit be2523c

Browse files
committed
fix incorrect bootstrap mean calculation in meta-learners (#828)
1 parent 21a2bba commit be2523c

7 files changed

Lines changed: 13 additions & 10 deletions

File tree

causalml/inference/meta/base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from abc import ABCMeta, abstractclassmethod
1+
from abc import ABCMeta, abstractmethod
22
import logging
33
import numpy as np
44
import pandas as pd
@@ -12,11 +12,13 @@
1212

1313

1414
class BaseLearner(metaclass=ABCMeta):
15-
@abstractclassmethod
15+
@classmethod
16+
@abstractmethod
1617
def fit(self, X, treatment, y, p=None):
1718
pass
1819

19-
@abstractclassmethod
20+
@classmethod
21+
@abstractmethod
2022
def predict(
2123
self, X, treatment=None, y=None, p=None, return_components=False, verbose=True
2224
):
@@ -37,7 +39,8 @@ def fit_predict(
3739
self.fit(X, treatment, y, p)
3840
return self.predict(X, treatment, y, p, return_components, verbose)
3941

40-
@abstractclassmethod
42+
@classmethod
43+
@abstractmethod
4144
def estimate_ate(
4245
self,
4346
X,

causalml/inference/meta/drlearner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def estimate_ate(
434434
cate_b = self.bootstrap(
435435
X, treatment, y, p, size=bootstrap_size, seed=seed
436436
)
437-
ate_bootstraps[:, n] = cate_b.mean()
437+
ate_bootstraps[:, n] = cate_b.mean(axis=0)
438438

439439
ate_lower = np.percentile(
440440
ate_bootstraps, (self.ate_alpha / 2) * 100, axis=1

causalml/inference/meta/rlearner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def estimate_ate(
324324
else:
325325
p = self._format_p(p, self.t_groups)
326326
cate_b = self.bootstrap(X, treatment, y, p, size=bootstrap_size)
327-
ate_bootstraps[:, n] = cate_b.mean()
327+
ate_bootstraps[:, n] = cate_b.mean(axis=0)
328328

329329
ate_lower = np.percentile(
330330
ate_bootstraps, (self.ate_alpha / 2) * 100, axis=1

causalml/inference/meta/slearner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def estimate_ate(
278278

279279
for n in tqdm(range(n_bootstraps)):
280280
ate_b = self.bootstrap(X, treatment, y, size=bootstrap_size)
281-
ate_bootstraps[:, n] = ate_b.mean()
281+
ate_bootstraps[:, n] = ate_b.mean(axis=0)
282282

283283
ate_lower = np.percentile(
284284
ate_bootstraps, (self.ate_alpha / 2) * 100, axis=1

causalml/inference/meta/tlearner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def estimate_ate(
280280

281281
for n in tqdm(range(n_bootstraps)):
282282
ate_b = self.bootstrap(X, treatment, y, size=bootstrap_size)
283-
ate_bootstraps[:, n] = ate_b.mean()
283+
ate_bootstraps[:, n] = ate_b.mean(axis=0)
284284

285285
ate_lower = np.percentile(
286286
ate_bootstraps, (self.ate_alpha / 2) * 100, axis=1

causalml/inference/meta/xlearner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def estimate_ate(
390390

391391
for n in tqdm(range(n_bootstraps)):
392392
cate_b = self.bootstrap(X, treatment, y, p, size=bootstrap_size)
393-
ate_bootstraps[:, n] = cate_b.mean()
393+
ate_bootstraps[:, n] = cate_b.mean(axis=0)
394394

395395
ate_lower = np.percentile(
396396
ate_bootstraps, (self.ate_alpha / 2) * 100, axis=1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "causalml"
3-
version = "0.15.3"
3+
version = "0.15.4dev"
44
description = "Python Package for Uplift Modeling and Causal Inference with Machine Learning Algorithms"
55
readme = { file = "README.md", content-type = "text/markdown" }
66

0 commit comments

Comments
 (0)