@@ -147,6 +147,32 @@ def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarra
147
147
return weight
148
148
149
149
150
+ def _num_features_for_raw_input (X : _LGBM_ScikitMatrixLike ) -> int :
151
+ """Get number of features from raw input data.
152
+
153
+ Some ``scikit-learn`` versions accomplish this with a private function
154
+ ``sklearn.utils.validation._num_features()``. This is here to avoid depending on that.
155
+ """
156
+ # if a list, assume a list of lists where each item is one row of training data,
157
+ # and all rows have the same number of features
158
+ if isinstance (X , list ):
159
+ return len (X [0 ])
160
+
161
+ # scikit-learn estimator checks error out if .shape is accessed unconditionally...
162
+ # but allow hard-coding an assumption that anything that isn't a list and doesn't have .shape
163
+ # must be convertible to something following the Array API via a __array__() method
164
+ if hasattr (X , "shape" ):
165
+ X_shape = X .shape
166
+ else :
167
+ X_shape = X .__array__ ().shape
168
+
169
+ # this condition accounts for training on a 1-dimensional input
170
+ if len (X_shape ) == 1 :
171
+ return 1
172
+ else :
173
+ return X_shape [1 ]
174
+
175
+
150
176
class _ObjectiveFunctionWrapper :
151
177
"""Proxy class for objective function."""
152
178
@@ -906,13 +932,10 @@ def fit(
906
932
params ["metric" ] = [e for e in eval_metrics_builtin if e not in params ["metric" ]] + params ["metric" ]
907
933
params ["metric" ] = [metric for metric in params ["metric" ] if metric is not None ]
908
934
909
- # this needs to be set before calling _LGBMValidateData(), as that function uses
910
- # self.n_features_in_ (the corresponding public attribute) to check that the input has
911
- # the expected number of features
912
- if isinstance (X , list ):
913
- self ._n_features_in = len (X [0 ])
914
- else :
915
- self ._n_features_in = X .shape [1 ]
935
+ # `sklearn.utils.validation.validate_data()` expects self.n_features_in_ to already be set by the
936
+ # time it's called (if you call it with reset=True like LightGBM does), and the scikit-learn
937
+ # estimator checks complain if X.shape is accessed unconditionally
938
+ self ._n_features_in = _num_features_for_raw_input (X )
916
939
917
940
if not isinstance (X , (pd_DataFrame , dt_DataTable )):
918
941
_X , _y = _LGBMValidateData (
0 commit comments