|
17 | 17 | import warnings
|
18 | 18 |
|
19 | 19 | import numpy as np
|
| 20 | +from scipy.sparse import issparse |
20 | 21 | from sklearn.base import BaseEstimator
|
21 | 22 | from sklearn.utils import check_array
|
22 | 23 | from sklearn.utils.validation import _check_sample_weight
|
23 | 24 |
|
24 | 25 | from daal4py.sklearn._n_jobs_support import control_n_jobs
|
25 |
| -from daal4py.sklearn._utils import sklearn_check_version |
| 26 | +from daal4py.sklearn._utils import daal_check_version, sklearn_check_version |
26 | 27 | from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics
|
| 28 | +from onedal.utils import _is_csr |
27 | 29 |
|
28 | 30 | from .._device_offload import dispatch
|
29 | 31 | from .._utils import IntelEstimator, PatchingConditionsChain
|
@@ -62,13 +64,13 @@ class BasicStatistics(IntelEstimator, BaseEstimator):
|
62 | 64 | mean_ : ndarray of shape (n_features,)
|
63 | 65 | Mean of each feature over all samples.
|
64 | 66 | variance_ : ndarray of shape (n_features,)
|
65 |
| - Variance of each feature over all samples. |
| 67 | + Variance of each feature over all samples. Bessel's correction is used. |
66 | 68 | variation_ : ndarray of shape (n_features,)
|
67 |
| - Variation of each feature over all samples. |
| 69 | + Variation of each feature over all samples. Bessel's correction is used. |
68 | 70 | sum_squares_ : ndarray of shape (n_features,)
|
69 | 71 | Sum of squares for each feature over all samples.
|
70 | 72 | standard_deviation_ : ndarray of shape (n_features,)
|
71 |
| - Standard deviation of each feature over all samples. |
| 73 | + Unbiased standard deviation of each feature over all samples. Bessel's correction is used. |
72 | 74 | sum_squares_centered_ : ndarray of shape (n_features,)
|
73 | 75 | Centered sum of squares for each feature over all samples.
|
74 | 76 | second_order_raw_moment_ : ndarray of shape (n_features,)
|
@@ -166,21 +168,50 @@ def __getattr__(self, attr):
|
166 | 168 | f"'{self.__class__.__name__}' object has no attribute '{attr}'"
|
167 | 169 | )
|
168 | 170 |
|
169 |
| - def _onedal_supported(self, method_name, *data): |
| 171 | + def _onedal_cpu_supported(self, method_name, *data): |
170 | 172 | patching_status = PatchingConditionsChain(
|
171 | 173 | f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}"
|
172 | 174 | )
|
173 | 175 | return patching_status
|
174 | 176 |
|
175 |
| - _onedal_cpu_supported = _onedal_supported |
176 |
| - _onedal_gpu_supported = _onedal_supported |
| 177 | + def _onedal_gpu_supported(self, method_name, *data): |
| 178 | + patching_status = PatchingConditionsChain( |
| 179 | + f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}" |
| 180 | + ) |
| 181 | + X, sample_weight = data |
| 182 | + |
| 183 | + is_data_supported = not issparse(X) or ( |
| 184 | + _is_csr(X) and daal_check_version((2025, "P", 200)) |
| 185 | + ) |
| 186 | + |
| 187 | + is_sample_weight_supported = sample_weight is None or not issparse(X) |
| 188 | + |
| 189 | + patching_status.and_conditions( |
| 190 | + [ |
| 191 | + ( |
| 192 | + is_sample_weight_supported, |
| 193 | + "Sample weights are not supported for CSR data format", |
| 194 | + ), |
| 195 | + ( |
| 196 | + is_data_supported, |
| 197 | + "Supported data formats: Dense, CSR (oneDAL version >= 2025.2.0).", |
| 198 | + ), |
| 199 | + ] |
| 200 | + ) |
| 201 | + return patching_status |
177 | 202 |
|
178 | 203 | def _onedal_fit(self, X, sample_weight=None, queue=None):
|
179 | 204 | if sklearn_check_version("1.2"):
|
180 | 205 | self._validate_params()
|
181 | 206 |
|
182 | 207 | if sklearn_check_version("1.0"):
|
183 |
| - X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False) |
| 208 | + X = validate_data( |
| 209 | + self, |
| 210 | + X, |
| 211 | + dtype=[np.float64, np.float32], |
| 212 | + ensure_2d=False, |
| 213 | + accept_sparse="csr", |
| 214 | + ) |
184 | 215 | else:
|
185 | 216 | X = check_array(X, dtype=[np.float64, np.float32])
|
186 | 217 |
|
|
0 commit comments