Skip to content

Commit 5b149cb

Browse files
md-shafiul-alammd.shafiul.alam
and
md.shafiul.alam
authored
Optimize and fix sample weight checks for Kmeans predict (uxlfoundation#2042)
* sample weight modification * refactor * ci fix * refactor * refactor * ci fix * add back the checks with version check * refactor * ci fix * lint * ci fix --------- Co-authored-by: md.shafiul.alam <[email protected]>
1 parent 5df8206 commit 5b149cb

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

sklearnex/cluster/k_means.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,7 @@ def _onedal_fit_supported(self, method_name, X, y=None, sample_weight=None):
109109
_is_csr(X) and daal_check_version((2024, "P", 700))
110110
) or not issparse(X)
111111

112-
_acceptable_sample_weights = True
113-
if sample_weight is not None or not isinstance(sample_weight, numbers.Number):
114-
sample_weight = _check_sample_weight(
115-
sample_weight, X, dtype=X.dtype if hasattr(X, "dtype") else None
116-
)
117-
_acceptable_sample_weights = np.allclose(
118-
sample_weight, np.ones_like(sample_weight)
119-
)
112+
_acceptable_sample_weights = self._validate_sample_weight(sample_weight, X)
120113

121114
patching_status.and_conditions(
122115
[
@@ -127,7 +120,7 @@ def _onedal_fit_supported(self, method_name, X, y=None, sample_weight=None):
127120
(correct_count, "n_clusters is smaller than number of samples"),
128121
(
129122
_acceptable_sample_weights,
130-
"oneDAL doesn't support sample_weight, either None or ones are acceptable",
123+
"oneDAL doesn't support sample_weight. Accepted options are None, constant, or equal weights.",
131124
),
132125
(
133126
is_data_supported,
@@ -161,6 +154,9 @@ def _onedal_fit(self, X, _, sample_weight, queue=None):
161154
X,
162155
accept_sparse="csr",
163156
dtype=[np.float64, np.float32],
157+
order="C",
158+
copy=self.copy_x,
159+
accept_large_sparse=False,
164160
)
165161

166162
if sklearn_check_version("1.2"):
@@ -176,6 +172,22 @@ def _onedal_fit(self, X, _, sample_weight, queue=None):
176172

177173
self._save_attributes()
178174

175+
def _validate_sample_weight(self, sample_weight, X):
176+
if sample_weight is None:
177+
return True
178+
elif isinstance(sample_weight, numbers.Number):
179+
return True
180+
else:
181+
sample_weight = _check_sample_weight(
182+
sample_weight,
183+
X,
184+
dtype=X.dtype if hasattr(X, "dtype") else None,
185+
)
186+
if np.all(sample_weight == sample_weight[0]):
187+
return True
188+
else:
189+
return False
190+
179191
def _onedal_predict_supported(self, method_name, X, sample_weight=None):
180192
class_name = self.__class__.__name__
181193
is_data_supported = (
@@ -194,12 +206,9 @@ def _onedal_predict_supported(self, method_name, X, sample_weight=None):
194206
)
195207

196208
_acceptable_sample_weights = True
197-
if sample_weight is not None or not isinstance(sample_weight, numbers.Number):
198-
sample_weight = _check_sample_weight(
199-
sample_weight, X, dtype=X.dtype if hasattr(X, "dtype") else None
200-
)
201-
_acceptable_sample_weights = np.allclose(
202-
sample_weight, np.ones_like(sample_weight)
209+
if not sklearn_check_version("1.5"):
210+
_acceptable_sample_weights = self._validate_sample_weight(
211+
sample_weight, X
203212
)
204213

205214
patching_status.and_conditions(
@@ -214,7 +223,7 @@ def _onedal_predict_supported(self, method_name, X, sample_weight=None):
214223
),
215224
(
216225
_acceptable_sample_weights,
217-
"oneDAL doesn't support sample_weight, None or ones are acceptable",
226+
"oneDAL doesn't support sample_weight. Acceptable options are None, constant, or equal weights.",
218227
),
219228
]
220229
)

0 commit comments

Comments
 (0)