Skip to content

Commit 50c640d

Browse files
authored
#775: Support multiple treatments in CausalTreeRegressor and CausalRandomForestRegressor (#852)
* Add outcome vector y preparation for multiple treatment groups * Update CausalTreeRegressor class * Add multiple groups support in cython part of causal trees building * Add min_group_size parameter to control tree building * Update python part of causal trees, keep consistent namings for tree builder * Add an option to pass custom matplotlib axes in charts * Update Jupyter notebooks with causal trees and forests * Keep consistent codestyle with black * Update validate_data arguments support for sklearn <1.6 & >=1.6 * Extend causal tree and forest tests for multiple treatment groups * Keep consistent codestyle with black * Fix description in causal trees notebok * Keep consistent codestyle in tests * Remove unused cython var in criterion header * Add separate function for check_y_params
1 parent 6f4ec71 commit 50c640d

18 files changed

Lines changed: 5407 additions & 1137 deletions

causalml/inference/tree/_tree/_criterion.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ cdef class Criterion:
2424

2525
# Internal structures
2626
cdef const float64_t[:, ::1] y # Values of y
27-
cdef const int32_t[:] treatment # Treatment assignment: 1 for treatment, 0 for control
2827
cdef const float64_t[:] sample_weight # Sample weights
2928

3029
cdef const intp_t[:] sample_indices # Sample indices in X, y
@@ -50,7 +49,6 @@ cdef class Criterion:
5049
cdef int init(
5150
self,
5251
const float64_t[:, ::1] y,
53-
const int32_t[:] treatment,
5452
const float64_t[:] sample_weight,
5553
float64_t weighted_n_samples,
5654
const intp_t[:] sample_indices,

causalml/inference/tree/_tree/_criterion.pyx

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ cdef class Criterion:
4949
cdef int init(
5050
self,
5151
const float64_t[:, ::1] y,
52-
const int32_t[:] treatment,
5352
const float64_t[:] sample_weight,
5453
float64_t weighted_n_samples,
5554
const intp_t[:] sample_indices,
@@ -357,7 +356,6 @@ cdef class ClassificationCriterion(Criterion):
357356
cdef int init(
358357
self,
359358
const float64_t[:, ::1] y,
360-
const int32_t[:] treatment,
361359
const float64_t[:] sample_weight,
362360
float64_t weighted_n_samples,
363361
const intp_t[:] sample_indices,
@@ -871,7 +869,6 @@ cdef class RegressionCriterion(Criterion):
871869
cdef int init(
872870
self,
873871
const float64_t[:, ::1] y,
874-
const int32_t[:] treatment,
875872
const float64_t[:] sample_weight,
876873
float64_t weighted_n_samples,
877874
const intp_t[:] sample_indices,
@@ -1250,7 +1247,6 @@ cdef class MAE(RegressionCriterion):
12501247
cdef int init(
12511248
self,
12521249
const float64_t[:, ::1] y,
1253-
const int32_t[:] treatment,
12541250
const float64_t[:] sample_weight,
12551251
float64_t weighted_n_samples,
12561252
const intp_t[:] sample_indices,

causalml/inference/tree/_tree/_splitter.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ cdef class Splitter:
7070
# +1: monotonic increase
7171
cdef const int8_t[:] monotonic_cst
7272
cdef bint with_monotonic_cst
73-
cdef const int32_t[:] treatment
7473
cdef const float64_t[:] sample_weight
7574

7675
# The samples vector `samples` is maintained by the Splitter object such
@@ -94,7 +93,6 @@ cdef class Splitter:
9493
self,
9594
object X,
9695
const float64_t[:, ::1] y,
97-
const int32_t[:] treatment,
9896
const float64_t[:] sample_weight,
9997
const unsigned char[::1] missing_values_in_feature_mask,
10098
) except -1

causalml/inference/tree/_tree/_splitter.pyx

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ cdef class Splitter:
125125
self,
126126
object X,
127127
const float64_t[:, ::1] y,
128-
const int32_t[:] treatment,
129128
const float64_t[:] sample_weight,
130129
const unsigned char[::1] missing_values_in_feature_mask,
131130
) except -1:
@@ -145,9 +144,6 @@ cdef class Splitter:
145144
This is the vector of targets, or true labels, for the samples represented
146145
as a Cython memoryview.
147146
148-
treatment : ndarray, dtype=int32_t
149-
The treatment labels for each sample, represented as a Cython memoryview.
150-
151147
sample_weight : ndarray, dtype=float64_t
152148
The weights of the samples, where higher weighted samples are fit
153149
closer than lower weight samples. If not provided, all samples
@@ -194,7 +190,6 @@ cdef class Splitter:
194190

195191
self.y = y
196192

197-
self.treatment = treatment
198193
self.sample_weight = sample_weight
199194
if missing_values_in_feature_mask is not None:
200195
self.criterion.init_sum_missing()
@@ -226,7 +221,6 @@ cdef class Splitter:
226221

227222
self.criterion.init(
228223
self.y,
229-
self.treatment,
230224
self.sample_weight,
231225
self.weighted_n_samples,
232226
self.samples,
@@ -1515,11 +1509,10 @@ cdef class BestSplitter(Splitter):
15151509
self,
15161510
object X,
15171511
const float64_t[:, ::1] y,
1518-
const int32_t[:] treatment,
15191512
const float64_t[:] sample_weight,
15201513
const unsigned char[::1] missing_values_in_feature_mask,
15211514
) except -1:
1522-
Splitter.init(self, X, y, treatment, sample_weight, missing_values_in_feature_mask)
1515+
Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask)
15231516
self.partitioner = DensePartitioner(
15241517
X, self.samples, self.feature_values, missing_values_in_feature_mask
15251518
)
@@ -1546,11 +1539,10 @@ cdef class BestSparseSplitter(Splitter):
15461539
self,
15471540
object X,
15481541
const float64_t[:, ::1] y,
1549-
const int32_t[:] treatment,
15501542
const float64_t[:] sample_weight,
15511543
const unsigned char[::1] missing_values_in_feature_mask,
15521544
) except -1:
1553-
Splitter.init(self, X, y, treatment, sample_weight, missing_values_in_feature_mask)
1545+
Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask)
15541546
self.partitioner = SparsePartitioner(
15551547
X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask
15561548
)
@@ -1577,11 +1569,10 @@ cdef class RandomSplitter(Splitter):
15771569
self,
15781570
object X,
15791571
const float64_t[:, ::1] y,
1580-
const int32_t[:] treatment,
15811572
const float64_t[:] sample_weight,
15821573
const unsigned char[::1] missing_values_in_feature_mask,
15831574
) except -1:
1584-
Splitter.init(self, X, y, treatment, sample_weight, missing_values_in_feature_mask)
1575+
Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask)
15851576
self.partitioner = DensePartitioner(
15861577
X, self.samples, self.feature_values, missing_values_in_feature_mask
15871578
)
@@ -1608,11 +1599,10 @@ cdef class RandomSparseSplitter(Splitter):
16081599
self,
16091600
object X,
16101601
const float64_t[:, ::1] y,
1611-
const int32_t[:] treatment,
16121602
const float64_t[:] sample_weight,
16131603
const unsigned char[::1] missing_values_in_feature_mask,
16141604
) except -1:
1615-
Splitter.init(self, X, y, treatment, sample_weight, missing_values_in_feature_mask)
1605+
Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask)
16161606
self.partitioner = SparsePartitioner(
16171607
X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask
16181608
)

