Skip to content

Commit bd32df6

Browse files
committed
Add separate function for check_y_params
1 parent a4d4e3f commit bd32df6

2 files changed

Lines changed: 16 additions & 9 deletions

File tree

causalml/inference/tree/causal/_tree.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,19 @@
3131
}
3232

3333

34+
def get_check_y_params() -> dict:
35+
"""
36+
Prepares flags depending on the scikit-learn version.
37+
38+
Returns: check_y_params
39+
"""
40+
if Version(sklearn_version) >= Version("1.6"):
41+
check_y_params = dict(ensure_2d=False, dtype=None, ensure_all_finite=False)
42+
else:
43+
check_y_params = dict(ensure_2d=False, dtype=None, force_all_finite=False)
44+
return check_y_params
45+
46+
3447
class BaseCausalDecisionTree(BaseDecisionTree):
3548
"""
3649
Modified base class BaseDecisionTree for causal trees
@@ -66,14 +79,7 @@ def fit(
6679
# Need to validate separately here.
6780
# We can't pass multi_ouput=True because that would allow y to be csr.
6881
check_X_params = dict(dtype=DTYPE, accept_sparse="csc")
69-
if Version(sklearn_version) >= Version("1.6"):
70-
check_y_params = dict(
71-
ensure_2d=False, dtype=None, ensure_all_finite=False
72-
)
73-
else:
74-
check_y_params = dict(
75-
ensure_2d=False, dtype=None, force_all_finite=False
76-
)
82+
check_y_params = get_check_y_params()
7783
X, y = validate_data(
7884
self, X, y, validate_separately=(check_X_params, check_y_params)
7985
)

causalml/inference/tree/causal/causalforest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sklearn.ensemble._forest import _generate_sample_indices, _get_n_samples_bootstrap
2020

2121
from .causaltree import CausalTreeRegressor
22+
from ._tree import get_check_y_params
2223

2324
try:
2425
from packaging.version import parse as Version
@@ -274,7 +275,7 @@ def _fit(
274275
if issparse(y):
275276
raise ValueError("sparse multilabel-indicator for y is not supported.")
276277
check_X_params = dict(dtype=DTYPE, accept_sparse="csc")
277-
check_y_params = dict(ensure_2d=False, dtype=None, ensure_all_finite=False)
278+
check_y_params = get_check_y_params()
278279
X, y = validate_data(
279280
self,
280281
X,

0 commit comments

Comments
 (0)