Skip to content

Commit 6689faa

Browse files
committed
fix n_features_in setting
1 parent 33fb5b6 commit 6689faa

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

python-package/lightgbm/compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def validate_data(
4848
# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
4949
if isinstance(y, str) and y == "no_validation":
5050
X = check_array(X, accept_sparse=accept_sparse, force_all_finite=ensure_all_finite)
51-
return X
51+
return X # noqa: RET504
5252

5353
# if we reach here, we're validating features and labels
5454
X, y = check_X_y(

python-package/lightgbm/sklearn.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,32 @@ def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarra
147147
return weight
148148

149149

150+
def _num_features_for_raw_input(X: _LGBM_ScikitMatrixLike) -> int:
151+
"""Get number of features from raw input data.
152+
153+
Some ``scikit-learn`` versions accomplish this with a private function
154+
``sklearn.utils.validation._num_features()``. This is here to avoid depending on that.
155+
"""
156+
# if a list, assume a list of lists where each item is one row of training data,
157+
# and all rows have the same number of features
158+
if isinstance(X, list):
159+
return len(X[0])
160+
161+
# scikit-learn estimator checks error out if .shape is accessed unconditionally...
162+
# but allow hard-coding an assumption that anything that isn't a list and doesn't have .shape
163+
# must be convertible to something following the Array API via a __array__() method
164+
if hasattr(X, "shape"):
165+
X_shape = X.shape
166+
else:
167+
X_shape = X.__array__().shape
168+
169+
# this condition accounts for training on a 1-dimensional input
170+
if len(X_shape) == 1:
171+
return 1
172+
else:
173+
return X_shape[1]
174+
175+
150176
class _ObjectiveFunctionWrapper:
151177
"""Proxy class for objective function."""
152178

@@ -906,13 +932,10 @@ def fit(
906932
params["metric"] = [e for e in eval_metrics_builtin if e not in params["metric"]] + params["metric"]
907933
params["metric"] = [metric for metric in params["metric"] if metric is not None]
908934

909-
# this needs to be set before calling _LGBMValidateData(), as that function uses
910-
# self.n_features_in_ (the corresponding public attribute) to check that the input has
911-
# the expected number of features
912-
if isinstance(X, list):
913-
self._n_features_in = len(X[0])
914-
else:
915-
self._n_features_in = X.shape[1]
935+
# `sklearn.utils.validation.validate_data()` expects self.n_features_in_ to already be set by the
936+
# time it's called (if you call it with reset=True like LightGBM does), and the scikit-learn
937+
# estimator checks complain if X.shape is accessed unconditionally
938+
self._n_features_in = _num_features_for_raw_input(X)
916939

917940
if not isinstance(X, (pd_DataFrame, dt_DataTable)):
918941
_X, _y = _LGBMValidateData(

0 commit comments

Comments
 (0)