diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index 7e299f07e0..34893af44b 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -16,6 +16,7 @@ from functools import wraps +from daal4py.sklearn._utils import sklearn_check_version from onedal._device_offload import _copy_to_usm, _get_global_queue, _transfer_to_host from onedal.utils._array_api import _asarray from onedal.utils._dpep_helpers import dpnp_available @@ -25,6 +26,7 @@ from onedal.utils._array_api import _convert_to_dpnp from ._config import get_config +from .utils import get_tags def _get_backend(obj, queue, method_name, *data): @@ -59,33 +61,42 @@ def _get_backend(obj, queue, method_name, *data): def dispatch(obj, method_name, branches, *args, **kwargs): q = _get_global_queue() + + array_api_offload = ( + "array_api_dispatch" in get_config() and get_config()["array_api_dispatch"] + ) + + onedal_array_api = array_api_offload and get_tags(obj)["onedal_array_api"] + sklearn_array_api = array_api_offload and get_tags(obj)["array_api_support"] + + # We need to avoid a copy to host here if zero_copy supported + backend = "" + if onedal_array_api: + backend, q, patching_status = _get_backend(obj, q, method_name, *args) + if backend == "onedal": + patching_status.write_log(queue=q, transferred_to_host=False) + return branches[backend](obj, *args, **kwargs, queue=q) + if sklearn_array_api and backend == "sklearn": + patching_status.write_log(transferred_to_host=False) + return branches[backend](obj, *args, **kwargs) + + # move to host because it is necessary for checking + # we only guarantee onedal_cpu_supported and onedal_gpu_supported are generalized to non-numpy inputs + # for zero copy estimators. this will eventually be deprecated when all estimators are zero-copy generalized has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args) has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values()) hostkwargs = dict(zip(kwargs.keys(), hostvalues)) - - backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs + + if not backend: + backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) + if backend == "onedal": - # Host args only used before onedal backend call. - # Device will be offloaded when onedal backend will be called. patching_status.write_log(queue=q, transferred_to_host=False) return branches[backend](obj, *hostargs, **hostkwargs, queue=q) if backend == "sklearn": - if ( - "array_api_dispatch" in get_config() - and get_config()["array_api_dispatch"] - and "array_api_support" in obj._get_tags() - and obj._get_tags()["array_api_support"] - and not has_usm_data - ): - # USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is - # not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant, - # except for the linalg module. There is no guarantee that stock scikit-learn will - # work with such input data. The condition will be updated after DPNP.ndarray and - # DPCTL usm_ndarray enabling for conformance testing and these arrays supportance - # of the fallback cases. - # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, - # then raw inputs are used for the fallback. + if sklearn_array_api and not has_usm_data: + # dpnp fallback is not handled properly yet. patching_status.write_log(transferred_to_host=False) return branches[backend](obj, *args, **kwargs) else: diff --git a/sklearnex/base.py b/sklearnex/base.py new file mode 100644 index 0000000000..9d83f98f46 --- /dev/null +++ b/sklearnex/base.py @@ -0,0 +1,57 @@ +# =============================================================================== +# Copyright contributors to the oneDAL project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== + +from abc import ABC + +from daal4py.sklearn._utils import sklearn_check_version + + +class IntelEstimator(ABC): + + if sklearn_check_version("1.6"): + # Starting in sklearn 1.6, _more_tags is deprecated. An IntelEstimator + # is defined to handle the various versioning issues with the tags and + # with the ongoing rollout of sklearn's array_api support. This will make + # maintenance easier, and centralize tag changes to a single location. + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.onedal_array_api = False + return tags + + elif sklearn_check_version("1.3"): + + def _more_tags(self): + return {"onedal_array_api": False} + + else: + # array_api_support tag was added in sklearn 1.3 via scikit-learn/scikit-learn#26372 + def _more_tags(self): + return {"array_api_support": False, "onedal_array_api": False} + + if sklearn_check_version("1.4"): + + def _get_doc_link(self) -> str: + # This method is meant to generate a clickable doc link for classses + # in sklearnex that are not part of base scikit-learn. It should be + # inherited before inheriting from a scikit-learn estimator, otherwise + # will get overriden by the estimator's original. + url = super()._get_doc_link() + if not url: + module_path, _ = self.__class__.__module__.rsplit(".", 1) + class_name = self.__class__.__name__ + url = f"https://intel.github.io/scikit-learn-intelex/latest/non-scikit-algorithms.html#{module_path}.{class_name}" + return url diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index 26f78ac16e..0138783cc7 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -28,7 +28,8 @@ from onedal.utils import _is_csr from .._device_offload import dispatch -from .._utils import IntelEstimator, PatchingConditionsChain +from .._utils import PatchingConditionsChain +from ..base import IntelEstimator if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data diff --git a/sklearnex/basic_statistics/incremental_basic_statistics.py b/sklearnex/basic_statistics/incremental_basic_statistics.py index d1ddcd55dc..664d0fd811 100644 --- a/sklearnex/basic_statistics/incremental_basic_statistics.py +++ b/sklearnex/basic_statistics/incremental_basic_statistics.py @@ -26,7 +26,8 @@ ) from .._device_offload import dispatch -from .._utils import IntelEstimator, PatchingConditionsChain +from .._utils import PatchingConditionsChain +from ..base import IntelEstimator if sklearn_check_version("1.2"): from sklearn.utils._param_validation import Interval, StrOptions diff --git a/sklearnex/cluster/dbscan.py b/sklearnex/cluster/dbscan.py index ef5f6b78d9..0d81eb1bd6 100755 --- a/sklearnex/cluster/dbscan.py +++ b/sklearnex/cluster/dbscan.py @@ -27,6 +27,7 @@ from .._device_offload import dispatch from .._utils import PatchingConditionsChain +from ..base import IntelEstimator if sklearn_check_version("1.1") and not sklearn_check_version("1.2"): from sklearn.utils import check_scalar @@ -37,7 +38,7 @@ validate_data = _sklearn_DBSCAN._validate_data -class BaseDBSCAN(ABC): +class BaseDBSCAN(IntelEstimator): def _onedal_dbscan(self, **onedal_params): return onedal_DBSCAN(**onedal_params) @@ -51,7 +52,7 @@ def _save_attributes(self): @control_n_jobs(decorated_methods=["fit"]) -class DBSCAN(_sklearn_DBSCAN, BaseDBSCAN): +class DBSCAN(BaseDBSCAN, _sklearn_DBSCAN): __doc__ = _sklearn_DBSCAN.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/cluster/k_means.py b/sklearnex/cluster/k_means.py index 4ba75ca5b8..c092dae9c3 100644 --- a/sklearnex/cluster/k_means.py +++ b/sklearnex/cluster/k_means.py @@ -40,6 +40,7 @@ from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain + from ..base import IntelEstimator if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data @@ -47,7 +48,7 @@ validate_data = _sklearn_KMeans._validate_data @control_n_jobs(decorated_methods=["fit", "fit_transform", "predict", "score"]) - class KMeans(_sklearn_KMeans): + class KMeans(IntelEstimator, _sklearn_KMeans): __doc__ = _sklearn_KMeans.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/covariance/incremental_covariance.py b/sklearnex/covariance/incremental_covariance.py index 89ed92b601..4b9bf99563 100644 --- a/sklearnex/covariance/incremental_covariance.py +++ b/sklearnex/covariance/incremental_covariance.py @@ -33,7 +33,8 @@ from sklearnex import config_context from .._device_offload import dispatch, wrap_output_data -from .._utils import IntelEstimator, PatchingConditionsChain, register_hyperparameters +from .._utils import PatchingConditionsChain, register_hyperparameters +from ..base import IntelEstimator from ..metrics import pairwise_distances from ..utils._array_api import get_namespace diff --git a/sklearnex/decomposition/pca.py b/sklearnex/decomposition/pca.py index 143587aa16..3e5616d7ae 100755 --- a/sklearnex/decomposition/pca.py +++ b/sklearnex/decomposition/pca.py @@ -32,6 +32,7 @@ from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain + from ..base import IntelEstimator from ..utils._array_api import get_namespace if sklearn_check_version("1.1") and not sklearn_check_version("1.2"): @@ -50,7 +51,7 @@ validate_data = _sklearn_PCA._validate_data @control_n_jobs(decorated_methods=["fit", "transform", "fit_transform"]) - class PCA(_sklearn_PCA): + class PCA(IntelEstimator, _sklearn_PCA): __doc__ = _sklearn_PCA.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/ensemble/_forest.py b/sklearnex/ensemble/_forest.py index 2a04962645..bd485726b4 100644 --- a/sklearnex/ensemble/_forest.py +++ b/sklearnex/ensemble/_forest.py @@ -62,6 +62,7 @@ from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain +from ..base import IntelEstimator from ..utils._array_api import get_namespace if sklearn_check_version("1.2"): @@ -75,7 +76,7 @@ validate_data = BaseEstimator._validate_data -class BaseForest(ABC): +class BaseForest(IntelEstimator): _onedal_factory = None def _onedal_fit(self, X, y, sample_weight=None, queue=None): @@ -402,7 +403,7 @@ def base_estimator(self, estimator): self.estimator = estimator -class ForestClassifier(_sklearn_ForestClassifier, BaseForest): +class ForestClassifier(BaseForest, _sklearn_ForestClassifier): # Surprisingly, even though scikit-learn warns against using # their ForestClassifier directly, it actually has a more stable # API than the user-facing objects (over time). If they change it @@ -851,7 +852,7 @@ def _onedal_score(self, X, y, sample_weight=None, queue=None): ) -class ForestRegressor(_sklearn_ForestRegressor, BaseForest): +class ForestRegressor(BaseForest, _sklearn_ForestRegressor): _err = "out_of_bag_error_r2|out_of_bag_error_prediction" _get_tree_state = staticmethod(get_tree_state_reg) diff --git a/sklearnex/linear_model/incremental_linear.py b/sklearnex/linear_model/incremental_linear.py index 622b15ef6c..d79c59667a 100644 --- a/sklearnex/linear_model/incremental_linear.py +++ b/sklearnex/linear_model/incremental_linear.py @@ -40,7 +40,8 @@ from onedal.common.hyperparameters import get_hyperparameters from .._device_offload import dispatch, wrap_output_data -from .._utils import IntelEstimator, PatchingConditionsChain, register_hyperparameters +from .._utils import PatchingConditionsChain, register_hyperparameters +from ..base import IntelEstimator @register_hyperparameters( diff --git a/sklearnex/linear_model/incremental_ridge.py b/sklearnex/linear_model/incremental_ridge.py index 39097d3e8d..639f6eaf7c 100644 --- a/sklearnex/linear_model/incremental_ridge.py +++ b/sklearnex/linear_model/incremental_ridge.py @@ -33,6 +33,7 @@ from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain +from ..base import IntelEstimator if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data @@ -43,7 +44,7 @@ @control_n_jobs( decorated_methods=["fit", "partial_fit", "predict", "score", "_onedal_finalize_fit"] ) -class IncrementalRidge(MultiOutputMixin, RegressorMixin, BaseEstimator): +class IncrementalRidge(IntelEstimator, MultiOutputMixin, RegressorMixin, BaseEstimator): """ Incremental estimator for Ridge Regression. Allows to train Ridge Regression if data is splitted into batches. diff --git a/sklearnex/linear_model/linear.py b/sklearnex/linear_model/linear.py index fb7eca8cf1..71f85d2a15 100644 --- a/sklearnex/linear_model/linear.py +++ b/sklearnex/linear_model/linear.py @@ -28,6 +28,7 @@ from .._config import get_config from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain, get_patch_message, register_hyperparameters +from ..base import IntelEstimator if sklearn_check_version("1.0") and not sklearn_check_version("1.2"): from sklearn.linear_model._base import _deprecate_normalize @@ -47,7 +48,7 @@ @register_hyperparameters({"fit": get_hyperparameters("linear_regression", "train")}) @control_n_jobs(decorated_methods=["fit", "predict", "score"]) -class LinearRegression(_sklearn_LinearRegression): +class LinearRegression(IntelEstimator, _sklearn_LinearRegression): __doc__ = _sklearn_LinearRegression.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/linear_model/logistic_regression.py b/sklearnex/linear_model/logistic_regression.py index 01e944c74f..be195aedb7 100644 --- a/sklearnex/linear_model/logistic_regression.py +++ b/sklearnex/linear_model/logistic_regression.py @@ -39,6 +39,7 @@ from .._config import get_config from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain, get_patch_message + from ..base import IntelEstimator if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data @@ -47,7 +48,7 @@ _sparsity_enabled = daal_check_version((2024, "P", 700)) - class BaseLogisticRegression(ABC): + class BaseLogisticRegression(IntelEstimator): def _onedal_gpu_save_attributes(self): assert hasattr(self, "_onedal_estimator") self.classes_ = self._onedal_estimator.classes_ @@ -65,7 +66,7 @@ def _onedal_gpu_save_attributes(self): "score", ] ) - class LogisticRegression(_sklearn_LogisticRegression, BaseLogisticRegression): + class LogisticRegression(BaseLogisticRegression, _sklearn_LogisticRegression): __doc__ = _sklearn_LogisticRegression.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/linear_model/ridge.py b/sklearnex/linear_model/ridge.py index 85d6714905..74ff42cdbe 100644 --- a/sklearnex/linear_model/ridge.py +++ b/sklearnex/linear_model/ridge.py @@ -39,6 +39,7 @@ from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain + from ..base import IntelEstimator if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data @@ -46,7 +47,7 @@ validate_data = _sklearn_Ridge._validate_data @control_n_jobs(decorated_methods=["fit", "predict", "score"]) - class Ridge(_sklearn_Ridge): + class Ridge(IntelEstimator, _sklearn_Ridge): __doc__ = _sklearn_Ridge.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/neighbors/common.py b/sklearnex/neighbors/common.py index 0ad5a62dd1..d405b2c28f 100644 --- a/sklearnex/neighbors/common.py +++ b/sklearnex/neighbors/common.py @@ -28,10 +28,11 @@ from onedal.utils import _check_array, _num_features, _num_samples from .._utils import PatchingConditionsChain +from ..base import IntelEstimator from ..utils._array_api import get_namespace -class KNeighborsDispatchingBase: +class KNeighborsDispatchingBase(IntelEstimator): def _fit_validation(self, X, y=None): if sklearn_check_version("1.2"): self._validate_params() diff --git a/sklearnex/neighbors/knn_classification.py b/sklearnex/neighbors/knn_classification.py index 3b9871b4cf..2c7d6e5a79 100755 --- a/sklearnex/neighbors/knn_classification.py +++ b/sklearnex/neighbors/knn_classification.py @@ -26,6 +26,7 @@ from onedal.neighbors import KNeighborsClassifier as onedal_KNeighborsClassifier from .._device_offload import dispatch, wrap_output_data +from ..utils import get_tags from .common import KNeighborsDispatchingBase if sklearn_check_version("1.6"): @@ -184,7 +185,7 @@ def _onedal_fit(self, X, y, queue=None): } try: - requires_y = self._get_tags()["requires_y"] + requires_y = get_tags(self)["requires_y"] except KeyError: requires_y = False diff --git a/sklearnex/neighbors/knn_regression.py b/sklearnex/neighbors/knn_regression.py index 5889687498..1112eb7f16 100755 --- a/sklearnex/neighbors/knn_regression.py +++ b/sklearnex/neighbors/knn_regression.py @@ -25,6 +25,7 @@ from onedal.neighbors import KNeighborsRegressor as onedal_KNeighborsRegressor from .._device_offload import dispatch, wrap_output_data +from ..utils import get_tags from .common import KNeighborsDispatchingBase if sklearn_check_version("1.6"): @@ -166,7 +167,7 @@ def _onedal_fit(self, X, y, queue=None): } try: - requires_y = self._get_tags()["requires_y"] + requires_y = get_tags(self)["requires_y"] except KeyError: requires_y = False diff --git a/sklearnex/neighbors/knn_unsupervised.py b/sklearnex/neighbors/knn_unsupervised.py index d76e220cae..fa32b9cb7f 100755 --- a/sklearnex/neighbors/knn_unsupervised.py +++ b/sklearnex/neighbors/knn_unsupervised.py @@ -22,6 +22,7 @@ from onedal.neighbors import NearestNeighbors as onedal_NearestNeighbors from .._device_offload import dispatch, wrap_output_data +from ..utils import get_tags from .common import KNeighborsDispatchingBase if sklearn_check_version("1.6"): @@ -140,7 +141,7 @@ def _onedal_fit(self, X, y=None, queue=None): } try: - requires_y = self._get_tags()["requires_y"] + requires_y = get_tags(self)["requires_y"] except KeyError: requires_y = False diff --git a/sklearnex/preview/covariance/covariance.py b/sklearnex/preview/covariance/covariance.py index 04bdc0be8d..668a6e617a 100644 --- a/sklearnex/preview/covariance/covariance.py +++ b/sklearnex/preview/covariance/covariance.py @@ -30,6 +30,7 @@ from ..._device_offload import dispatch, wrap_output_data from ..._utils import PatchingConditionsChain, register_hyperparameters +from ...base import IntelEstimator if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data @@ -39,7 +40,7 @@ @register_hyperparameters({"fit": get_hyperparameters("covariance", "compute")}) @control_n_jobs(decorated_methods=["fit", "mahalanobis"]) -class EmpiricalCovariance(_sklearn_EmpiricalCovariance): +class EmpiricalCovariance(IntelEstimator, _sklearn_EmpiricalCovariance): __doc__ = _sklearn_EmpiricalCovariance.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/preview/decomposition/incremental_pca.py b/sklearnex/preview/decomposition/incremental_pca.py index fdf13e0817..51aeaa0c51 100644 --- a/sklearnex/preview/decomposition/incremental_pca.py +++ b/sklearnex/preview/decomposition/incremental_pca.py @@ -24,6 +24,7 @@ from ..._device_offload import dispatch, wrap_output_data from ..._utils import PatchingConditionsChain +from ...base import IntelEstimator if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data @@ -34,7 +35,7 @@ @control_n_jobs( decorated_methods=["fit", "partial_fit", "transform", "_onedal_finalize_fit"] ) -class IncrementalPCA(_sklearn_IncrementalPCA): +class IncrementalPCA(IntelEstimator, _sklearn_IncrementalPCA): def __init__(self, n_components=None, *, whiten=False, copy=True, batch_size=None): super().__init__( diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 4b481314ae..0971b405d1 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -30,6 +30,8 @@ from .._config import config_context, get_config from .._utils import PatchingConditionsChain +from ..base import IntelEstimator +from ..utils import get_tags if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data @@ -37,7 +39,7 @@ validate_data = BaseEstimator._validate_data -class BaseSVM(BaseEstimator, ABC): +class BaseSVM(IntelEstimator): @property def _dual_coef_(self): @@ -156,7 +158,7 @@ def _onedal_fit_checks(self, X, y, sample_weight=None): ) if y is None: - if self._get_tags()["requires_y"]: + if get_tags(self)["requires_y"]: raise ValueError( f"This {self.__class__.__name__} estimator " f"requires y to be passed, but the target y is None." diff --git a/sklearnex/svm/nusvc.py b/sklearnex/svm/nusvc.py index 301d90ccc4..fa9f77295d 100644 --- a/sklearnex/svm/nusvc.py +++ b/sklearnex/svm/nusvc.py @@ -39,13 +39,13 @@ if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data else: - validate_data = BaseSVC._validate_data + validate_data = _sklearn_NuSVC._validate_data @control_n_jobs( decorated_methods=["fit", "predict", "_predict_proba", "decision_function", "score"] ) -class NuSVC(_sklearn_NuSVC, BaseSVC): +class NuSVC(BaseSVC, _sklearn_NuSVC): __doc__ = _sklearn_NuSVC.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/svm/nusvr.py b/sklearnex/svm/nusvr.py index 6c746174ac..4979fde0a3 100644 --- a/sklearnex/svm/nusvr.py +++ b/sklearnex/svm/nusvr.py @@ -32,11 +32,11 @@ if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data else: - validate_data = BaseSVR._validate_data + validate_data = _sklearn_NuSVR._validate_data @control_n_jobs(decorated_methods=["fit", "predict", "score"]) -class NuSVR(_sklearn_NuSVR, BaseSVR): +class NuSVR(BaseSVR, _sklearn_NuSVR): __doc__ = _sklearn_NuSVR.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/svm/svc.py b/sklearnex/svm/svc.py index bf5e7f32fc..8e8ced9eee 100644 --- a/sklearnex/svm/svc.py +++ b/sklearnex/svm/svc.py @@ -41,13 +41,13 @@ if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data else: - validate_data = BaseSVC._validate_data + validate_data = _sklearn_SVC._validate_data @control_n_jobs( decorated_methods=["fit", "predict", "_predict_proba", "decision_function", "score"] ) -class SVC(_sklearn_SVC, BaseSVC): +class SVC(BaseSVC, _sklearn_SVC): __doc__ = _sklearn_SVC.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/svm/svr.py b/sklearnex/svm/svr.py index ff2641bea0..fdf1d5d7c7 100644 --- a/sklearnex/svm/svr.py +++ b/sklearnex/svm/svr.py @@ -28,11 +28,11 @@ if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data else: - validate_data = BaseSVR._validate_data + validate_data = _sklearn_SVR._validate_data @control_n_jobs(decorated_methods=["fit", "predict", "score"]) -class SVR(_sklearn_SVR, BaseSVR): +class SVR(BaseSVR, _sklearn_SVR): __doc__ = _sklearn_SVR.__doc__ if sklearn_check_version("1.2"): diff --git a/sklearnex/utils/__init__.py b/sklearnex/utils/__init__.py index 686e089adf..0d30f6074b 100755 --- a/sklearnex/utils/__init__.py +++ b/sklearnex/utils/__init__.py @@ -14,6 +14,23 @@ # limitations under the License. # =============================================================================== +from daal4py.sklearn._utils import sklearn_check_version + from .validation import assert_all_finite -__all__ = ["assert_all_finite"] +# Not an ideal solution, but this converts the outputs of newer sklearnex tags +# into dicts to match how tags had been used. Someone more clever than me will +# have to find a way of converting older tags into newer ones instead (with +# minimal impact on performance). + +if sklearn_check_version("1.6"): + from sklearn.utils import get_tags as _sklearn_get_tags + + get_tags = lambda estimator: _sklearn_get_tags(estimator).__dict__ + +else: + from sklearn.base import BaseEstimator + + get_tags = BaseEstimator._get_tags + +__all__ = ["assert_all_finite", "get_tags"] diff --git a/sklearnex/utils/tests/test_array_api.py b/sklearnex/utils/tests/test_array_api.py new file mode 100644 index 0000000000..ed929c4c34 --- /dev/null +++ b/sklearnex/utils/tests/test_array_api.py @@ -0,0 +1,60 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from daal4py.sklearn._utils import sklearn_check_version +from onedal.tests.utils._dataframes_support import ( + _convert_to_dataframe, + get_dataframes_and_queues, +) + +# TODO: +# add test suit for dpctl.tensor, dpnp.ndarray, numpy.ndarray without config_context(array_api_dispatch=True)). +# TODO: +# extend for DPNP inputs. + + +@pytest.mark.skipif( + not sklearn_check_version("1.4"), + reason="Array API dispatch requires sklearn 1.4 version", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues( + dataframe_filter_="numpy,dpctl,array_api", device_filter_="cpu,gpu" + ), +) +def test_get_namespace_with_config_context(dataframe, queue): + """Test `get_namespace` with `array_api_dispatch` enabled.""" + from sklearnex import config_context + from sklearnex.utils._array_api import get_namespace + + array_api_compat = pytest.importorskip("array_api_compat") + + X_np = np.asarray([[1, 2, 3]]) + X = _convert_to_dataframe(X_np, sycl_queue=queue, target_df=dataframe) + + with config_context(array_api_dispatch=True): + xp_out, is_array_api_compliant = get_namespace(X) + assert is_array_api_compliant + if not dataframe in "numpy,array_api": + # Rather than array_api_compat.get_namespace raw output + # `get_namespace` has specific wrapper classes for `numpy.ndarray` + # or `array-api-strict`. + assert xp_out == array_api_compat.get_namespace(X)