causalml/inference/tree/_tree/_tree.pxd

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,19 @@ cdef class TreeBuilder:
106106
# This class controls the various stopping criteria and the node splitting
107107
# evaluation order, e.g. depth-first or best-first.
108108

109-
cdef Splitter splitter # Splitting algorithm
109+
cdef Splitter splitter # Splitting algorithm
110110

111-
cdef intp_t min_samples_split # Minimum number of samples in an internal node
112-
cdef intp_t min_samples_leaf # Minimum number of samples in a leaf
111+
cdef intp_t min_samples_split # Minimum number of samples in an internal node
112+
cdef intp_t min_samples_leaf # Minimum number of samples in a leaf
113113
cdef float64_t min_weight_leaf # Minimum weight in a leaf
114-
cdef intp_t max_depth # Maximal tree depth
114+
cdef intp_t max_depth # Maximal tree depth
115115
cdef float64_t min_impurity_decrease # Impurity threshold for early stopping
116116

117117
cpdef build(
118118
self,
119119
Tree tree,
120120
object X,
121121
const float64_t[:, ::1] y,
122-
const int32_t[:] treatment,
123122
const float64_t[:] sample_weight=*,
124123
const unsigned char[::1] missing_values_in_feature_mask=*,
125124
)
@@ -128,7 +127,6 @@ cdef class TreeBuilder:
128127
self,
129128
object X,
130129
const float64_t[:, ::1] y,
131-
const int32_t[:] treatment,
132130
const float64_t[:] sample_weight,
133131
)
134132

