Skip to content

Commit d00947a

Browse files
authored
Prepare release 0.9.0 (#76)
* Update changelog. * Return tuple of arrays instead of concatenated arrays.
1 parent 1c4f019 commit d00947a

File tree

4 files changed

+16
-18
lines changed

4 files changed

+16
-18
lines changed

CHANGELOG.rst

+6-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Changelog
88
=========
99

10-
0.9.0 (2024-07-xx)
10+
0.9.0 (2024-08-02)
1111
------------------
1212

1313
**New features**
@@ -16,19 +16,16 @@ Changelog
1616

1717
* Add :class:`metalearners.utils.FixedBinaryPropensity`.
1818

19-
* Added ``_build_onnx`` to :class:`metalearners.MetaLearner` abstract class and implement it
19+
* Add ``_build_onnx`` to :class:`metalearners.MetaLearner` abstract class and implement it
2020
for :class:`metalearners.TLearner`, :class:`metalearners.XLearner`, :class:`metalearners.RLearner`
2121
and :class:`metalearners.DRLearner`.
2222

23-
* Added ``_necessary_onnx_models`` to :class:`metalearners.MetaLearner`.
23+
* Add ``_necessary_onnx_models`` to :class:`metalearners.MetaLearner`.
2424

25-
* Added :meth:`metalearners.metalearner.DRLearner.average_treatment_effect` to
26-
compute AIPW point estimate and standard error for _average
27-
treatment effects (ATE)_
28-
without requiring a full model fit (which is required for CATE
29-
estimation). A new notebook contains examples.
25+
* Add :meth:`metalearners.metalearner.DRLearner.average_treatment_effect` to
26+
compute the AIPW point estimate and standard error for
27+
_average treatment effects (ATE)_ without requiring a full model fit.
3028

31-
* Added :meth:`metalearners.metalearner.DRLearner.treatment_effect` to compute AIPW point estimate and standard error for _average treatment effects (ATE)_ without requiring a full model fit (which is required for CATE estimation). A new notebook contains examples.
3229

3330
0.8.0 (2024-07-22)
3431
------------------

docs/examples/example_estimating_ates.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@
310310
" np.c_[\n",
311311
" naive_est,\n",
312312
" linreg_est,\n",
313-
" metalearners_est.flatten(),\n",
313+
" np.hstack(metalearners_est),\n",
314314
" doubleml_est,\n",
315315
" econml_est,\n",
316316
"], index = ['est', 'se'],\n",

metalearners/drlearner.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,11 @@ def average_treatment_effect(
344344
y: Vector,
345345
w: Vector,
346346
is_oos: bool,
347-
) -> np.ndarray:
347+
) -> tuple[np.ndarray, np.ndarray]:
348348
"""Compute Average Treatment Effect (ATE) for each treatment variant using the
349349
Augmented IPW estimator (Robins et al 1994). Does not require fitting a second-
350-
stage treatment model: it uses the pseudo-outcome alone and computes the average
351-
and SE. Can be used following the
350+
stage treatment model: it uses the pseudo-outcome alone and computes the point
351+
estimate and standard error. Can be used following the
352352
:meth:`~metalearners.drlearner.DRLearner.fit_all_nuisance` method.
353353
354354
Args:
@@ -358,7 +358,8 @@ def average_treatment_effect(
358358
is_oos (bool): indicator whether data is out of sample
359359
360360
Returns:
361-
np.ndarray: Treatment effect and standard error for each treatment variant.
361+
np.ndarray: Treatment effect for each treatment variant.
362+
np.ndarray: Standard error for each treatment variant.
362363
"""
363364
if not self._nuisance_models_fit:
364365
raise ValueError(
@@ -375,7 +376,7 @@ def average_treatment_effect(
375376
)
376377
treatment_effect = gamma_matrix.mean(axis=0)
377378
standard_error = gamma_matrix.std(axis=0) / np.sqrt(len(X))
378-
return np.c_[treatment_effect, standard_error]
379+
return treatment_effect, standard_error
379380

380381
def _pseudo_outcome(
381382
self,

tests/test_drlearner.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_drlearner_onnx(
135135
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx, atol=5e-4)
136136

137137

138-
def test_treatment_effect(
138+
def test_average_treatment_effect(
139139
numerical_experiment_dataset_continuous_outcome_binary_treatment_linear_te,
140140
):
141141
X, _, W, Y, _, tau = (
@@ -150,5 +150,5 @@ def test_treatment_effect(
150150
n_folds=2,
151151
)
152152
ml.fit_all_nuisance(X, Y, W)
153-
est = ml.average_treatment_effect(X, Y, W, is_oos=False)
154-
np.testing.assert_almost_equal(est[:, 0], tau.mean(), decimal=1)
153+
ate_estimate, _ = ml.average_treatment_effect(X, Y, W, is_oos=False)
154+
np.testing.assert_almost_equal(ate_estimate, tau.mean(), decimal=1)

0 commit comments

Comments
 (0)