Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 64c4d95

Browse files
committedMay 19, 2021
Add treatment recommendations to causal analysis
1 parent aa56e7d commit 64c4d95

File tree

2 files changed

+44
-39
lines changed

2 files changed

+44
-39
lines changed
 

‎econml/solutions/causal_analysis/_causal_analysis.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -997,10 +997,12 @@ def _tree(self, is_policy, Xtest, feature_index, *, treatment_cost=0,
997997
if is_policy:
998998
intrp.interpret(result.estimator, Xtest,
999999
sample_treatment_costs=treatment_cost)
1000+
treat = intrp.treat(Xtest)
10001001
else: # no treatment cost for CATE trees
10011002
intrp.interpret(result.estimator, Xtest)
1003+
treat = None
10021004

1003-
return intrp, result.X_transformer.get_feature_names(self.feature_names_), treatment_names
1005+
return intrp, result.X_transformer.get_feature_names(self.feature_names_), treatment_names, treat
10041006

10051007
# TODO: it seems like it would be better to just return the tree itself rather than plot it;
10061008
# however, the tree can't store the feature and treatment names we compute here...
@@ -1027,18 +1029,21 @@ def plot_policy_tree(self, Xtest, feature_index, *, treatment_cost=0,
10271029
Confidence level of the confidence intervals displayed in the leaf nodes.
10281030
A (1-alpha)*100% confidence interval is displayed.
10291031
"""
1030-
intrp, feature_names, treatment_names = self._tree(True, Xtest, feature_index,
1031-
treatment_cost=treatment_cost,
1032-
max_depth=max_depth,
1033-
min_samples_leaf=min_samples_leaf,
1034-
min_impurity_decrease=min_value_increase,
1035-
alpha=alpha)
1032+
intrp, feature_names, treatment_names, _ = self._tree(True, Xtest, feature_index,
1033+
treatment_cost=treatment_cost,
1034+
max_depth=max_depth,
1035+
min_samples_leaf=min_samples_leaf,
1036+
min_impurity_decrease=min_value_increase,
1037+
alpha=alpha)
10361038
return intrp.plot(feature_names=feature_names, treatment_names=treatment_names)
10371039

1038-
def _policy_tree_string(self, Xtest, feature_index, *, treatment_cost=0,
1040+
def _policy_tree_output(self, Xtest, feature_index, *, treatment_cost=0,
10391041
max_depth=3, min_samples_leaf=2, min_value_increase=1e-4, alpha=.1):
10401042
"""
1041-
Get a recommended policy tree in graphviz format as a string.
1043+
Get a tuple policy outputs.
1044+
1045+
The first item in the tuple is the recommended policy tree in graphviz format as a string.
1046+
The second item is the recommended treatment for each sample as a list.
10421047
10431048
Parameters
10441049
----------
@@ -1060,18 +1065,18 @@ def _policy_tree_string(self, Xtest, feature_index, *, treatment_cost=0,
10601065
10611066
Returns
10621067
-------
1063-
tree : string
1064-
The policy tree represented as a graphviz string
1068+
tree : tuple of string, list of int
1069+
The policy tree represented as a graphviz string and the recommended treatment for each row
10651070
"""
10661071

1067-
intrp, feature_names, treatment_names = self._tree(True, Xtest, feature_index,
1068-
treatment_cost=treatment_cost,
1069-
max_depth=max_depth,
1070-
min_samples_leaf=min_samples_leaf,
1071-
min_impurity_decrease=min_value_increase,
1072-
alpha=alpha)
1072+
intrp, feature_names, treatment_names, treat = self._tree(True, Xtest, feature_index,
1073+
treatment_cost=treatment_cost,
1074+
max_depth=max_depth,
1075+
min_samples_leaf=min_samples_leaf,
1076+
min_impurity_decrease=min_value_increase,
1077+
alpha=alpha)
10731078
return intrp.export_graphviz(feature_names=feature_names,
1074-
treatment_names=treatment_names)
1079+
treatment_names=treatment_names), treat.tolist()
10751080

10761081
# TODO: it seems like it would be better to just return the tree itself rather than plot it;
10771082
# however, the tree can't store the feature and treatment names we compute here...
@@ -1099,11 +1104,11 @@ def plot_heterogeneity_tree(self, Xtest, feature_index, *,
10991104
A (1-alpha)*100% confidence interval is displayed.
11001105
"""
11011106

1102-
intrp, feature_names, treatment_names = self._tree(False, Xtest, feature_index,
1103-
max_depth=max_depth,
1104-
min_samples_leaf=min_samples_leaf,
1105-
min_impurity_decrease=min_impurity_decrease,
1106-
alpha=alpha)
1107+
intrp, feature_names, treatment_names, _ = self._tree(False, Xtest, feature_index,
1108+
max_depth=max_depth,
1109+
min_samples_leaf=min_samples_leaf,
1110+
min_impurity_decrease=min_impurity_decrease,
1111+
alpha=alpha)
11071112
return intrp.plot(feature_names=feature_names,
11081113
treatment_names=treatment_names)
11091114

@@ -1131,10 +1136,10 @@ def _heterogeneity_tree_string(self, Xtest, feature_index, *,
11311136
A (1-alpha)*100% confidence interval is displayed.
11321137
"""
11331138

1134-
intrp, feature_names, treatment_names = self._tree(False, Xtest, feature_index,
1135-
max_depth=max_depth,
1136-
min_samples_leaf=min_samples_leaf,
1137-
min_impurity_decrease=min_impurity_decrease,
1138-
alpha=alpha)
1139+
intrp, feature_names, treatment_names, _ = self._tree(False, Xtest, feature_index,
1140+
max_depth=max_depth,
1141+
min_samples_leaf=min_samples_leaf,
1142+
min_impurity_decrease=min_impurity_decrease,
1143+
alpha=alpha)
11391144
return intrp.export_graphviz(feature_names=feature_names,
11401145
treatment_names=treatment_names)

‎econml/tests/test_causal_analysis.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ def test_basic_array(self):
4444
coh_point_est = np.array(coh_dict[_CausalInsightsConstants.PointEstimateKey])
4545
loc_point_est = np.array(loc_dict[_CausalInsightsConstants.PointEstimateKey])
4646

47-
ca._policy_tree_string(X, 1)
47+
ca._policy_tree_output(X, 1)
4848
ca._heterogeneity_tree_string(X, 1)
4949
ca._heterogeneity_tree_string(X, 3)
5050

5151
# Can't handle multi-dimensional treatments
5252
with self.assertRaises(AssertionError):
53-
ca._policy_tree_string(X, 3)
53+
ca._policy_tree_output(X, 3)
5454

5555
# global shape is (d_y, sum(d_t))
5656
assert glo_point_est.shape == coh_point_est.shape == (1, 5)
@@ -133,13 +133,13 @@ def test_basic_pandas(self):
133133
assert glo_point_est.shape == coh_point_est.shape == (1, 5)
134134
assert loc_point_est.shape == (2,) + glo_point_est.shape
135135

136-
ca._policy_tree_string(X, inds[1])
136+
ca._policy_tree_output(X, inds[1])
137137
ca._heterogeneity_tree_string(X, inds[1])
138138
ca._heterogeneity_tree_string(X, inds[3])
139139

140140
# Can't handle multi-dimensional treatments
141141
with self.assertRaises(AssertionError):
142-
ca._policy_tree_string(X, inds[3])
142+
ca._policy_tree_output(X, inds[3])
143143

144144
if not classification:
145145
# ExitStack can be used as a "do nothing" ContextManager
@@ -199,13 +199,13 @@ def test_automl_first_stage(self):
199199
coh_point_est = np.array(coh_dict[_CausalInsightsConstants.PointEstimateKey])
200200
loc_point_est = np.array(loc_dict[_CausalInsightsConstants.PointEstimateKey])
201201

202-
ca._policy_tree_string(X, 1)
202+
ca._policy_tree_output(X, 1)
203203
ca._heterogeneity_tree_string(X, 1)
204204
ca._heterogeneity_tree_string(X, 3)
205205

206206
# Can't handle multi-dimensional treatments
207207
with self.assertRaises(AssertionError):
208-
ca._policy_tree_string(X, 3)
208+
ca._policy_tree_output(X, 3)
209209

210210
# global shape is (d_y, sum(d_t))
211211
assert glo_point_est.shape == coh_point_est.shape == (1, 5)
@@ -279,7 +279,7 @@ def test_one_feature(self):
279279
assert glo_point_est.shape == coh_point_est.shape == (1, 1)
280280
assert loc_point_est.shape == (2,) + glo_point_est.shape
281281

282-
ca._policy_tree_string(X, inds[0])
282+
ca._policy_tree_output(X, inds[0])
283283
ca._heterogeneity_tree_string(X, inds[0])
284284

285285
def test_final_models(self):
@@ -302,13 +302,13 @@ def test_final_models(self):
302302
coh_dict = ca._cohort_causal_effect_dict(X[:2])
303303
loc_dict = ca._local_causal_effect_dict(X[:2])
304304

305-
ca._policy_tree_string(X, 1)
305+
ca._policy_tree_output(X, 1)
306306
ca._heterogeneity_tree_string(X, 1)
307307
ca._heterogeneity_tree_string(X, 3)
308308

309309
# Can't handle multi-dimensional treatments
310310
with self.assertRaises(AssertionError):
311-
ca._policy_tree_string(X, 3)
311+
ca._policy_tree_output(X, 3)
312312

313313
if not classification:
314314
# ExitStack can be used as a "do nothing" ContextManager
@@ -370,13 +370,13 @@ def test_forest_with_pandas(self):
370370
assert glo_point_est.shape == coh_point_est.shape == (1, 5)
371371
assert loc_point_est.shape == (2,) + glo_point_est.shape
372372

373-
ca._policy_tree_string(X, inds[1])
373+
ca._policy_tree_output(X, inds[1])
374374
ca._heterogeneity_tree_string(X, inds[1])
375375
ca._heterogeneity_tree_string(X, inds[3])
376376

377377
# Can't handle multi-dimensional treatments
378378
with self.assertRaises(AssertionError):
379-
ca._policy_tree_string(X, inds[3])
379+
ca._policy_tree_output(X, inds[3])
380380

381381
def test_warm_start(self):
382382
for classification in [True, False]:

0 commit comments

Comments
 (0)
Please sign in to comment.