Skip to content

CausalRandomForestRegressor.fit() fails on scikit-learn 1.9 (forest sampler signature change) #905

@jeongyoonlee

Description

@jeongyoonlee

Describe the bug

After #903 fixed the import error (#902), CausalRandomForestRegressor.fit() still fails on scikit-learn 1.9.

causalml/inference/tree/causal/causalforest.py vendors scikit-learn's BaseForest._fit() internals (to thread the treatment vector through _parallel_build_trees), so it calls several private forest helpers directly. scikit-learn 1.9 (scikit-learn#31529) made sample_weight a required argument on these helpers:

  • BaseForest._validate_y_class_weight(y)(y, sample_weight)
  • _generate_sample_indices(random_state, n_samples, n_samples_bootstrap)(..., sample_weight)
  • _get_n_samples_bootstrap(n_samples, max_samples)(..., sample_weight)

To Reproduce

pip install -U causalml
pip install 'scikit-learn==1.9.0'
from causalml.dataset import synthetic_data
from causalml.inference.tree import CausalRandomForestRegressor

y, X, w, *_ = synthetic_data(mode=1, n=500, p=5, sigma=1.0)
CausalRandomForestRegressor(control_name=0).fit(X=X, treatment=w, y=y)
TypeError: BaseForest._validate_y_class_weight() missing 1 required positional argument: 'sample_weight'
  .../causalml/inference/tree/causal/causalforest.py  (in _fit)

Expected behavior

fit() / predict() work on scikit-learn 1.9 with results unchanged from earlier versions.

Fix

Version-guard the three calls, passing sample_weight=None on >= 1.9. On 1.9, _generate_sample_indices(..., None) runs the identical uniform randint draw and _get_n_samples_bootstrap(..., None) keeps the old unweighted semantics, so causal-forest predictions are bit-for-bit identical across versions (the vendored _parallel_build_trees already applies the user's sample_weight). PR incoming.

Note: calculate_error() has a separate, independent incompatibility via forestci — tracked in #906.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions