Skip to content

Commit

Permalink
fix return type
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Oct 5, 2024
1 parent 6689faa commit 9a05670
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,22 @@ def validate_data(
# trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
**ignored_kwargs,
):
no_val_y = isinstance(y, str) and y == "no_validation"

# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
if isinstance(y, str) and y == "no_validation":
if no_val_y:
X = check_array(X, accept_sparse=accept_sparse, force_all_finite=ensure_all_finite)
return X # noqa: RET504

# if we reach here, we're validating features and labels
X, y = check_X_y(
X,
y,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)
else:
X, y = check_X_y(
X,
y,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)

# this only needs to be updated at fit() time
_estimator._n_features_in = X.shape[1]
# this only needs to be updated at fit() time
_estimator._n_features_in = X.shape[1]

# raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
n_features = X.shape[1]
Expand All @@ -70,7 +70,10 @@ def validate_data(
f"is expecting {_estimator._n_features} features as input."
)

return X, y
if no_val_y:
return X
else:
return X, y

SKLEARN_INSTALLED = True
_LGBMBaseCrossValidator = BaseCrossValidator
Expand Down

0 comments on commit 9a05670

Please sign in to comment.