Skip to content

Commit

Permalink
fix n_features_in setting
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Oct 5, 2024
1 parent 33fb5b6 commit 6689faa
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def validate_data(
# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
if isinstance(y, str) and y == "no_validation":
X = check_array(X, accept_sparse=accept_sparse, force_all_finite=ensure_all_finite)
return X
return X # noqa: RET504

# if we reach here, we're validating features and labels
X, y = check_X_y(
Expand Down
37 changes: 30 additions & 7 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,32 @@ def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarra
return weight


def _num_features_for_raw_input(X: _LGBM_ScikitMatrixLike) -> int:
"""Get number of features from raw input data.
Some ``scikit-learn`` versions accomplish this with a private function
``sklearn.utils.validation._num_features()``. This is here to avoid depending on that.
"""
# if a list, assume a list of lists where each item is one row of training data,
# and all rows have the same number of features
if isinstance(X, list):
return len(X[0])

# scikit-learn estimator checks error out if .shape is accessed unconditionally...
# but allow hard-coding an assumption that anything that isn't a list and doesn't have .shape
# must be convertible to something following the Array API via a __array__() method
if hasattr(X, "shape"):
X_shape = X.shape
else:
X_shape = X.__array__().shape

# this condition accounts for training on a 1-dimensional input
if len(X_shape) == 1:
return 1
else:
return X_shape[1]


class _ObjectiveFunctionWrapper:
"""Proxy class for objective function."""

Expand Down Expand Up @@ -906,13 +932,10 @@ def fit(
params["metric"] = [e for e in eval_metrics_builtin if e not in params["metric"]] + params["metric"]
params["metric"] = [metric for metric in params["metric"] if metric is not None]

# this needs to be set before calling _LGBMValidateData(), as that function uses
# self.n_features_in_ (the corresponding public attribute) to check that the input has
# the expected number of features
if isinstance(X, list):
self._n_features_in = len(X[0])
else:
self._n_features_in = X.shape[1]
# `sklearn.utils.validation.validate_data()` expects self.n_features_in_ to already be set by the
# time it's called (if you call it with reset=True like LightGBM does), and the scikit-learn
# estimator checks complain if X.shape is accessed unconditionally
self._n_features_in = _num_features_for_raw_input(X)

if not isinstance(X, (pd_DataFrame, dt_DataTable)):
_X, _y = _LGBMValidateData(
Expand Down

0 comments on commit 6689faa

Please sign in to comment.