Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sklearnex.BasicStatistics API for CSR inputs on GPU and a test for it #2253

Merged
merged 18 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onedal/basic_statistics/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

options_and_tests = {
"sum": (lambda X: np.sum(X, axis=0), (5e-4, 1e-7)),
"min": (lambda X: np.min(X, axis=0), (1e-7, 1e-7)),
"min": (lambda X: np.min(X, axis=0), (0, 0)),
"max": (lambda X: np.max(X, axis=0), (1e-7, 1e-7)),
"mean": (lambda X: np.mean(X, axis=0), (5e-7, 1e-7)),
"variance": (lambda X: np.var(X, axis=0), (2e-3, 2e-3)),
Expand Down
30 changes: 28 additions & 2 deletions sklearnex/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
import warnings

import numpy as np
from scipy.sparse import issparse
from sklearn.base import BaseEstimator
from sklearn.utils import check_array
from sklearn.utils.validation import _check_sample_weight

from daal4py.sklearn._n_jobs_support import control_n_jobs
from daal4py.sklearn._utils import sklearn_check_version
from daal4py.sklearn._utils import (
sklearn_check_version,
daal_check_version,
)
from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics
from onedal.utils import _is_csr

from .._device_offload import dispatch
from .._utils import IntelEstimator, PatchingConditionsChain
Expand Down Expand Up @@ -167,9 +172,30 @@ def __getattr__(self, attr):
)

def _onedal_supported(self, method_name, *data):

patching_status = PatchingConditionsChain(
f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}"
)

X, sample_weight = data
is_data_supported = (
_is_csr(X) and daal_check_version((2025, "P", 200))
) or not issparse(X)

is_sample_weight_supported = sample_weight is None or not issparse(X)

patching_status.and_conditions(
[
(
is_sample_weight_supported,
"Sample weights are not supported for CSR data format",
),
(
is_data_supported,
"Supported data formats: Dense, CSR (oneDAL version >= 2024.7.0).",
),
]
)
return patching_status

_onedal_cpu_supported = _onedal_supported
Expand All @@ -180,7 +206,7 @@ def _onedal_fit(self, X, sample_weight=None, queue=None):
self._validate_params()

if sklearn_check_version("1.0"):
X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False)
X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False, accept_sparse="csr")
else:
X = check_array(X, dtype=[np.float64, np.float32])

Expand Down
50 changes: 49 additions & 1 deletion sklearnex/basic_statistics/tests/test_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose
from scipy import sparse as sp

from daal4py.sklearn._utils import daal_check_version
from onedal.basic_statistics.tests.utils import options_and_tests
from onedal.tests.utils._dataframes_support import (
_convert_to_dataframe,
get_dataframes_and_queues,
get_queues,
)
from sklearnex.basic_statistics import BasicStatistics


@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
def test_sklearnex_import_basic_statistics(dataframe, queue):
X = np.array([[0, 0], [1, 1]])
Expand Down Expand Up @@ -173,6 +174,53 @@
assert_allclose(gtr_sum, res_sum, atol=tol)


# @pytest.mark.skipif(not hasattr(sp, "random_array"), reason="requires scipy>=1.12.0")
@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("row_count", [100, 1000])
@pytest.mark.parametrize("column_count", [10, 100])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_multiple_options_on_random_sparse_data(
queue, row_count, column_count, dtype
):
seed = 77
random_state = 42

gen = np.random.default_rng(seed)

X_sparse = sp.random_array(
shape=(row_count, column_count),
density=0.05,
format="csr",
dtype=dtype,
random_state=gen,
)
X_dense = X_sparse.toarray()

options = [
"sum",

Check notice on line 200 in sklearnex/basic_statistics/tests/test_basic_statistics.py

View check run for this annotation

codefactor.io / CodeFactor

sklearnex/basic_statistics/tests/test_basic_statistics.py#L200

Unresolved comment '# TODO: There is a bug in oneDAL's max computations on GPU'. (C100)
# TODO: There is a bug in oneDAL's max computations on GPU
# "max",
"min",
"mean",
"standard_deviation",
"variance",
"second_order_raw_moment",
]
basicstat = BasicStatistics(result_options=options)

result = basicstat.fit(X_sparse)

for result_option in options_and_tests:
function, tols = options_and_tests[result_option]
if not result_option in options:
continue
fp32tol, fp64tol = tols
res = getattr(result, result_option)
gtr = function(X_dense)
tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, atol=tol)


@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
@pytest.mark.parametrize("row_count", [100, 1000])
@pytest.mark.parametrize("column_count", [10, 100])
Expand Down
6 changes: 6 additions & 0 deletions sklearnex/tests/test_run_to_run_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def _run_test(estimator, method, datasets):


_sparse_instances = [SVC()]
if daal_check_version((2025, "P", 200)): # Test for >= 2025.2.0
_sparse_instances.extend(
[
BasicStatistics(),
]
)
if daal_check_version((2024, "P", 700)): # Test for > 2024.7.0
_sparse_instances.extend(
[
Expand Down
Loading