causalml/inference/tree/_tree/_tree.pyx

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ cdef class TreeBuilder:
106106
Tree tree,
107107
object X,
108108
const float64_t[:, ::1] y,
109-
const int32_t[:] treatment,
110109
const float64_t[:] sample_weight=None,
111110
const unsigned char[::1] missing_values_in_feature_mask=None,
112111
):
@@ -117,7 +116,6 @@ cdef class TreeBuilder:
117116
self,
118117
object X,
119118
const float64_t[:, ::1] y,
120-
const int32_t[:] treatment,
121119
const float64_t[:] sample_weight,
122120
):
123121
"""Check input dtype, layout and format"""
@@ -142,9 +140,6 @@ cdef class TreeBuilder:
142140
if y.base.dtype != DOUBLE or not y.base.flags.contiguous:
143141
y = np.ascontiguousarray(y, dtype=DOUBLE)
144142

145-
if treatment.base.dtype != INT or not treatment.base.flags.contiguous:
146-
treatment = np.ascontiguousarray(treatment, dtype=INT)
147-
148143
if (
149144
sample_weight is not None and
150145
(
@@ -154,15 +149,16 @@ cdef class TreeBuilder:
154149
):
155150
sample_weight = np.asarray(sample_weight, dtype=DOUBLE, order="C")
156151

157-
return X, y, treatment, sample_weight
152+
return X, y, sample_weight
158153

159154
# Depth first builder ---------------------------------------------------------
160155
cdef class DepthFirstTreeBuilder(TreeBuilder):
161156
"""Build a decision tree in depth-first fashion."""
162157

163158
def __cinit__(self, Splitter splitter, intp_t min_samples_split,
164159
intp_t min_samples_leaf, float64_t min_weight_leaf,
165-
intp_t max_depth, float64_t min_impurity_decrease):
160+
intp_t max_depth, float64_t min_impurity_decrease,
161+
*args, **kwargs):
166162
self.splitter = splitter
167163
self.min_samples_split = min_samples_split
168164
self.min_samples_leaf = min_samples_leaf
@@ -175,14 +171,13 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
175171
Tree tree,
176172
object X,
177173
const float64_t[:, ::1] y,
178-
const int32_t[:] treatment,
179174
const float64_t[:] sample_weight=None,
180175
const unsigned char[::1] missing_values_in_feature_mask=None,
181176
):
182177
"""Build a decision tree from the training set (X, y)."""
183178

184179
# check input
185-
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)
180+
X, y, sample_weight = self._check_input(X, y, sample_weight)
186181

187182
# Initial capacity
188183
cdef intp_t init_capacity
@@ -203,7 +198,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
203198
cdef float64_t min_impurity_decrease = self.min_impurity_decrease
204199

205200
# Recursive partition (without actual recursion)
206-
splitter.init(X, y, treatment, sample_weight, missing_values_in_feature_mask)
201+
splitter.init(X, y, sample_weight, missing_values_in_feature_mask)
207202

208203
cdef intp_t start
209204
cdef intp_t end
@@ -399,7 +394,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
399394
def __cinit__(self, Splitter splitter, intp_t min_samples_split,
400395
intp_t min_samples_leaf, min_weight_leaf,
401396
intp_t max_depth, intp_t max_leaf_nodes,
402-
float64_t min_impurity_decrease):
397+
float64_t min_impurity_decrease,
398+
*args, **kwargs):
403399
self.splitter = splitter
404400
self.min_samples_split = min_samples_split
405401
self.min_samples_leaf = min_samples_leaf
@@ -413,21 +409,20 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
413409
Tree tree,
414410
object X,
415411
const float64_t[:, ::1] y,
416-
const int32_t[:] treatment,
417412
const float64_t[:] sample_weight=None,
418413
const unsigned char[::1] missing_values_in_feature_mask=None,
419414
):
420415
"""Build a decision tree from the training set (X, y)."""
421416

422417
# check input
423-
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)
418+
X, y, sample_weight = self._check_input(X, y, sample_weight)
424419

425420
# Parameters
426421
cdef Splitter splitter = self.splitter
427422
cdef intp_t max_leaf_nodes = self.max_leaf_nodes
428423

429424
# Recursive partition (without actual recursion)
430-
splitter.init(X, y, treatment, sample_weight, missing_values_in_feature_mask)
425+
splitter.init(X, y, sample_weight, missing_values_in_feature_mask)
431426

432427
cdef vector[FrontierRecord] frontier
433428
cdef FrontierRecord record

causalml/inference/tree/causal/_builder.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66

77
from .._tree._tree cimport Node, Tree, TreeBuilder
88
from .._tree._splitter cimport Splitter, SplitRecord
9-
from .._tree._tree cimport intp_t, int32_t, float64_t
9+
from .._tree._typedefs cimport intp_t, int32_t, int64_t, float32_t, float64_t
1010
from .._tree._tree cimport FrontierRecord, StackRecord
1111
from .._tree._tree cimport ParentInfo, _init_parent_record

0 commit comments

Comments
 (0)