@@ -109,14 +109,7 @@ def _onedal_fit_supported(self, method_name, X, y=None, sample_weight=None):
109
109
_is_csr (X ) and daal_check_version ((2024 , "P" , 700 ))
110
110
) or not issparse (X )
111
111
112
- _acceptable_sample_weights = True
113
- if sample_weight is not None or not isinstance (sample_weight , numbers .Number ):
114
- sample_weight = _check_sample_weight (
115
- sample_weight , X , dtype = X .dtype if hasattr (X , "dtype" ) else None
116
- )
117
- _acceptable_sample_weights = np .allclose (
118
- sample_weight , np .ones_like (sample_weight )
119
- )
112
+ _acceptable_sample_weights = self ._validate_sample_weight (sample_weight , X )
120
113
121
114
patching_status .and_conditions (
122
115
[
@@ -127,7 +120,7 @@ def _onedal_fit_supported(self, method_name, X, y=None, sample_weight=None):
127
120
(correct_count , "n_clusters is smaller than number of samples" ),
128
121
(
129
122
_acceptable_sample_weights ,
130
- "oneDAL doesn't support sample_weight, either None or ones are acceptable " ,
123
+ "oneDAL doesn't support sample_weight. Accepted options are None, constant, or equal weights. " ,
131
124
),
132
125
(
133
126
is_data_supported ,
@@ -161,6 +154,9 @@ def _onedal_fit(self, X, _, sample_weight, queue=None):
161
154
X ,
162
155
accept_sparse = "csr" ,
163
156
dtype = [np .float64 , np .float32 ],
157
+ order = "C" ,
158
+ copy = self .copy_x ,
159
+ accept_large_sparse = False ,
164
160
)
165
161
166
162
if sklearn_check_version ("1.2" ):
@@ -176,6 +172,22 @@ def _onedal_fit(self, X, _, sample_weight, queue=None):
176
172
177
173
self ._save_attributes ()
178
174
175
+ def _validate_sample_weight (self , sample_weight , X ):
176
+ if sample_weight is None :
177
+ return True
178
+ elif isinstance (sample_weight , numbers .Number ):
179
+ return True
180
+ else :
181
+ sample_weight = _check_sample_weight (
182
+ sample_weight ,
183
+ X ,
184
+ dtype = X .dtype if hasattr (X , "dtype" ) else None ,
185
+ )
186
+ if np .all (sample_weight == sample_weight [0 ]):
187
+ return True
188
+ else :
189
+ return False
190
+
179
191
def _onedal_predict_supported (self , method_name , X , sample_weight = None ):
180
192
class_name = self .__class__ .__name__
181
193
is_data_supported = (
@@ -194,12 +206,9 @@ def _onedal_predict_supported(self, method_name, X, sample_weight=None):
194
206
)
195
207
196
208
_acceptable_sample_weights = True
197
- if sample_weight is not None or not isinstance (sample_weight , numbers .Number ):
198
- sample_weight = _check_sample_weight (
199
- sample_weight , X , dtype = X .dtype if hasattr (X , "dtype" ) else None
200
- )
201
- _acceptable_sample_weights = np .allclose (
202
- sample_weight , np .ones_like (sample_weight )
209
+ if not sklearn_check_version ("1.5" ):
210
+ _acceptable_sample_weights = self ._validate_sample_weight (
211
+ sample_weight , X
203
212
)
204
213
205
214
patching_status .and_conditions (
@@ -214,7 +223,7 @@ def _onedal_predict_supported(self, method_name, X, sample_weight=None):
214
223
),
215
224
(
216
225
_acceptable_sample_weights ,
217
- "oneDAL doesn't support sample_weight, None or ones are acceptable " ,
226
+ "oneDAL doesn't support sample_weight. Acceptable options are None, constant, or equal weights. " ,
218
227
),
219
228
]
220
229
)
0 commit comments