Skip to content

Commit e639a3b

Browse files
authored
Add sklearnex.BasicStatistics API for CSR inputs on GPU and a test for it (#2253)
- sklearnex.BasicStatistics API was extended to accept CSR inputs on GPU - The tests for CSR inputs on CPU and GPU were added
1 parent 5feeec0 commit e639a3b

File tree

7 files changed

+251
-29
lines changed

7 files changed

+251
-29
lines changed

onedal/basic_statistics/basic_statistics.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,13 @@ def fit(self, data, sample_weight=None, queue=None):
8282
sample_weight = _check_array(sample_weight, ensure_2d=False)
8383

8484
is_single_dim = data.ndim == 1
85-
data_table, weights_table = to_table(data, sample_weight, queue=queue)
85+
86+
data_table = to_table(data, queue=queue)
87+
weights_table = (
88+
to_table(sample_weight, queue=queue)
89+
if sample_weight is not None
90+
else to_table(None)
91+
)
8692

8793
dtype = data_table.dtype
8894
raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr)

onedal/basic_statistics/tests/utils.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,35 @@
1616

1717
import numpy as np
1818

19+
20+
# Compute unbiased variation for the columns of array-like X
21+
def variation(X):
22+
X_mean = np.mean(X, axis=0)
23+
if np.all(X_mean):
24+
# Avoid division by zero
25+
return np.std(X, axis=0, ddof=1) / X_mean
26+
else:
27+
return np.array(
28+
[
29+
x / y if y != 0 else np.nan
30+
for x, y in zip(np.std(X, axis=0, ddof=1), X_mean)
31+
]
32+
)
33+
34+
1935
options_and_tests = {
2036
"sum": (lambda X: np.sum(X, axis=0), (5e-4, 1e-7)),
2137
"min": (lambda X: np.min(X, axis=0), (1e-7, 1e-7)),
2238
"max": (lambda X: np.max(X, axis=0), (1e-7, 1e-7)),
2339
"mean": (lambda X: np.mean(X, axis=0), (5e-7, 1e-7)),
24-
"variance": (lambda X: np.var(X, axis=0), (2e-3, 2e-3)),
25-
"variation": (lambda X: np.std(X, axis=0) / np.mean(X, axis=0), (5e-2, 5e-2)),
40+
# sklearnex computes unbiased variance and standard deviation that is why ddof=1
41+
"variance": (lambda X: np.var(X, axis=0, ddof=1), (2e-4, 1e-7)),
42+
"variation": (lambda X: variation(X), (1e-3, 1e-6)),
2643
"sum_squares": (lambda X: np.sum(np.square(X), axis=0), (2e-4, 1e-7)),
2744
"sum_squares_centered": (
2845
lambda X: np.sum(np.square(X - np.mean(X, axis=0)), axis=0),
29-
(2e-4, 1e-7),
46+
(1e-3, 1e-7),
3047
),
31-
"standard_deviation": (lambda X: np.std(X, axis=0), (2e-3, 2e-3)),
48+
"standard_deviation": (lambda X: np.std(X, axis=0, ddof=1), (2e-3, 1e-7)),
3249
"second_order_raw_moment": (lambda X: np.mean(np.square(X), axis=0), (1e-6, 1e-7)),
3350
}

sklearnex/basic_statistics/basic_statistics.py

+39-8
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
import warnings
1818

1919
import numpy as np
20+
from scipy.sparse import issparse
2021
from sklearn.base import BaseEstimator
2122
from sklearn.utils import check_array
2223
from sklearn.utils.validation import _check_sample_weight
2324

2425
from daal4py.sklearn._n_jobs_support import control_n_jobs
25-
from daal4py.sklearn._utils import sklearn_check_version
26+
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
2627
from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics
28+
from onedal.utils import _is_csr
2729

2830
from .._device_offload import dispatch
2931
from .._utils import IntelEstimator, PatchingConditionsChain
@@ -62,13 +64,13 @@ class BasicStatistics(IntelEstimator, BaseEstimator):
6264
mean_ : ndarray of shape (n_features,)
6365
Mean of each feature over all samples.
6466
variance_ : ndarray of shape (n_features,)
65-
Variance of each feature over all samples.
67+
Variance of each feature over all samples. Bessel's correction is used.
6668
variation_ : ndarray of shape (n_features,)
67-
Variation of each feature over all samples.
69+
Variation of each feature over all samples. Bessel's correction is used.
6870
sum_squares_ : ndarray of shape (n_features,)
6971
Sum of squares for each feature over all samples.
7072
standard_deviation_ : ndarray of shape (n_features,)
71-
Standard deviation of each feature over all samples.
73+
Unbiased standard deviation of each feature over all samples. Bessel's correction is used.
7274
sum_squares_centered_ : ndarray of shape (n_features,)
7375
Centered sum of squares for each feature over all samples.
7476
second_order_raw_moment_ : ndarray of shape (n_features,)
@@ -166,21 +168,50 @@ def __getattr__(self, attr):
166168
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
167169
)
168170

169-
def _onedal_supported(self, method_name, *data):
171+
def _onedal_cpu_supported(self, method_name, *data):
170172
patching_status = PatchingConditionsChain(
171173
f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}"
172174
)
173175
return patching_status
174176

175-
_onedal_cpu_supported = _onedal_supported
176-
_onedal_gpu_supported = _onedal_supported
177+
def _onedal_gpu_supported(self, method_name, *data):
178+
patching_status = PatchingConditionsChain(
179+
f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}"
180+
)
181+
X, sample_weight = data
182+
183+
is_data_supported = not issparse(X) or (
184+
_is_csr(X) and daal_check_version((2025, "P", 200))
185+
)
186+
187+
is_sample_weight_supported = sample_weight is None or not issparse(X)
188+
189+
patching_status.and_conditions(
190+
[
191+
(
192+
is_sample_weight_supported,
193+
"Sample weights are not supported for CSR data format",
194+
),
195+
(
196+
is_data_supported,
197+
"Supported data formats: Dense, CSR (oneDAL version >= 2025.2.0).",
198+
),
199+
]
200+
)
201+
return patching_status
177202

178203
def _onedal_fit(self, X, sample_weight=None, queue=None):
179204
if sklearn_check_version("1.2"):
180205
self._validate_params()
181206

182207
if sklearn_check_version("1.0"):
183-
X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False)
208+
X = validate_data(
209+
self,
210+
X,
211+
dtype=[np.float64, np.float32],
212+
ensure_2d=False,
213+
accept_sparse="csr",
214+
)
184215
else:
185216
X = check_array(X, dtype=[np.float64, np.float32])
186217

0 commit comments

Comments
 (0)