Describe the bug
After #903 fixed the import error (#902), CausalRandomForestRegressor.fit() still fails on scikit-learn 1.9.
causalml/inference/tree/causal/causalforest.py vendors scikit-learn's BaseForest._fit() internals (to thread the treatment vector through _parallel_build_trees), so it calls several private forest helpers directly. scikit-learn 1.9 (scikit-learn#31529) made sample_weight a required argument on these helpers:
BaseForest._validate_y_class_weight(y) → (y, sample_weight)
_generate_sample_indices(random_state, n_samples, n_samples_bootstrap) → (..., sample_weight)
_get_n_samples_bootstrap(n_samples, max_samples) → (..., sample_weight)
To Reproduce
pip install -U causalml
pip install 'scikit-learn==1.9.0'
from causalml.dataset import synthetic_data
from causalml.inference.tree import CausalRandomForestRegressor
y, X, w, *_ = synthetic_data(mode=1, n=500, p=5, sigma=1.0)
CausalRandomForestRegressor(control_name=0).fit(X=X, treatment=w, y=y)
TypeError: BaseForest._validate_y_class_weight() missing 1 required positional argument: 'sample_weight'
.../causalml/inference/tree/causal/causalforest.py (in _fit)
Expected behavior
fit() / predict() work on scikit-learn 1.9 with results unchanged from earlier versions.
Fix
Version-guard the three calls, passing sample_weight=None on >= 1.9. On 1.9, _generate_sample_indices(..., None) runs the identical uniform randint draw and _get_n_samples_bootstrap(..., None) keeps the old unweighted semantics, so causal-forest predictions are bit-for-bit identical across versions (the vendored _parallel_build_trees already applies the user's sample_weight). PR incoming.
Note: calculate_error() has a separate, independent incompatibility via forestci — tracked in #906.
Describe the bug
After #903 fixed the import error (#902),
CausalRandomForestRegressor.fit()still fails on scikit-learn 1.9.causalml/inference/tree/causal/causalforest.pyvendors scikit-learn'sBaseForest._fit()internals (to thread thetreatmentvector through_parallel_build_trees), so it calls several private forest helpers directly. scikit-learn 1.9 (scikit-learn#31529) madesample_weighta required argument on these helpers:BaseForest._validate_y_class_weight(y)→(y, sample_weight)_generate_sample_indices(random_state, n_samples, n_samples_bootstrap)→(..., sample_weight)_get_n_samples_bootstrap(n_samples, max_samples)→(..., sample_weight)To Reproduce
Expected behavior
fit()/predict()work on scikit-learn 1.9 with results unchanged from earlier versions.Fix
Version-guard the three calls, passing
sample_weight=Noneon >= 1.9. On 1.9,_generate_sample_indices(..., None)runs the identical uniformrandintdraw and_get_n_samples_bootstrap(..., None)keeps the old unweighted semantics, so causal-forest predictions are bit-for-bit identical across versions (the vendored_parallel_build_treesalready applies the user'ssample_weight). PR incoming.Note:
calculate_error()has a separate, independent incompatibility viaforestci— tracked in #906.