@@ -997,10 +997,12 @@ def _tree(self, is_policy, Xtest, feature_index, *, treatment_cost=0,
997
997
if is_policy :
998
998
intrp .interpret (result .estimator , Xtest ,
999
999
sample_treatment_costs = treatment_cost )
1000
+ treat = intrp .treat (Xtest )
1000
1001
else : # no treatment cost for CATE trees
1001
1002
intrp .interpret (result .estimator , Xtest )
1003
+ treat = None
1002
1004
1003
- return intrp , result .X_transformer .get_feature_names (self .feature_names_ ), treatment_names
1005
+ return intrp , result .X_transformer .get_feature_names (self .feature_names_ ), treatment_names , treat
1004
1006
1005
1007
# TODO: it seems like it would be better to just return the tree itself rather than plot it;
1006
1008
# however, the tree can't store the feature and treatment names we compute here...
@@ -1027,18 +1029,21 @@ def plot_policy_tree(self, Xtest, feature_index, *, treatment_cost=0,
1027
1029
Confidence level of the confidence intervals displayed in the leaf nodes.
1028
1030
A (1-alpha)*100% confidence interval is displayed.
1029
1031
"""
1030
- intrp , feature_names , treatment_names = self ._tree (True , Xtest , feature_index ,
1031
- treatment_cost = treatment_cost ,
1032
- max_depth = max_depth ,
1033
- min_samples_leaf = min_samples_leaf ,
1034
- min_impurity_decrease = min_value_increase ,
1035
- alpha = alpha )
1032
+ intrp , feature_names , treatment_names , _ = self ._tree (True , Xtest , feature_index ,
1033
+ treatment_cost = treatment_cost ,
1034
+ max_depth = max_depth ,
1035
+ min_samples_leaf = min_samples_leaf ,
1036
+ min_impurity_decrease = min_value_increase ,
1037
+ alpha = alpha )
1036
1038
return intrp .plot (feature_names = feature_names , treatment_names = treatment_names )
1037
1039
1038
- def _policy_tree_string (self , Xtest , feature_index , * , treatment_cost = 0 ,
1040
+ def _policy_tree_output (self , Xtest , feature_index , * , treatment_cost = 0 ,
1039
1041
max_depth = 3 , min_samples_leaf = 2 , min_value_increase = 1e-4 , alpha = .1 ):
1040
1042
"""
1041
- Get a recommended policy tree in graphviz format as a string.
1043
+ Get a tuple policy outputs.
1044
+
1045
+ The first item in the tuple is the recommended policy tree in graphviz format as a string.
1046
+ The second item is the recommended treatment for each sample as a list.
1042
1047
1043
1048
Parameters
1044
1049
----------
@@ -1060,18 +1065,18 @@ def _policy_tree_string(self, Xtest, feature_index, *, treatment_cost=0,
1060
1065
1061
1066
Returns
1062
1067
-------
1063
- tree : string
1064
- The policy tree represented as a graphviz string
1068
+ tree : tuple of string, list of int
1069
+ The policy tree represented as a graphviz string and the recommended treatment for each row
1065
1070
"""
1066
1071
1067
- intrp , feature_names , treatment_names = self ._tree (True , Xtest , feature_index ,
1068
- treatment_cost = treatment_cost ,
1069
- max_depth = max_depth ,
1070
- min_samples_leaf = min_samples_leaf ,
1071
- min_impurity_decrease = min_value_increase ,
1072
- alpha = alpha )
1072
+ intrp , feature_names , treatment_names , treat = self ._tree (True , Xtest , feature_index ,
1073
+ treatment_cost = treatment_cost ,
1074
+ max_depth = max_depth ,
1075
+ min_samples_leaf = min_samples_leaf ,
1076
+ min_impurity_decrease = min_value_increase ,
1077
+ alpha = alpha )
1073
1078
return intrp .export_graphviz (feature_names = feature_names ,
1074
- treatment_names = treatment_names )
1079
+ treatment_names = treatment_names ), treat . tolist ()
1075
1080
1076
1081
# TODO: it seems like it would be better to just return the tree itself rather than plot it;
1077
1082
# however, the tree can't store the feature and treatment names we compute here...
@@ -1099,11 +1104,11 @@ def plot_heterogeneity_tree(self, Xtest, feature_index, *,
1099
1104
A (1-alpha)*100% confidence interval is displayed.
1100
1105
"""
1101
1106
1102
- intrp , feature_names , treatment_names = self ._tree (False , Xtest , feature_index ,
1103
- max_depth = max_depth ,
1104
- min_samples_leaf = min_samples_leaf ,
1105
- min_impurity_decrease = min_impurity_decrease ,
1106
- alpha = alpha )
1107
+ intrp , feature_names , treatment_names , _ = self ._tree (False , Xtest , feature_index ,
1108
+ max_depth = max_depth ,
1109
+ min_samples_leaf = min_samples_leaf ,
1110
+ min_impurity_decrease = min_impurity_decrease ,
1111
+ alpha = alpha )
1107
1112
return intrp .plot (feature_names = feature_names ,
1108
1113
treatment_names = treatment_names )
1109
1114
@@ -1131,10 +1136,10 @@ def _heterogeneity_tree_string(self, Xtest, feature_index, *,
1131
1136
A (1-alpha)*100% confidence interval is displayed.
1132
1137
"""
1133
1138
1134
- intrp , feature_names , treatment_names = self ._tree (False , Xtest , feature_index ,
1135
- max_depth = max_depth ,
1136
- min_samples_leaf = min_samples_leaf ,
1137
- min_impurity_decrease = min_impurity_decrease ,
1138
- alpha = alpha )
1139
+ intrp , feature_names , treatment_names , _ = self ._tree (False , Xtest , feature_index ,
1140
+ max_depth = max_depth ,
1141
+ min_samples_leaf = min_samples_leaf ,
1142
+ min_impurity_decrease = min_impurity_decrease ,
1143
+ alpha = alpha )
1139
1144
return intrp .export_graphviz (feature_names = feature_names ,
1140
1145
treatment_names = treatment_names )
0 commit comments