Skip to content

Commit 98b2bf3

Browse files
authored
Vasilis/policy (#377)
* added policy learning module * added cython policy tree and policy forest * extended policy cate interpreter to interpret multiple treatments using the new policy tree * added doubly robust policy learning tree and doubly robust policy learning forest * fixed randomness in weightedkfold, that was causing tests to fail due to non-fixed-randomness behavior * added notebook on policy learning
1 parent cac4c3e commit 98b2bf3

34 files changed

+4758
-1340
lines changed

doc/reference.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,18 @@ Sieve Methods
104104
econml.iv.sieve.HermiteFeatures
105105
econml.iv.sieve.DPolynomialFeatures
106106

107+
.. _policy_api:
108+
109+
Policy Learning
110+
---------------
111+
112+
.. autosummary::
113+
:toctree: _autosummary
114+
115+
econml.policy.DRPolicyForest
116+
econml.policy.DRPolicyTree
117+
econml.policy.PolicyForest
118+
econml.policy.PolicyTree
107119

108120
.. _interpreters_api:
109121

doc/spec/estimation/dr.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ Usage FAQs
438438

439439
.. testcode::
440440

441-
from econml.drlearner import DRLearner
441+
from econml.dr import DRLearner
442442
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
443443
from sklearn.model_selection import GridSearchCV
444444
model_reg = lambda: GridSearchCV(

econml/__init__.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
33

4-
__all__ = ['automated_ml', 'bootstrap',
5-
'cate_interpreter', 'causal_forest',
6-
'data', 'deepiv', 'dml', 'dr', 'drlearner',
7-
'inference', 'iv',
8-
'metalearners', 'ortho_forest', 'orf', 'ortho_iv',
9-
'score', 'sklearn_extensions', 'tree',
10-
'two_stage_least_squares', 'utilities', "dowhy", "__version__"]
4+
__all__ = ['automated_ml',
5+
'bootstrap',
6+
'cate_interpreter',
7+
'causal_forest',
8+
'data',
9+
'deepiv',
10+
'dml',
11+
'dr',
12+
'drlearner',
13+
'inference',
14+
'iv',
15+
'metalearners',
16+
'ortho_forest',
17+
'orf',
18+
'ortho_iv',
19+
'policy',
20+
'score',
21+
'sklearn_extensions',
22+
'tree',
23+
'two_stage_least_squares',
24+
'utilities',
25+
'dowhy',
26+
'__version__']
1127

1228
__version__ = '0.9.2'

econml/_ensemble/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
from ._ensemble import BaseEnsemble, _partition_estimators
5+
from ._utilities import (_get_n_samples_subsample, _accumulate_prediction, _accumulate_prediction_var,
6+
_accumulate_prediction_and_var, _accumulate_oob_preds)
7+
8+
__all__ = ["BaseEnsemble",
9+
"_partition_estimators",
10+
"_get_n_samples_subsample",
11+
"_accumulate_prediction",
12+
"_accumulate_prediction_var",
13+
"_accumulate_prediction_and_var",
14+
"_accumulate_oob_preds"]
File renamed without changes.

econml/_ensemble/_utilities.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import numbers
5+
import numpy as np
6+
7+
8+
def _get_n_samples_subsample(n_samples, max_samples):
9+
"""
10+
Get the number of samples in a sub-sample without replacement.
11+
Parameters
12+
----------
13+
n_samples : int
14+
Number of samples in the dataset.
15+
max_samples : int or float
16+
The maximum number of samples to draw from the total available:
17+
- if float, this indicates a fraction of the total and should be
18+
the interval `(0, 1)`;
19+
- if int, this indicates the exact number of samples;
20+
- if None, this indicates the total number of samples.
21+
Returns
22+
-------
23+
n_samples_subsample : int
24+
The total number of samples to draw for the subsample.
25+
"""
26+
if max_samples is None:
27+
return n_samples
28+
29+
if isinstance(max_samples, numbers.Integral):
30+
if not (1 <= max_samples <= n_samples):
31+
msg = "`max_samples` must be in range 1 to {} but got value {}"
32+
raise ValueError(msg.format(n_samples, max_samples))
33+
return max_samples
34+
35+
if isinstance(max_samples, numbers.Real):
36+
if not (0 < max_samples <= 1):
37+
msg = "`max_samples` must be in range (0, 1) but got value {}"
38+
raise ValueError(msg.format(max_samples))
39+
return int(round(n_samples * max_samples))
40+
41+
msg = "`max_samples` should be int or float, but got type '{}'"
42+
raise TypeError(msg.format(type(max_samples)))
43+
44+
45+
def _accumulate_prediction(predict, X, out, lock, *args, **kwargs):
46+
"""
47+
This is a utility function for joblib's Parallel.
48+
It can't go locally in ForestClassifier or ForestRegressor, because joblib
49+
complains that it cannot pickle it when placed there.
50+
"""
51+
prediction = predict(X, *args, check_input=False, **kwargs)
52+
with lock:
53+
if len(out) == 1:
54+
out[0] += prediction
55+
else:
56+
for i in range(len(out)):
57+
out[i] += prediction[i]
58+
59+
60+
def _accumulate_prediction_var(predict, X, out, lock, *args, **kwargs):
61+
"""
62+
This is a utility function for joblib's Parallel.
63+
It can't go locally in ForestClassifier or ForestRegressor, because joblib
64+
complains that it cannot pickle it when placed there.
65+
Accumulates the mean covariance of a tree prediction. predict is assumed to
66+
return an array of (n_samples, d) or a tuple of arrays. This method accumulates in the placeholder
67+
out[0] the (n_samples, d, d) covariance of the columns of the prediction across
68+
the trees and for each sample (or a tuple of covariances to be stored in each element
69+
of the list out).
70+
"""
71+
prediction = predict(X, *args, check_input=False, **kwargs)
72+
with lock:
73+
if len(out) == 1:
74+
out[0] += np.einsum('ijk,ikm->ijm',
75+
prediction.reshape(prediction.shape + (1,)),
76+
prediction.reshape((-1, 1) + prediction.shape[1:]))
77+
else:
78+
for i in range(len(out)):
79+
pred_i = prediction[i]
80+
out[i] += np.einsum('ijk,ikm->ijm',
81+
pred_i.reshape(pred_i.shape + (1,)),
82+
pred_i.reshape((-1, 1) + pred_i.shape[1:]))
83+
84+
85+
def _accumulate_prediction_and_var(predict, X, out, out_var, lock, *args, **kwargs):
86+
"""
87+
This is a utility function for joblib's Parallel.
88+
It can't go locally in ForestClassifier or ForestRegressor, because joblib
89+
complains that it cannot pickle it when placed there.
90+
Combines `_accumulate_prediction` and `_accumulate_prediction_var` in a single
91+
parallel run, so that out will contain the mean of the predictions across trees
92+
and out_var the covariance.
93+
"""
94+
prediction = predict(X, *args, check_input=False, **kwargs)
95+
with lock:
96+
if len(out) == 1:
97+
out[0] += prediction
98+
out_var[0] += np.einsum('ijk,ikm->ijm',
99+
prediction.reshape(prediction.shape + (1,)),
100+
prediction.reshape((-1, 1) + prediction.shape[1:]))
101+
else:
102+
for i in range(len(out)):
103+
pred_i = prediction[i]
104+
out[i] += prediction
105+
out_var[i] += np.einsum('ijk,ikm->ijm',
106+
pred_i.reshape(pred_i.shape + (1,)),
107+
pred_i.reshape((-1, 1) + pred_i.shape[1:]))
108+
109+
110+
def _accumulate_oob_preds(tree, X, subsample_inds, alpha_hat, jac_hat, counts, lock):
111+
mask = np.ones(X.shape[0], dtype=bool)
112+
mask[subsample_inds] = False
113+
alpha, jac = tree.predict_alpha_and_jac(X[mask])
114+
with lock:
115+
alpha_hat[mask] += alpha
116+
jac_hat[mask] += jac
117+
counts[mask] += 1

0 commit comments

Comments
 (0)