Skip to content

Allow use of OptunaSearchCV with cross_val_predict. #173

@yu9824

Description

@yu9824

Expected behavior

cross_val_predict should accept OptunaSearchCV as estimator but fails when scikit-learn >= 1.4.0 due to validate_params.

https://github.com/scikit-learn/scikit-learn/blob/46b5f541138458803e39f9ce5810878849e4ecf7/sklearn/model_selection/_validation.py#L1035-L1059

Environment

  • Optuna version:3.5.0
  • Optuna Integration version:3.5.0
  • Python version:3.11.6
  • OS:macOS-14.7-arm64-arm-64bit
  • scikit-learn version: 1.4.0

Error messages, stack traces, or logs

---------------------------------------------------------------------------
InvalidParameterError                     Traceback (most recent call last)
Cell In[1], line 15
      6 X, y = make_regression(n_samples=100, n_features=10, bias=1, random_state=334)
      8 ocv = optuna.integration.OptunaSearchCV(
      9     PLSRegression(),
     10     param_distributions=dict(
   (...)
     13     cv=5,
     14 )
---> 15 y_oof = cross_val_predict(ocv, X, y, cv=5)

File ~/miniforge3/envs/py311/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:203, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    200 to_ignore += ["self", "cls"]
    201 params = {k: v for k, v in params.arguments.items() if k not in to_ignore}
--> 203 validate_parameter_constraints(
    204     parameter_constraints, params, caller_name=func.__qualname__
    205 )
    207 try:
    208     with config_context(
    209         skip_parameter_validation=(
    210             prefer_skip_nested_validation or global_skip_validation
    211         )
    212     ):

File ~/miniforge3/envs/py311/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:95, in validate_parameter_constraints(parameter_constraints, params, caller_name)
     89 else:
     90     constraints_str = (
     91         f"{', '.join([str(c) for c in constraints[:-1]])} or"
     92         f" {constraints[-1]}"
     93     )
---> 95 raise InvalidParameterError(
     96     f"The {param_name!r} parameter of {caller_name} must be"
     97     f" {constraints_str}. Got {param_val!r} instead."
     98 )

InvalidParameterError: The 'estimator' parameter of cross_val_predict must be an object implementing 'fit' and 'predict'. Got OptunaSearchCV(cv=5, estimator=PLSRegression(), n_jobs=1,
               param_distributions={'n_components': IntDistribution(high=10, log=False, low=1, step=1)}) instead.

Steps to reproduce

import optuna
from sklearn.cross_decomposition import PLSRegression
from sklearn.datasets import make_regression
from sklearn.model_selection import cross_val_predict

X, y = make_regression(n_samples=100, n_features=10, bias=1, random_state=334)

ocv = optuna.integration.OptunaSearchCV(
    PLSRegression(),
    param_distributions=dict(
        n_components=optuna.distributions.IntDistribution(1, 10)
    ),
    cv=5,
)
y_oof = cross_val_predict(ocv, X, y, cv=5)

Additional context (optional)

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingv5Issue/PR related to Optuna version 5.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions