From 1ebc83ce7a8b0637a0a7e48ed4a326e944bf99ca Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 09:55:57 +0100 Subject: [PATCH 01/28] first steps --- onedal/svm/svm.py | 40 +++++++++++++++++----------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/onedal/svm/svm.py b/onedal/svm/svm.py index f4184a40ac..17bb13e4b0 100644 --- a/onedal/svm/svm.py +++ b/onedal/svm/svm.py @@ -22,6 +22,7 @@ from onedal import _backend +from ..common._base import BaseEstimator from ..common._estimator_checks import _check_is_fitted from ..common._mixin import ClassifierMixin, RegressorMixin from ..common._policy import _get_policy @@ -35,14 +36,7 @@ ) -class SVMtype(Enum): - c_svc = 0 - epsilon_svr = 1 - nu_svc = 2 - nu_svr = 3 - - -class BaseSVM(metaclass=ABCMeta): +class BaseSVM(BaseEstimator, metaclass=ABCMeta): @abstractmethod def __init__( self, @@ -63,8 +57,6 @@ def __init__( decision_function_shape, break_ties, algorithm, - svm_type=None, - **kwargs, ): self.C = C self.nu = nu @@ -82,21 +74,20 @@ def __init__( self.decision_function_shape = decision_function_shape self.break_ties = break_ties self.algorithm = algorithm - self.svm_type = svm_type def _validate_targets(self, y, dtype): self.class_weight_ = None self.classes_ = None return _column_or_1d(y, warn=True).astype(dtype, copy=False) - def _get_onedal_params(self, data): + def _get_onedal_params(self, dtype): max_iter = 10000 if self.max_iter == -1 else self.max_iter # TODO: remove this workaround # when oneDAL SVM starts support of 'n_iterations' result self.n_iter_ = 1 if max_iter < 1 else max_iter class_count = 0 if self.classes_ is None else len(self.classes_) return { - "fptype": data.dtype, + "fptype": dtype, "method": self.algorithm, "kernel": self.kernel, "c": self.C, @@ -174,9 +165,9 @@ def _fit(self, X, y, sample_weight, module, queue): self._scale_, self._sigma_ = _gamma, np.sqrt(0.5 / _gamma) policy = _get_policy(queue, *data) - X = _convert_to_supported(policy, X) - params = self._get_onedal_params(X) - result = module.train(policy, params, *to_table(*data)) + data_t = to_table(*_convert_to_supported(policy, *data)) + params = self._get_onedal_params(data_t[0].dtype) + result = module.train(policy, params, *data_t) if self._sparse: self.dual_coef_ = sp.csr_matrix(from_table(result.coeffs).T) @@ -206,9 +197,6 @@ def _create_model(self, module): m.support_vectors = to_table(self.support_vectors_) m.coeffs = to_table(self.dual_coef_.T) m.biases = to_table(self.intercept_) - - if self.svm_type is SVMtype.c_svc or self.svm_type is SVMtype.nu_svc: - m.first_class_response, m.second_class_response = 0, 1 return m def _predict(self, X, module, queue): @@ -370,7 +358,6 @@ def __init__( break_ties=False, algorithm=algorithm, ) - self.svm_type = SVMtype.epsilon_svr def fit(self, X, y, sample_weight=None, queue=None): return super()._fit(X, y, sample_weight, _backend.svm.regression, queue) @@ -422,7 +409,11 @@ def __init__( break_ties=break_ties, algorithm=algorithm, ) - self.svm_type = SVMtype.c_svc + + def _create_model(self, module): + m = super()._create_model(module) + m.first_class_response, m.second_class_response = 0, 1 + return m def _validate_targets(self, y, dtype): y, self.class_weight_, self.classes_ = _validate_targets( @@ -483,7 +474,6 @@ def __init__( break_ties=False, algorithm=algorithm, ) - self.svm_type = SVMtype.nu_svr def fit(self, X, y, sample_weight=None, queue=None): return super()._fit(X, y, sample_weight, _backend.svm.nu_regression, queue) @@ -535,7 +525,11 @@ def __init__( break_ties=break_ties, algorithm=algorithm, ) - self.svm_type = SVMtype.nu_svc + + def _create_model(self, module): + m = super()._create_model(module) + m.first_class_response, m.second_class_response = 0, 1 + return m def _validate_targets(self, y, dtype): y, self.class_weight_, self.classes_ = _validate_targets( From 6ba816c96a08e6c7234731f607d72247fdc40d09 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 10:49:20 +0100 Subject: [PATCH 02/28] remove SVMType entirely --- onedal/svm/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onedal/svm/__init__.py b/onedal/svm/__init__.py index 6bcf140a4a..edff432c69 100644 --- a/onedal/svm/__init__.py +++ b/onedal/svm/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # ============================================================================== -from .svm import SVC, SVR, NuSVC, NuSVR, SVMtype +from .svm import SVC, SVR, NuSVC, NuSVR -__all__ = ["SVC", "SVR", "NuSVC", "NuSVR", "SVMtype"] +__all__ = ["SVC", "SVR", "NuSVC", "NuSVR"] From b2f1bd7e2a74f3d2be20662c17263fa81a0de4e8 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 12:03:45 +0100 Subject: [PATCH 03/28] movement? --- onedal/svm/svm.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/onedal/svm/svm.py b/onedal/svm/svm.py index 17bb13e4b0..0871fa011a 100644 --- a/onedal/svm/svm.py +++ b/onedal/svm/svm.py @@ -35,7 +35,6 @@ _validate_targets, ) - class BaseSVM(BaseEstimator, metaclass=ABCMeta): @abstractmethod def __init__( @@ -241,14 +240,14 @@ def _predict(self, X, module, queue): ) policy = _get_policy(queue, X) - X = _convert_to_supported(policy, X) - params = self._get_onedal_params(X) + X = to_table(_convert_to_supported(policy, X)) + params = self._get_onedal_params(X.dtype) if hasattr(self, "_onedal_model"): model = self._onedal_model else: model = self._create_model(module) - result = module.infer(policy, params, model, to_table(X)) + result = module.infer(policy, params, model, X) y = from_table(result.responses) return y @@ -278,12 +277,12 @@ def _decision_function(self, X, module, queue): ) _check_n_features(self, X, False) - if self._sparse and not sp.isspmatrix(X): - X = sp.csr_matrix(X) if self._sparse: - X.sort_indices() - - if sp.issparse(X) and not self._sparse and not callable(self.kernel): + if not sp.isspmatrix(X): + X = sp.csr_matrix(X) + else: + X.sort_indices() + elif sp.issparse(X) and not callable(self.kernel): raise ValueError( "cannot use sparse input in %r trained on dense data" % type(self).__name__ @@ -298,14 +297,14 @@ def _decision_function(self, X, module, queue): ) policy = _get_policy(queue, X) - X = _convert_to_supported(policy, X) - params = self._get_onedal_params(X) + X = to_table(_convert_to_supported(policy, X)) + params = self._get_onedal_params(X.dtype) if hasattr(self, "_onedal_model"): model = self._onedal_model else: model = self._create_model(module) - result = module.infer(policy, params, model, to_table(X)) + result = module.infer(policy, params, model, X) decision_function = from_table(result.decision_function) if len(self.classes_) == 2: From 66cac2501d6794171257ee8d307e4ae6b1b8d0be Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 13:31:44 +0100 Subject: [PATCH 04/28] centralize predict --- onedal/svm/svm.py | 67 +++++++++++----------------------------- sklearnex/svm/_common.py | 51 +++++++++++++++++++++++++++++- sklearnex/svm/nusvr.py | 19 ------------ sklearnex/svm/svr.py | 19 ------------ 4 files changed, 68 insertions(+), 88 deletions(-) diff --git a/onedal/svm/svm.py b/onedal/svm/svm.py index 0871fa011a..120478edf0 100644 --- a/onedal/svm/svm.py +++ b/onedal/svm/svm.py @@ -35,6 +35,7 @@ _validate_targets, ) + class BaseSVM(BaseEstimator, metaclass=ABCMeta): @abstractmethod def __init__( @@ -200,56 +201,28 @@ def _create_model(self, module): def _predict(self, X, module, queue): _check_is_fitted(self) - if self.break_ties and self.decision_function_shape == "ovo": - raise ValueError( - "break_ties must be False when " "decision_function_shape is 'ovo'" - ) - if module in [_backend.svm.classification, _backend.svm.nu_classification]: - sv = self.support_vectors_ - if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]: - raise ValueError( - "The internal representation " - f"of {self.__class__.__name__} was altered" - ) + if self._sparse and not sp.isspmatrix(X): + X = sp.csr_matrix(X) + if self._sparse: + X.sort_indices() - if ( - self.break_ties - and self.decision_function_shape == "ovr" - and len(self.classes_) > 2 - ): - y = np.argmax(self.decision_function(X), axis=1) - else: - X = _check_array( - X, - dtype=[np.float64, np.float32], - force_all_finite=True, - accept_sparse="csr", + if sp.issparse(X) and not self._sparse and not callable(self.kernel): + raise ValueError( + "cannot use sparse input in %r trained on dense data" + % type(self).__name__ ) - _check_n_features(self, X, False) - if self._sparse and not sp.isspmatrix(X): - X = sp.csr_matrix(X) - if self._sparse: - X.sort_indices() - - if sp.issparse(X) and not self._sparse and not callable(self.kernel): - raise ValueError( - "cannot use sparse input in %r trained on dense data" - % type(self).__name__ - ) - - policy = _get_policy(queue, X) - X = to_table(_convert_to_supported(policy, X)) - params = self._get_onedal_params(X.dtype) + policy = _get_policy(queue, X) + X = to_table(_convert_to_supported(policy, X)) + params = self._get_onedal_params(X.dtype) - if hasattr(self, "_onedal_model"): - model = self._onedal_model - else: - model = self._create_model(module) - result = module.infer(policy, params, model, X) - y = from_table(result.responses) - return y + if hasattr(self, "_onedal_model"): + model = self._onedal_model + else: + model = self._create_model(module) + result = module.infer(policy, params, model, X) + return from_table(result.responses) def _ovr_decision_function(self, predictions, confidences, n_classes): n_samples = predictions.shape[0] @@ -337,7 +310,6 @@ def __init__( max_iter=-1, tau=1e-12, algorithm="thunder", - **kwargs, ): super().__init__( C=C, @@ -388,7 +360,6 @@ def __init__( decision_function_shape="ovr", break_ties=False, algorithm="thunder", - **kwargs, ): super().__init__( C=C, @@ -453,7 +424,6 @@ def __init__( max_iter=-1, tau=1e-12, algorithm="thunder", - **kwargs, ): super().__init__( C=C, @@ -504,7 +474,6 @@ def __init__( decision_function_shape="ovr", break_ties=False, algorithm="thunder", - **kwargs, ): super().__init__( C=1.0, diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 4b481314ae..8c5b1bb85f 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -22,14 +22,16 @@ from scipy import sparse as sp from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.calibration import CalibratedClassifierCV -from sklearn.metrics import r2_score +from sklearn.metrics import accuracy_score, r2_score from sklearn.preprocessing import LabelEncoder +from sklearn.utils.validation import check_array from daal4py.sklearn._utils import sklearn_check_version from onedal.utils import _check_array, _check_X_y, _column_or_1d from .._config import config_context, get_config from .._utils import PatchingConditionsChain +from ..utils._array_api import get_namespace if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data @@ -238,6 +240,40 @@ def _get_sample_weight(self, X, y, sample_weight): return sample_weight + def _onedal_predict(self, X, queue=None): + xp, _ = get_namespace(X) + + if self.break_ties and self.decision_function_shape == "ovo": + raise ValueError( + "break_ties must be False when " "decision_function_shape is 'ovo'" + ) + + if sklearn_check_version("1.0"): + X = validate_data( + self, + X, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + reset=False, + ) + else: + X = check_array( + X, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + ) + + if ( + self.break_ties + and self.decision_function_shape == "ovr" + and len(self.classes_) > 2 + ): + return xp.argmax( + self._onedal_estimator.decision_function(X, queue=queue), axis=1 + ) + + return self._onedal_estimator.predict(X, queue=queue) + class BaseSVC(BaseSVM): def _compute_balanced_class_weight(self, y): @@ -311,6 +347,19 @@ def _save_attributes(self): length = int(len(self.classes_) * (len(self.classes_) - 1) / 2) self.n_iter_ = np.full((length,), self._onedal_estimator.n_iter_) + def _onedal_predict(self, X, queue=None): + sv = self.support_vectors_ + if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]: + raise ValueError( + "The internal representation " f"of {self.__class__.__name__} was altered" + ) + return super()._onedal_predict(X, queue=queue) + + def _onedal_score(self, X, y, sample_weight=None, queue=None): + return accuracy_score( + y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight + ) + class BaseSVR(BaseSVM): def _save_attributes(self): diff --git a/sklearnex/svm/nusvr.py b/sklearnex/svm/nusvr.py index 6c746174ac..6401901155 100644 --- a/sklearnex/svm/nusvr.py +++ b/sklearnex/svm/nusvr.py @@ -146,25 +146,6 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._onedal_estimator.fit(X, y, sample_weight, queue=queue) self._save_attributes() - def _onedal_predict(self, X, queue=None): - if sklearn_check_version("1.0"): - X = validate_data( - self, - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - accept_sparse="csr", - reset=False, - ) - else: - X = check_array( - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - accept_sparse="csr", - ) - return self._onedal_estimator.predict(X, queue=queue) - fit.__doc__ = _sklearn_NuSVR.fit.__doc__ predict.__doc__ = _sklearn_NuSVR.predict.__doc__ score.__doc__ = _sklearn_NuSVR.score.__doc__ diff --git a/sklearnex/svm/svr.py b/sklearnex/svm/svr.py index ff2641bea0..810b5582b4 100644 --- a/sklearnex/svm/svr.py +++ b/sklearnex/svm/svr.py @@ -143,25 +143,6 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._onedal_estimator.fit(X, y, sample_weight, queue=queue) self._save_attributes() - def _onedal_predict(self, X, queue=None): - if sklearn_check_version("1.0"): - X = validate_data( - self, - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - accept_sparse="csr", - reset=False, - ) - else: - X = check_array( - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - accept_sparse="csr", - ) - return self._onedal_estimator.predict(X, queue=queue) - fit.__doc__ = _sklearn_SVR.fit.__doc__ predict.__doc__ = _sklearn_SVR.predict.__doc__ score.__doc__ = _sklearn_SVR.score.__doc__ From 1b3e266d5f53a0eb0548d0d0adb2cd71fa4d8366 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 13:33:23 +0100 Subject: [PATCH 05/28] further removal --- sklearnex/svm/svc.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/sklearnex/svm/svc.py b/sklearnex/svm/svc.py index bf5e7f32fc..a96ffa17e9 100644 --- a/sklearnex/svm/svc.py +++ b/sklearnex/svm/svc.py @@ -335,26 +335,6 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._save_attributes() - def _onedal_predict(self, X, queue=None): - if sklearn_check_version("1.0"): - X = validate_data( - self, - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - ensure_2d=False, - accept_sparse="csr", - reset=False, - ) - else: - X = check_array( - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - accept_sparse="csr", - ) - return self._onedal_estimator.predict(X, queue=queue) - def _onedal_predict_proba(self, X, queue=None): if getattr(self, "clf_prob", None) is None: raise NotFittedError( @@ -388,11 +368,6 @@ def _onedal_decision_function(self, X, queue=None): ) return self._onedal_estimator.decision_function(X, queue=queue) - def _onedal_score(self, X, y, sample_weight=None, queue=None): - return accuracy_score( - y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight - ) - fit.__doc__ = _sklearn_SVC.fit.__doc__ predict.__doc__ = _sklearn_SVC.predict.__doc__ decision_function.__doc__ = _sklearn_SVC.decision_function.__doc__ From 8686d5688743c4a7aa0de57cbebe500ee7cfb89d Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 13:33:33 +0100 Subject: [PATCH 06/28] further removal --- sklearnex/svm/nusvc.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/sklearnex/svm/nusvc.py b/sklearnex/svm/nusvc.py index 301d90ccc4..c621e4559e 100644 --- a/sklearnex/svm/nusvc.py +++ b/sklearnex/svm/nusvc.py @@ -16,7 +16,6 @@ import numpy as np from sklearn.exceptions import NotFittedError -from sklearn.metrics import accuracy_score from sklearn.svm import NuSVC as _sklearn_NuSVC from sklearn.utils.validation import ( _deprecate_positional_args, @@ -305,26 +304,6 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._save_attributes() - def _onedal_predict(self, X, queue=None): - if sklearn_check_version("1.0"): - validate_data( - self, - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - ensure_2d=False, - accept_sparse="csr", - reset=False, - ) - else: - X = check_array( - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - accept_sparse="csr", - ) - - return self._onedal_estimator.predict(X, queue=queue) def _onedal_predict_proba(self, X, queue=None): if getattr(self, "clf_prob", None) is None: @@ -360,11 +339,6 @@ def _onedal_decision_function(self, X, queue=None): return self._onedal_estimator.decision_function(X, queue=queue) - def _onedal_score(self, X, y, sample_weight=None, queue=None): - return accuracy_score( - y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight - ) - fit.__doc__ = _sklearn_NuSVC.fit.__doc__ predict.__doc__ = _sklearn_NuSVC.predict.__doc__ decision_function.__doc__ = _sklearn_NuSVC.decision_function.__doc__ From 01ac9231f6aabb5d61ca1723c36c5f4d04bd9284 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 13:46:54 +0100 Subject: [PATCH 07/28] attempt to deal with other oddities --- onedal/svm/svm.py | 16 ++++------------ sklearnex/svm/_common.py | 13 +++++++++---- sklearnex/svm/nusvc.py | 1 - 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/onedal/svm/svm.py b/onedal/svm/svm.py index 120478edf0..86d75db1ff 100644 --- a/onedal/svm/svm.py +++ b/onedal/svm/svm.py @@ -334,8 +334,7 @@ def fit(self, X, y, sample_weight=None, queue=None): return super()._fit(X, y, sample_weight, _backend.svm.regression, queue) def predict(self, X, queue=None): - y = super()._predict(X, _backend.svm.regression, queue) - return y.ravel() + return super()._predict(X, _backend.svm.regression, queue) class SVC(ClassifierMixin, BaseSVM): @@ -395,10 +394,7 @@ def fit(self, X, y, sample_weight=None, queue=None): return super()._fit(X, y, sample_weight, _backend.svm.classification, queue) def predict(self, X, queue=None): - y = super()._predict(X, _backend.svm.classification, queue) - if len(self.classes_) == 2: - y = y.ravel() - return self.classes_.take(np.asarray(y, dtype=np.intp)).ravel() + return super()._predict(X, _backend.svm.classification, queue) def decision_function(self, X, queue=None): return super()._decision_function(X, _backend.svm.classification, queue) @@ -448,8 +444,7 @@ def fit(self, X, y, sample_weight=None, queue=None): return super()._fit(X, y, sample_weight, _backend.svm.nu_regression, queue) def predict(self, X, queue=None): - y = super()._predict(X, _backend.svm.nu_regression, queue) - return y.ravel() + return super()._predict(X, _backend.svm.nu_regression, queue) class NuSVC(ClassifierMixin, BaseSVM): @@ -509,10 +504,7 @@ def fit(self, X, y, sample_weight=None, queue=None): return super()._fit(X, y, sample_weight, _backend.svm.nu_classification, queue) def predict(self, X, queue=None): - y = super()._predict(X, _backend.svm.nu_classification, queue) - if len(self.classes_) == 2: - y = y.ravel() - return self.classes_.take(np.asarray(y, dtype=np.intp)).ravel() + return super()._predict(X, _backend.svm.nu_classification, queue) def decision_function(self, X, queue=None): return super()._decision_function(X, _backend.svm.nu_classification, queue) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 8c5b1bb85f..0b37976dcc 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -240,8 +240,9 @@ def _get_sample_weight(self, X, y, sample_weight): return sample_weight - def _onedal_predict(self, X, queue=None): - xp, _ = get_namespace(X) + def _onedal_predict(self, X, queue=None, xp=None): + if xp is None: + xp, _ = get_namespace(X) if self.break_ties and self.decision_function_shape == "ovo": raise ValueError( @@ -272,7 +273,7 @@ def _onedal_predict(self, X, queue=None): self._onedal_estimator.decision_function(X, queue=queue), axis=1 ) - return self._onedal_estimator.predict(X, queue=queue) + return xp.squeeze(self._onedal_estimator.predict(X, queue=queue)) class BaseSVC(BaseSVM): @@ -353,7 +354,11 @@ def _onedal_predict(self, X, queue=None): raise ValueError( "The internal representation " f"of {self.__class__.__name__} was altered" ) - return super()._onedal_predict(X, queue=queue) + xp, _ = get_namespace(X) + res = super()._onedal_predict(X, queue=queue, xp=xp) + if len(self.classes_) != 2: + res = xp.take(self.classes_, xp.asarray(res, dtype=xp.int32)) + return res def _onedal_score(self, X, y, sample_weight=None, queue=None): return accuracy_score( diff --git a/sklearnex/svm/nusvc.py b/sklearnex/svm/nusvc.py index c621e4559e..b471c19a17 100644 --- a/sklearnex/svm/nusvc.py +++ b/sklearnex/svm/nusvc.py @@ -304,7 +304,6 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._save_attributes() - def _onedal_predict_proba(self, X, queue=None): if getattr(self, "clf_prob", None) is None: raise NotFittedError( From d6a5c6e9fc5a5b94130e43184adea45dcb54b7b6 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 14:38:29 +0100 Subject: [PATCH 08/28] move onedal_decision_function --- sklearnex/svm/_common.py | 60 ++++++++++++++++++++++++++++------------ sklearnex/svm/nusvc.py | 3 +- sklearnex/svm/nusvr.py | 3 +- sklearnex/svm/svc.py | 22 ++------------- sklearnex/svm/svr.py | 3 +- 5 files changed, 50 insertions(+), 41 deletions(-) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 0b37976dcc..963fa04c2c 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -20,10 +20,12 @@ import numpy as np from scipy import sparse as sp -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import ClassifierMixin, RegressorMixin from sklearn.calibration import CalibratedClassifierCV from sklearn.metrics import accuracy_score, r2_score from sklearn.preprocessing import LabelEncoder +from sklearn.svm._base import BaseLibSVM as _sklearn_BaseLibSVM +from sklearn.svm._base import BaseSVC as _sklearn_BaseSVC from sklearn.utils.validation import check_array from daal4py.sklearn._utils import sklearn_check_version @@ -39,7 +41,9 @@ validate_data = BaseEstimator._validate_data -class BaseSVM(BaseEstimator, ABC): +class BaseSVM(_sklearn_BaseLibSVM): + + _onedal_factory = None @property def _dual_coef_(self): @@ -244,11 +248,6 @@ def _onedal_predict(self, X, queue=None, xp=None): if xp is None: xp, _ = get_namespace(X) - if self.break_ties and self.decision_function_shape == "ovo": - raise ValueError( - "break_ties must be False when " "decision_function_shape is 'ovo'" - ) - if sklearn_check_version("1.0"): X = validate_data( self, @@ -264,19 +263,10 @@ def _onedal_predict(self, X, queue=None, xp=None): accept_sparse="csr", ) - if ( - self.break_ties - and self.decision_function_shape == "ovr" - and len(self.classes_) > 2 - ): - return xp.argmax( - self._onedal_estimator.decision_function(X, queue=queue), axis=1 - ) - return xp.squeeze(self._onedal_estimator.predict(X, queue=queue)) -class BaseSVC(BaseSVM): +class BaseSVC(BaseSVM, _sklearn_BaseSVC): def _compute_balanced_class_weight(self, y): y_ = _column_or_1d(y) classes, _ = np.unique(y_, return_inverse=True) @@ -354,19 +344,53 @@ def _onedal_predict(self, X, queue=None): raise ValueError( "The internal representation " f"of {self.__class__.__name__} was altered" ) + + if self.break_ties and self.decision_function_shape == "ovo": + raise ValueError( + "break_ties must be False when " "decision_function_shape is 'ovo'" + ) + + if ( + self.break_ties + and self.decision_function_shape == "ovr" + and len(self.classes_) > 2 + ): + return xp.argmax(self._onedal_decision_function(X, queue=queue), axis=1) + xp, _ = get_namespace(X) res = super()._onedal_predict(X, queue=queue, xp=xp) if len(self.classes_) != 2: res = xp.take(self.classes_, xp.asarray(res, dtype=xp.int32)) return res + def _onedal_decision_function(self, X, queue=None): + xp, _ = get_namespace(X) + if sklearn_check_version("1.0"): + validate_data( + self, + X, + dtype=[xp.float64, xp.float32], + force_all_finite=False, + accept_sparse="csr", + reset=False, + ) + else: + X = check_array( + X, + dtype=[xp.float64, xp.float32], + force_all_finite=False, + accept_sparse="csr", + ) + + return self._onedal_estimator.decision_function(X, queue=queue) + def _onedal_score(self, X, y, sample_weight=None, queue=None): return accuracy_score( y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight ) -class BaseSVR(BaseSVM): +class BaseSVR(BaseSVM, RegressorMixin): def _save_attributes(self): self.support_vectors_ = self._onedal_estimator.support_vectors_ self.n_features_in_ = self._onedal_estimator.n_features_in_ diff --git a/sklearnex/svm/nusvc.py b/sklearnex/svm/nusvc.py index b471c19a17..1b04b23d1c 100644 --- a/sklearnex/svm/nusvc.py +++ b/sklearnex/svm/nusvc.py @@ -46,6 +46,7 @@ ) class NuSVC(_sklearn_NuSVC, BaseSVC): __doc__ = _sklearn_NuSVC.__doc__ + _onedal_factory = onedal_NuSVC if sklearn_check_version("1.2"): _parameter_constraints: dict = {**_sklearn_NuSVC._parameter_constraints} @@ -291,7 +292,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): "decision_function_shape": self.decision_function_shape, } - self._onedal_estimator = onedal_NuSVC(**onedal_params) + self._onedal_estimator = self._onedal_factory(**onedal_params) self._onedal_estimator.fit(X, y, weights, queue=queue) if self.probability: diff --git a/sklearnex/svm/nusvr.py b/sklearnex/svm/nusvr.py index 6401901155..174ed1327e 100644 --- a/sklearnex/svm/nusvr.py +++ b/sklearnex/svm/nusvr.py @@ -38,6 +38,7 @@ @control_n_jobs(decorated_methods=["fit", "predict", "score"]) class NuSVR(_sklearn_NuSVR, BaseSVR): __doc__ = _sklearn_NuSVR.__doc__ + _onedal_factory = onedal_NuSVR if sklearn_check_version("1.2"): _parameter_constraints: dict = {**_sklearn_NuSVR._parameter_constraints} @@ -142,7 +143,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): "max_iter": self.max_iter, } - self._onedal_estimator = onedal_NuSVR(**onedal_params) + self._onedal_estimator = self._onedal_factory(**onedal_params) self._onedal_estimator.fit(X, y, sample_weight, queue=queue) self._save_attributes() diff --git a/sklearnex/svm/svc.py b/sklearnex/svm/svc.py index a96ffa17e9..78052ad74c 100644 --- a/sklearnex/svm/svc.py +++ b/sklearnex/svm/svc.py @@ -49,6 +49,7 @@ ) class SVC(_sklearn_SVC, BaseSVC): __doc__ = _sklearn_SVC.__doc__ + _onedal_factory = onedal_SVC if sklearn_check_version("1.2"): _parameter_constraints: dict = {**_sklearn_SVC._parameter_constraints} @@ -322,7 +323,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): "decision_function_shape": self.decision_function_shape, } - self._onedal_estimator = onedal_SVC(**onedal_params) + self._onedal_estimator = self._onedal_factory(**onedal_params) self._onedal_estimator.fit(X, y, weights, queue=queue) if self.probability: @@ -349,25 +350,6 @@ def _onedal_predict_proba(self, X, queue=None): with config_context(**cfg): return self.clf_prob.predict_proba(X) - def _onedal_decision_function(self, X, queue=None): - if sklearn_check_version("1.0"): - X = validate_data( - self, - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - accept_sparse="csr", - reset=False, - ) - else: - X = check_array( - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - accept_sparse="csr", - ) - return self._onedal_estimator.decision_function(X, queue=queue) - fit.__doc__ = _sklearn_SVC.fit.__doc__ predict.__doc__ = _sklearn_SVC.predict.__doc__ decision_function.__doc__ = _sklearn_SVC.decision_function.__doc__ diff --git a/sklearnex/svm/svr.py b/sklearnex/svm/svr.py index 810b5582b4..d104bc9302 100644 --- a/sklearnex/svm/svr.py +++ b/sklearnex/svm/svr.py @@ -34,6 +34,7 @@ @control_n_jobs(decorated_methods=["fit", "predict", "score"]) class SVR(_sklearn_SVR, BaseSVR): __doc__ = _sklearn_SVR.__doc__ + _onedal_factory = onedal_SVR if sklearn_check_version("1.2"): _parameter_constraints: dict = {**_sklearn_SVR._parameter_constraints} @@ -139,7 +140,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): "max_iter": self.max_iter, } - self._onedal_estimator = onedal_SVR(**onedal_params) + self._onedal_estimator = self._onedal_factory(**onedal_params) self._onedal_estimator.fit(X, y, sample_weight, queue=queue) self._save_attributes() From 9301fbb94a76011c6fc9a22f91666bb0845c4a5b Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 15:09:57 +0100 Subject: [PATCH 09/28] centralization --- sklearnex/svm/_common.py | 203 ++++++++++++++++++++++++++++++++++++++- sklearnex/svm/nusvc.py | 187 +----------------------------------- sklearnex/svm/nusvr.py | 32 +----- sklearnex/svm/svc.py | 160 +----------------------------- sklearnex/svm/svr.py | 32 +----- 5 files changed, 203 insertions(+), 411 deletions(-) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 963fa04c2c..aa1f8f1d90 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -20,28 +20,33 @@ import numpy as np from scipy import sparse as sp -from sklearn.base import ClassifierMixin, RegressorMixin +from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin from sklearn.calibration import CalibratedClassifierCV +from sklearn.exceptions import NotFittedError from sklearn.metrics import accuracy_score, r2_score from sklearn.preprocessing import LabelEncoder from sklearn.svm._base import BaseLibSVM as _sklearn_BaseLibSVM from sklearn.svm._base import BaseSVC as _sklearn_BaseSVC -from sklearn.utils.validation import check_array +from sklearn.utils.validation import check_array, check_is_fitted from daal4py.sklearn._utils import sklearn_check_version from onedal.utils import _check_array, _check_X_y, _column_or_1d from .._config import config_context, get_config +from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain from ..utils._array_api import get_namespace +if sklearn_check_version("1.0"): + from sklearn.utils.metaestimators import available_if + if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data else: validate_data = BaseEstimator._validate_data -class BaseSVM(_sklearn_BaseLibSVM): +class BaseSVM(BaseEstimator): _onedal_factory = None @@ -267,6 +272,145 @@ def _onedal_predict(self, X, queue=None, xp=None): class BaseSVC(BaseSVM, _sklearn_BaseSVC): + + @wrap_output_data + def predict(self, X): + check_is_fitted(self) + return dispatch( + self, + "predict", + { + "onedal": self.__class__._onedal_predict, + "sklearn": _sklearn_BaseSVC.predict, + }, + X, + ) + + @wrap_output_data + def score(self, X, y, sample_weight=None): + check_is_fitted(self) + return dispatch( + self, + "score", + { + "onedal": self.__class__._onedal_score, + "sklearn": _sklearn_BaseSVC.score, + }, + X, + y, + sample_weight=sample_weight, + ) + + @wrap_output_data + def decision_function(self, X): + check_is_fitted(self) + return dispatch( + self, + "decision_function", + { + "onedal": self.__class__._onedal_decision_function, + "sklearn": _sklearn_BaseSVC.decision_function, + }, + X, + ) + + if sklearn_check_version("1.0"): + + @available_if(_sklearn_BaseSVC._check_proba) + def predict_proba(self, X): + """ + Compute probabilities of possible outcomes for samples in X. + + The model need to have probability information computed at training + time: fit with attribute `probability` set to True. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + For kernel="precomputed", the expected shape of X is + (n_samples_test, n_samples_train). + + Returns + ------- + T : ndarray of shape (n_samples, n_classes) + Returns the probability of the sample for each class in + the model. The columns correspond to the classes in sorted + order, as they appear in the attribute :term:`classes_`. + + Notes + ----- + The probability model is created using cross validation, so + the results can be slightly different than those obtained by + predict. Also, it will produce meaningless results on very small + datasets. + """ + check_is_fitted(self) + return self._predict_proba(X) + + @available_if(_sklearn_BaseSVC._check_proba) + def predict_log_proba(self, X): + """Compute log probabilities of possible outcomes for samples in X. + + The model need to have probability information computed at training + time: fit with attribute `probability` set to True. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) or \ + (n_samples_test, n_samples_train) + For kernel="precomputed", the expected shape of X is + (n_samples_test, n_samples_train). + + Returns + ------- + T : ndarray of shape (n_samples, n_classes) + Returns the log-probabilities of the sample for each class in + the model. The columns correspond to the classes in sorted + order, as they appear in the attribute :term:`classes_`. + + Notes + ----- + The probability model is created using cross validation, so + the results can be slightly different than those obtained by + predict. Also, it will produce meaningless results on very small + datasets. + """ + xp, _ = get_namespace(X) + + return xp.log(self.predict_proba(X)) + + else: + + @property + def predict_proba(self): + self._check_proba() + check_is_fitted(self) + return self._predict_proba + + def _predict_log_proba(self, X): + xp, _ = get_namespace(X) + return xp.log(self.predict_proba(X)) + + predict_proba.__doc__ = _sklearn_NuSVC.predict_proba.__doc__ + + @wrap_output_data + def _predict_proba(self, X): + sklearn_pred_proba = ( + _sklearn_NuSVC.predict_proba + if sklearn_check_version("1.0") + else _sklearn_NuSVC._predict_proba + ) + + return dispatch( + self, + "predict_proba", + { + "onedal": self.__class__._onedal_predict_proba, + "sklearn": sklearn_pred_proba, + }, + X, + ) + def _compute_balanced_class_weight(self, y): y_ = _column_or_1d(y) classes, _ = np.unique(y_, return_inverse=True) @@ -316,7 +460,7 @@ def _save_attributes(self): self.dual_coef_ = self._onedal_estimator.dual_coef_ self.shape_fit_ = self._onedal_estimator.class_weight_ self.classes_ = self._onedal_estimator.classes_ - if isinstance(self, ClassifierMixin) or not sklearn_check_version("1.2"): + if not sklearn_check_version("1.2"): self.class_weight_ = self._onedal_estimator.class_weight_ self.support_ = self._onedal_estimator.support_ @@ -384,13 +528,59 @@ def _onedal_decision_function(self, X, queue=None): return self._onedal_estimator.decision_function(X, queue=queue) + def _onedal_predict_proba(self, X, queue=None): + if getattr(self, "clf_prob", None) is None: + raise NotFittedError( + "predict_proba is not available when fitted with probability=False" + ) + from .._config import config_context, get_config + + # We use stock metaestimators below, so the only way + # to pass a queue is using config_context. + cfg = get_config() + cfg["target_offload"] = queue + with config_context(**cfg): + return self.clf_prob.predict_proba(X) + def _onedal_score(self, X, y, sample_weight=None, queue=None): return accuracy_score( y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight ) + predict.__doc__ = _sklearn_BaseSVC.predict.__doc__ + decision_function.__doc__ = _sklearn_BaseSVC.decision_function.__doc__ + score.__doc__ = _sklearn_BaseSVC.score.__doc__ + + +class BaseSVR(BaseSVM, _sklearn_BaseLibSVM, RegressorMixin): + @wrap_output_data + def predict(self, X): + check_is_fitted(self) + return dispatch( + self, + "predict", + { + "onedal": self.__class__._onedal_predict, + "sklearn": _sklearn_BaseLibSVM.predict, + }, + X, + ) + + @wrap_output_data + def score(self, X, y, sample_weight=None): + check_is_fitted(self) + return dispatch( + self, + "score", + { + "onedal": self.__class__._onedal_score, + "sklearn": _sklearn_BaseLibSVM.score, + }, + X, + y, + sample_weight=sample_weight, + ) -class BaseSVR(BaseSVM, RegressorMixin): def _save_attributes(self): self.support_vectors_ = self._onedal_estimator.support_vectors_ self.n_features_in_ = self._onedal_estimator.n_features_in_ @@ -415,3 +605,6 @@ def _onedal_score(self, X, y, sample_weight=None, queue=None): return r2_score( y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight ) + + predict.__doc__ = _sklearn_BaseLibSVM.predict.__doc__ + score.__doc__ = _sklearn_BaseLibSVM.score.__doc__ diff --git a/sklearnex/svm/nusvc.py b/sklearnex/svm/nusvc.py index 1b04b23d1c..d613bf8ca7 100644 --- a/sklearnex/svm/nusvc.py +++ b/sklearnex/svm/nusvc.py @@ -15,7 +15,6 @@ # ============================================================================== import numpy as np -from sklearn.exceptions import NotFittedError from sklearn.svm import NuSVC as _sklearn_NuSVC from sklearn.utils.validation import ( _deprecate_positional_args, @@ -25,16 +24,11 @@ from daal4py.sklearn._n_jobs_support import control_n_jobs from daal4py.sklearn._utils import sklearn_check_version +from onedal.svm import NuSVC as onedal_NuSVC from .._device_offload import dispatch, wrap_output_data -from ..utils._array_api import get_namespace from ._common import BaseSVC -if sklearn_check_version("1.0"): - from sklearn.utils.metaestimators import available_if - -from onedal.svm import NuSVC as onedal_NuSVC - if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data else: @@ -44,7 +38,7 @@ @control_n_jobs( decorated_methods=["fit", "predict", "_predict_proba", "decision_function", "score"] ) -class NuSVC(_sklearn_NuSVC, BaseSVC): +class NuSVC(BaseSVC, _sklearn_NuSVC): __doc__ = _sklearn_NuSVC.__doc__ _onedal_factory = onedal_NuSVC @@ -117,146 +111,6 @@ def fit(self, X, y, sample_weight=None): return self - @wrap_output_data - def predict(self, X): - check_is_fitted(self) - return dispatch( - self, - "predict", - { - "onedal": self.__class__._onedal_predict, - "sklearn": _sklearn_NuSVC.predict, - }, - X, - ) - - @wrap_output_data - def score(self, X, y, sample_weight=None): - check_is_fitted(self) - return dispatch( - self, - "score", - { - "onedal": self.__class__._onedal_score, - "sklearn": _sklearn_NuSVC.score, - }, - X, - y, - sample_weight=sample_weight, - ) - - if sklearn_check_version("1.0"): - - @available_if(_sklearn_NuSVC._check_proba) - def predict_proba(self, X): - """ - Compute probabilities of possible outcomes for samples in X. - - The model need to have probability information computed at training - time: fit with attribute `probability` set to True. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - For kernel="precomputed", the expected shape of X is - (n_samples_test, n_samples_train). - - Returns - ------- - T : ndarray of shape (n_samples, n_classes) - Returns the probability of the sample for each class in - the model. The columns correspond to the classes in sorted - order, as they appear in the attribute :term:`classes_`. - - Notes - ----- - The probability model is created using cross validation, so - the results can be slightly different than those obtained by - predict. Also, it will produce meaningless results on very small - datasets. - """ - check_is_fitted(self) - return self._predict_proba(X) - - @available_if(_sklearn_NuSVC._check_proba) - def predict_log_proba(self, X): - """Compute log probabilities of possible outcomes for samples in X. - - The model need to have probability information computed at training - time: fit with attribute `probability` set to True. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) or \ - (n_samples_test, n_samples_train) - For kernel="precomputed", the expected shape of X is - (n_samples_test, n_samples_train). - - Returns - ------- - T : ndarray of shape (n_samples, n_classes) - Returns the log-probabilities of the sample for each class in - the model. The columns correspond to the classes in sorted - order, as they appear in the attribute :term:`classes_`. - - Notes - ----- - The probability model is created using cross validation, so - the results can be slightly different than those obtained by - predict. Also, it will produce meaningless results on very small - datasets. - """ - xp, _ = get_namespace(X) - - return xp.log(self.predict_proba(X)) - - else: - - @property - def predict_proba(self): - self._check_proba() - check_is_fitted(self) - return self._predict_proba - - def _predict_log_proba(self, X): - xp, _ = get_namespace(X) - return xp.log(self.predict_proba(X)) - - predict_proba.__doc__ = _sklearn_NuSVC.predict_proba.__doc__ - - @wrap_output_data - def _predict_proba(self, X): - sklearn_pred_proba = ( - _sklearn_NuSVC.predict_proba - if sklearn_check_version("1.0") - else _sklearn_NuSVC._predict_proba - ) - - return dispatch( - self, - "predict_proba", - { - "onedal": self.__class__._onedal_predict_proba, - "sklearn": sklearn_pred_proba, - }, - X, - ) - - @wrap_output_data - def decision_function(self, X): - check_is_fitted(self) - return dispatch( - self, - "decision_function", - { - "onedal": self.__class__._onedal_decision_function, - "sklearn": _sklearn_NuSVC.decision_function, - }, - X, - ) - - decision_function.__doc__ = _sklearn_NuSVC.decision_function.__doc__ - def _get_sample_weight(self, X, y, sample_weight=None): sample_weight = super()._get_sample_weight(X, y, sample_weight) if sample_weight is None: @@ -305,41 +159,4 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._save_attributes() - def _onedal_predict_proba(self, X, queue=None): - if getattr(self, "clf_prob", None) is None: - raise NotFittedError( - "predict_proba is not available when fitted with probability=False" - ) - from .._config import config_context, get_config - - # We use stock metaestimators below, so the only way - # to pass a queue is using config_context. - cfg = get_config() - cfg["target_offload"] = queue - with config_context(**cfg): - return self.clf_prob.predict_proba(X) - - def _onedal_decision_function(self, X, queue=None): - if sklearn_check_version("1.0"): - validate_data( - self, - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - accept_sparse="csr", - reset=False, - ) - else: - X = check_array( - X, - dtype=[np.float64, np.float32], - force_all_finite=False, - accept_sparse="csr", - ) - - return self._onedal_estimator.decision_function(X, queue=queue) - fit.__doc__ = _sklearn_NuSVC.fit.__doc__ - predict.__doc__ = _sklearn_NuSVC.predict.__doc__ - decision_function.__doc__ = _sklearn_NuSVC.decision_function.__doc__ - score.__doc__ = _sklearn_NuSVC.score.__doc__ diff --git a/sklearnex/svm/nusvr.py b/sklearnex/svm/nusvr.py index 174ed1327e..3b822b3c6a 100644 --- a/sklearnex/svm/nusvr.py +++ b/sklearnex/svm/nusvr.py @@ -36,7 +36,7 @@ @control_n_jobs(decorated_methods=["fit", "predict", "score"]) -class NuSVR(_sklearn_NuSVR, BaseSVR): +class NuSVR(BaseSVR, _sklearn_NuSVR): __doc__ = _sklearn_NuSVR.__doc__ _onedal_factory = onedal_NuSVR @@ -100,34 +100,6 @@ def fit(self, X, y, sample_weight=None): ) return self - @wrap_output_data - def predict(self, X): - check_is_fitted(self) - return dispatch( - self, - "predict", - { - "onedal": self.__class__._onedal_predict, - "sklearn": _sklearn_NuSVR.predict, - }, - X, - ) - - @wrap_output_data - def score(self, X, y, sample_weight=None): - check_is_fitted(self) - return dispatch( - self, - "score", - { - "onedal": self.__class__._onedal_score, - "sklearn": _sklearn_NuSVR.score, - }, - X, - y, - sample_weight=sample_weight, - ) - def _onedal_fit(self, X, y, sample_weight=None, queue=None): X, _, sample_weight = self._onedal_fit_checks(X, y, sample_weight) onedal_params = { @@ -148,5 +120,3 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._save_attributes() fit.__doc__ = _sklearn_NuSVR.fit.__doc__ - predict.__doc__ = _sklearn_NuSVR.predict.__doc__ - score.__doc__ = _sklearn_NuSVR.score.__doc__ diff --git a/sklearnex/svm/svc.py b/sklearnex/svm/svc.py index 78052ad74c..8b9330e887 100644 --- a/sklearnex/svm/svc.py +++ b/sklearnex/svm/svc.py @@ -16,7 +16,6 @@ import numpy as np from scipy import sparse as sp -from sklearn.exceptions import NotFittedError from sklearn.metrics import accuracy_score from sklearn.svm import SVC as _sklearn_SVC from sklearn.utils.validation import ( @@ -47,7 +46,7 @@ @control_n_jobs( decorated_methods=["fit", "predict", "_predict_proba", "decision_function", "score"] ) -class SVC(_sklearn_SVC, BaseSVC): +class SVC(BaseSVC, _sklearn_SVC): __doc__ = _sklearn_SVC.__doc__ _onedal_factory = onedal_SVC @@ -120,146 +119,6 @@ def fit(self, X, y, sample_weight=None): return self - @wrap_output_data - def predict(self, X): - check_is_fitted(self) - return dispatch( - self, - "predict", - { - "onedal": self.__class__._onedal_predict, - "sklearn": _sklearn_SVC.predict, - }, - X, - ) - - @wrap_output_data - def score(self, X, y, sample_weight=None): - check_is_fitted(self) - return dispatch( - self, - "score", - { - "onedal": self.__class__._onedal_score, - "sklearn": _sklearn_SVC.score, - }, - X, - y, - sample_weight=sample_weight, - ) - - if sklearn_check_version("1.0"): - - @available_if(_sklearn_SVC._check_proba) - def predict_proba(self, X): - """ - Compute probabilities of possible outcomes for samples in X. - - The model need to have probability information computed at training - time: fit with attribute `probability` set to True. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - For kernel="precomputed", the expected shape of X is - (n_samples_test, n_samples_train). - - Returns - ------- - T : ndarray of shape (n_samples, n_classes) - Returns the probability of the sample for each class in - the model. The columns correspond to the classes in sorted - order, as they appear in the attribute :term:`classes_`. - - Notes - ----- - The probability model is created using cross validation, so - the results can be slightly different than those obtained by - predict. Also, it will produce meaningless results on very small - datasets. - """ - check_is_fitted(self) - return self._predict_proba(X) - - @available_if(_sklearn_SVC._check_proba) - def predict_log_proba(self, X): - """Compute log probabilities of possible outcomes for samples in X. - - The model need to have probability information computed at training - time: fit with attribute `probability` set to True. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) or \ - (n_samples_test, n_samples_train) - For kernel="precomputed", the expected shape of X is - (n_samples_test, n_samples_train). - - Returns - ------- - T : ndarray of shape (n_samples, n_classes) - Returns the log-probabilities of the sample for each class in - the model. The columns correspond to the classes in sorted - order, as they appear in the attribute :term:`classes_`. - - Notes - ----- - The probability model is created using cross validation, so - the results can be slightly different than those obtained by - predict. Also, it will produce meaningless results on very small - datasets. - """ - xp, _ = get_namespace(X) - - return xp.log(self.predict_proba(X)) - - else: - - @property - def predict_proba(self): - self._check_proba() - check_is_fitted(self) - return self._predict_proba - - def _predict_log_proba(self, X): - xp, _ = get_namespace(X) - return xp.log(self.predict_proba(X)) - - predict_proba.__doc__ = _sklearn_SVC.predict_proba.__doc__ - - @wrap_output_data - def _predict_proba(self, X): - sklearn_pred_proba = ( - _sklearn_SVC.predict_proba - if sklearn_check_version("1.0") - else _sklearn_SVC._predict_proba - ) - - return dispatch( - self, - "predict_proba", - { - "onedal": self.__class__._onedal_predict_proba, - "sklearn": sklearn_pred_proba, - }, - X, - ) - - @wrap_output_data - def decision_function(self, X): - check_is_fitted(self) - return dispatch( - self, - "decision_function", - { - "onedal": self.__class__._onedal_decision_function, - "sklearn": _sklearn_SVC.decision_function, - }, - X, - ) - - decision_function.__doc__ = _sklearn_SVC.decision_function.__doc__ - def _onedal_gpu_supported(self, method_name, *data): class_name = self.__class__.__name__ patching_status = PatchingConditionsChain( @@ -336,21 +195,4 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._save_attributes() - def _onedal_predict_proba(self, X, queue=None): - if getattr(self, "clf_prob", None) is None: - raise NotFittedError( - "predict_proba is not available when fitted with probability=False" - ) - from .._config import config_context, get_config - - # We use stock metaestimators below, so the only way - # to pass a queue is using config_context. - cfg = get_config() - cfg["target_offload"] = queue - with config_context(**cfg): - return self.clf_prob.predict_proba(X) - fit.__doc__ = _sklearn_SVC.fit.__doc__ - predict.__doc__ = _sklearn_SVC.predict.__doc__ - decision_function.__doc__ = _sklearn_SVC.decision_function.__doc__ - score.__doc__ = _sklearn_SVC.score.__doc__ diff --git a/sklearnex/svm/svr.py b/sklearnex/svm/svr.py index d104bc9302..70965ab219 100644 --- a/sklearnex/svm/svr.py +++ b/sklearnex/svm/svr.py @@ -32,7 +32,7 @@ @control_n_jobs(decorated_methods=["fit", "predict", "score"]) -class SVR(_sklearn_SVR, BaseSVR): +class SVR(BaseSVR, _sklearn_SVR): __doc__ = _sklearn_SVR.__doc__ _onedal_factory = onedal_SVR @@ -97,34 +97,6 @@ def fit(self, X, y, sample_weight=None): return self - @wrap_output_data - def predict(self, X): - check_is_fitted(self) - return dispatch( - self, - "predict", - { - "onedal": self.__class__._onedal_predict, - "sklearn": _sklearn_SVR.predict, - }, - X, - ) - - @wrap_output_data - def score(self, X, y, sample_weight=None): - check_is_fitted(self) - return dispatch( - self, - "score", - { - "onedal": self.__class__._onedal_score, - "sklearn": _sklearn_SVR.score, - }, - X, - y, - sample_weight=sample_weight, - ) - def _onedal_fit(self, X, y, sample_weight=None, queue=None): X, _, sample_weight = self._onedal_fit_checks(X, y, sample_weight) onedal_params = { @@ -145,5 +117,3 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._save_attributes() fit.__doc__ = _sklearn_SVR.fit.__doc__ - predict.__doc__ = _sklearn_SVR.predict.__doc__ - score.__doc__ = _sklearn_SVR.score.__doc__ From 66dc2496b63e026d44ac84567af2af2ecdc73e9d Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 22:05:19 +0100 Subject: [PATCH 10/28] swap score information --- sklearnex/svm/_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index aa1f8f1d90..4b7bd31b5f 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -574,7 +574,7 @@ def score(self, X, y, sample_weight=None): "score", { "onedal": self.__class__._onedal_score, - "sklearn": _sklearn_BaseLibSVM.score, + "sklearn": RegressorMixin.score, }, X, y, @@ -607,4 +607,4 @@ def _onedal_score(self, X, y, sample_weight=None, queue=None): ) predict.__doc__ = _sklearn_BaseLibSVM.predict.__doc__ - score.__doc__ = _sklearn_BaseLibSVM.score.__doc__ + score.__doc__ = RegressorMixin.score.__doc__ From 25670226b847419c5896e62e953377e40c267b10 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 22:09:03 +0100 Subject: [PATCH 11/28] further fixes --- sklearnex/svm/_common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 4b7bd31b5f..486d357b43 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -391,14 +391,14 @@ def _predict_log_proba(self, X): xp, _ = get_namespace(X) return xp.log(self.predict_proba(X)) - predict_proba.__doc__ = _sklearn_NuSVC.predict_proba.__doc__ + predict_proba.__doc__ = _sklearn_BaseSVC.predict_proba.__doc__ @wrap_output_data def _predict_proba(self, X): sklearn_pred_proba = ( - _sklearn_NuSVC.predict_proba + _sklearn_BaseSVC.predict_proba if sklearn_check_version("1.0") - else _sklearn_NuSVC._predict_proba + else _sklearn_BaseSVC._predict_proba ) return dispatch( From 94b7452cfb2ed95da6c72caa7c7c7d262389fdfc Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 22:32:59 +0100 Subject: [PATCH 12/28] try to remove method resolution problem temporarily --- sklearnex/svm/_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 486d357b43..d4ce248062 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -552,7 +552,7 @@ def _onedal_score(self, X, y, sample_weight=None, queue=None): score.__doc__ = _sklearn_BaseSVC.score.__doc__ -class BaseSVR(BaseSVM, _sklearn_BaseLibSVM, RegressorMixin): +class BaseSVR(BaseSVM, _sklearn_BaseLibSVM): @wrap_output_data def predict(self, X): check_is_fitted(self) From fd11f04e57bb039ecbb760189fbb7cd8efc58d7a Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 23:17:23 +0100 Subject: [PATCH 13/28] make modifications to tests --- onedal/svm/tests/test_csr_svm.py | 6 +++--- onedal/svm/tests/test_nusvc.py | 12 ++++++------ onedal/svm/tests/test_svc.py | 16 ++++++++-------- onedal/svm/tests/test_svr.py | 18 ++++++++++-------- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/onedal/svm/tests/test_csr_svm.py b/onedal/svm/tests/test_csr_svm.py index e4a05a030e..c93833f2ec 100644 --- a/onedal/svm/tests/test_csr_svm.py +++ b/onedal/svm/tests/test_csr_svm.py @@ -61,11 +61,11 @@ def check_svm_model_equal( def _test_simple_dataset(queue, kernel): - X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]) + X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=np.float64) sparse_X = sp.lil_matrix(X) - Y = [1, 1, 1, 2, 2, 2] + Y = np.array([1, 1, 1, 2, 2, 2], dtype=np.float64) - X2 = np.array([[-1, -1], [2, 2], [3, 2]]) + X2 = np.array([[-1, -1], [2, 2], [3, 2]], dtype=np.float64) sparse_X2 = sp.dok_matrix(X2) dataset = sparse_X, Y, sparse_X2 diff --git a/onedal/svm/tests/test_nusvc.py b/onedal/svm/tests/test_nusvc.py index c8bf99a9d3..13d94fa552 100644 --- a/onedal/svm/tests/test_nusvc.py +++ b/onedal/svm/tests/test_nusvc.py @@ -32,7 +32,7 @@ def _test_libsvm_parameters(queue, array_constr, dtype): X = array_constr([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=dtype) - y = array_constr([1, 1, 1, 2, 2, 2], dtype=dtype) + y = array_constr([0, 0, 0, 1, 1, 1], dtype=dtype) clf = NuSVC(kernel="linear").fit(X, y, queue=queue) assert_array_almost_equal( @@ -41,7 +41,7 @@ def _test_libsvm_parameters(queue, array_constr, dtype): assert_array_equal(clf.support_, [0, 1, 3, 4]) assert_array_equal(clf.support_vectors_, X[clf.support_]) assert_array_equal(clf.intercept_, [0.0]) - assert_array_equal(clf.predict(X, queue=queue), y) + assert_array_equal(clf.predict(X, queue=queue).ravel(), y) @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") @@ -55,12 +55,12 @@ def test_libsvm_parameters(queue, array_constr, dtype): @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") @pytest.mark.parametrize("queue", get_queues()) def test_class_weight(queue): - X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]) - y = np.array([1, 1, 1, 2, 2, 2]) + X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=np.float64) + y = np.array([0, 0, 0, 1, 1, 1], dtype=np.float64) - clf = NuSVC(class_weight={1: 0.1}) + clf = NuSVC(class_weight={0: 0.1}) clf.fit(X, y, queue=queue) - assert_array_almost_equal(clf.predict(X, queue=queue), [2] * 6) + assert_array_almost_equal(clf.predict(X, queue=queue), [1] * 6) @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") diff --git a/onedal/svm/tests/test_svc.py b/onedal/svm/tests/test_svc.py index 9f7eaa4810..650f153ae5 100644 --- a/onedal/svm/tests/test_svc.py +++ b/onedal/svm/tests/test_svc.py @@ -32,14 +32,14 @@ def _test_libsvm_parameters(queue, array_constr, dtype): X = array_constr([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=dtype) - y = array_constr([1, 1, 1, 2, 2, 2], dtype=dtype) + y = array_constr([0, 0, 0, 1, 1, 1], dtype=dtype) clf = SVC(kernel="linear").fit(X, y, queue=queue) assert_array_equal(clf.dual_coef_, [[-0.25, 0.25]]) assert_array_equal(clf.support_, [1, 3]) assert_array_equal(clf.support_vectors_, (X[1], X[3])) assert_array_equal(clf.intercept_, [0.0]) - assert_array_equal(clf.predict(X), y) + assert_array_equal(clf.predict(X).ravel(), y) @pytest.mark.parametrize("queue", get_queues()) @@ -65,12 +65,12 @@ def test_libsvm_parameters(queue, array_constr, dtype): ], ) def test_class_weight(queue): - X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]) - y = np.array([1, 1, 1, 2, 2, 2]) + X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=np.float64) + y = np.array([0, 0, 0, 1, 1, 1], dtype=np.float64) - clf = SVC(class_weight={1: 0.1}) + clf = SVC(class_weight={0: 0.1}) clf.fit(X, y, queue=queue) - assert_array_almost_equal(clf.predict(X, queue=queue), [2] * 6) + assert_array_almost_equal(clf.predict(X, queue=queue), [1] * 6) @pytest.mark.parametrize("queue", get_queues()) @@ -160,9 +160,9 @@ def test_svc_sigmoid(queue, dtype): [[-1, 2], [0, 0], [2, -1], [+1, +1], [+1, +2], [+2, +1]], dtype=dtype ) X_test = np.array([[0, 2], [0.5, 0.5], [0.3, 0.1], [2, 0], [-1, -1]], dtype=dtype) - y_train = np.array([1, 1, 1, 2, 2, 2], dtype=dtype) + y_train = np.array([0, 0, 0, 1, 1, 1], dtype=dtype) svc = SVC(kernel="sigmoid").fit(X_train, y_train, queue=queue) assert_array_equal(svc.dual_coef_, [[-1, -1, -1, 1, 1, 1]]) assert_array_equal(svc.support_, [0, 1, 2, 3, 4, 5]) - assert_array_equal(svc.predict(X_test, queue=queue), [2, 2, 1, 2, 1]) + assert_array_equal(svc.predict(X_test, queue=queue).ravel(), [1, 1, 0, 1, 0]) diff --git a/onedal/svm/tests/test_svr.py b/onedal/svm/tests/test_svr.py index a9000ff5f7..591387165a 100644 --- a/onedal/svm/tests/test_svr.py +++ b/onedal/svm/tests/test_svr.py @@ -206,22 +206,24 @@ def test_synth_poly_compare_with_sklearn(queue, params): def test_sided_sample_weight(queue): clf = SVR(C=1e-2, kernel="linear") - X = [[-2, 0], [-1, -1], [0, -2], [0, 2], [1, 1], [2, 0]] - Y = [1, 1, 1, 2, 2, 2] + X = np.array([[-2, 0], [-1, -1], [0, -2], [0, 2], [1, 1], [2, 0]], dtype=np.float64) + Y = np.array([1, 1, 1, 2, 2, 2], dtype=np.float64) - sample_weight = [10.0, 0.1, 0.1, 0.1, 0.1, 10] + X_pred = np.array([[-1.0, 1.0]], dtype=np.float64) + + sample_weight = np.array([10.0, 0.1, 0.1, 0.1, 0.1, 10], dtype=np.float64) clf.fit(X, Y, sample_weight=sample_weight, queue=queue) - y_pred = clf.predict([[-1.0, 1.0]], queue=queue) + y_pred = clf.predict(X_pred, queue=queue) assert y_pred < 1.5 - sample_weight = [1.0, 0.1, 10.0, 10.0, 0.1, 0.1] + sample_weight = np.array([1.0, 0.1, 10.0, 10.0, 0.1, 0.1], dtype=np.float64) clf.fit(X, Y, sample_weight=sample_weight, queue=queue) - y_pred = clf.predict([[-1.0, 1.0]], queue=queue) + y_pred = clf.predict(X_pred, queue=queue) assert y_pred > 1.5 - sample_weight = [1] * 6 + sample_weight = np.array([1] * 6, dtype=np.float64) clf.fit(X, Y, sample_weight=sample_weight, queue=queue) - y_pred = clf.predict([[-1.0, 1.0]], queue=queue) + y_pred = clf.predict(X_pred, queue=queue) assert y_pred == pytest.approx(1.5) From 2ceccf77ae8a73b7d50d4e01648d1de0aa70c183 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 23:18:06 +0100 Subject: [PATCH 14/28] ravels --- onedal/svm/tests/test_nusvc.py | 2 +- onedal/svm/tests/test_svc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onedal/svm/tests/test_nusvc.py b/onedal/svm/tests/test_nusvc.py index 13d94fa552..bc9f43b9b6 100644 --- a/onedal/svm/tests/test_nusvc.py +++ b/onedal/svm/tests/test_nusvc.py @@ -60,7 +60,7 @@ def test_class_weight(queue): clf = NuSVC(class_weight={0: 0.1}) clf.fit(X, y, queue=queue) - assert_array_almost_equal(clf.predict(X, queue=queue), [1] * 6) + assert_array_almost_equal(clf.predict(X, queue=queue).ravel(), [1] * 6) @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") diff --git a/onedal/svm/tests/test_svc.py b/onedal/svm/tests/test_svc.py index 650f153ae5..3210d4e0c4 100644 --- a/onedal/svm/tests/test_svc.py +++ b/onedal/svm/tests/test_svc.py @@ -70,7 +70,7 @@ def test_class_weight(queue): clf = SVC(class_weight={0: 0.1}) clf.fit(X, y, queue=queue) - assert_array_almost_equal(clf.predict(X, queue=queue), [1] * 6) + assert_array_almost_equal(clf.predict(X, queue=queue).ravel(), [1] * 6) @pytest.mark.parametrize("queue", get_queues()) From 806f372e36ee9eef71265932a452cf58d3dd0070 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 23:52:25 +0100 Subject: [PATCH 15/28] fix errors, reduce code, move decision_function --- onedal/svm/svm.py | 74 +++----------------------------- onedal/svm/tests/test_csr_svm.py | 4 +- sklearnex/svm/_common.py | 45 +++++++++++++++++-- 3 files changed, 50 insertions(+), 73 deletions(-) diff --git a/onedal/svm/svm.py b/onedal/svm/svm.py index 86d75db1ff..b0873eaaa0 100644 --- a/onedal/svm/svm.py +++ b/onedal/svm/svm.py @@ -199,57 +199,9 @@ def _create_model(self, module): m.biases = to_table(self.intercept_) return m - def _predict(self, X, module, queue): + def _infer(self, X, module, queue): _check_is_fitted(self) - if self._sparse and not sp.isspmatrix(X): - X = sp.csr_matrix(X) - if self._sparse: - X.sort_indices() - - if sp.issparse(X) and not self._sparse and not callable(self.kernel): - raise ValueError( - "cannot use sparse input in %r trained on dense data" - % type(self).__name__ - ) - - policy = _get_policy(queue, X) - X = to_table(_convert_to_supported(policy, X)) - params = self._get_onedal_params(X.dtype) - - if hasattr(self, "_onedal_model"): - model = self._onedal_model - else: - model = self._create_model(module) - result = module.infer(policy, params, model, X) - return from_table(result.responses) - - def _ovr_decision_function(self, predictions, confidences, n_classes): - n_samples = predictions.shape[0] - votes = np.zeros((n_samples, n_classes)) - sum_of_confidences = np.zeros((n_samples, n_classes)) - - k = 0 - for i in range(n_classes): - for j in range(i + 1, n_classes): - sum_of_confidences[:, i] -= confidences[:, k] - sum_of_confidences[:, j] += confidences[:, k] - votes[predictions[:, k] == 0, i] += 1 - votes[predictions[:, k] == 1, j] += 1 - k += 1 - - transformed_confidences = sum_of_confidences / ( - 3 * (np.abs(sum_of_confidences) + 1) - ) - return votes + transformed_confidences - - def _decision_function(self, X, module, queue): - _check_is_fitted(self) - X = _check_array( - X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse="csr" - ) - _check_n_features(self, X, False) - if self._sparse: if not sp.isspmatrix(X): X = sp.csr_matrix(X) @@ -261,14 +213,6 @@ def _decision_function(self, X, module, queue): % type(self).__name__ ) - if module in [_backend.svm.classification, _backend.svm.nu_classification]: - sv = self.support_vectors_ - if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]: - raise ValueError( - "The internal representation " - f"of {self.__class__.__name__} was altered" - ) - policy = _get_policy(queue, X) X = to_table(_convert_to_supported(policy, X)) params = self._get_onedal_params(X.dtype) @@ -277,17 +221,13 @@ def _decision_function(self, X, module, queue): model = self._onedal_model else: model = self._create_model(module) - result = module.infer(policy, params, model, X) - decision_function = from_table(result.decision_function) - - if len(self.classes_) == 2: - decision_function = decision_function.ravel() + return module.infer(policy, params, model, X) - if self.decision_function_shape == "ovr" and len(self.classes_) > 2: - decision_function = self._ovr_decision_function( - decision_function < 0, -decision_function, len(self.classes_) - ) - return decision_function + def _predict(self, X, module, queue): + return from_table(self._infer(X, module, queue).responses) + + def _decision_function(self, X, module, queue): + return from_table(self._infer(X, module, queue).decision_function) class SVR(RegressorMixin, BaseSVM): diff --git a/onedal/svm/tests/test_csr_svm.py b/onedal/svm/tests/test_csr_svm.py index c93833f2ec..3720798e8b 100644 --- a/onedal/svm/tests/test_csr_svm.py +++ b/onedal/svm/tests/test_csr_svm.py @@ -62,11 +62,11 @@ def check_svm_model_equal( def _test_simple_dataset(queue, kernel): X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=np.float64) - sparse_X = sp.lil_matrix(X) + sparse_X = sp.csr_matrix(X) Y = np.array([1, 1, 1, 2, 2, 2], dtype=np.float64) X2 = np.array([[-1, -1], [2, 2], [3, 2]], dtype=np.float64) - sparse_X2 = sp.dok_matrix(X2) + sparse_X2 = sp.csr_matrix(X2) dataset = sparse_X, Y, sparse_X2 clf0 = SVC(kernel=kernel, gamma=1) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index d4ce248062..9524f14911 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -489,6 +489,8 @@ def _onedal_predict(self, X, queue=None): "The internal representation " f"of {self.__class__.__name__} was altered" ) + xp, _ = get_namespace(X) + if self.break_ties and self.decision_function_shape == "ovo": raise ValueError( "break_ties must be False when " "decision_function_shape is 'ovo'" @@ -501,20 +503,46 @@ def _onedal_predict(self, X, queue=None): ): return xp.argmax(self._onedal_decision_function(X, queue=queue), axis=1) - xp, _ = get_namespace(X) res = super()._onedal_predict(X, queue=queue, xp=xp) if len(self.classes_) != 2: res = xp.take(self.classes_, xp.asarray(res, dtype=xp.int32)) return res + + + def _onedal_ovr_decision_function(self, predictions, confidences, n_classes): + # This function is legacy from the original implementation and needs + # to be refactored. + xp, _ = get_namespace(predictions) + n_samples = predictions.shape[0] + votes = xp.zeros((n_samples, n_classes)) + sum_of_confidences = xp.zeros((n_samples, n_classes)) + + k = 0 + for i in range(n_classes): + for j in range(i + 1, n_classes): + sum_of_confidences[:, i] -= confidences[:, k] + sum_of_confidences[:, j] += confidences[:, k] + votes[predictions[:, k] == 0, i] += 1 + votes[predictions[:, k] == 1, j] += 1 + k += 1 + + transformed_confidences = sum_of_confidences / ( + 3 * (xp.abs(sum_of_confidences) + 1) + ) + return votes + transformed_confidences def _onedal_decision_function(self, X, queue=None): + sv = self.support_vectors_ + if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]: + raise ValueError( + "The internal representation " f"of {self.__class__.__name__} was altered" + ) xp, _ = get_namespace(X) if sklearn_check_version("1.0"): validate_data( self, X, dtype=[xp.float64, xp.float32], - force_all_finite=False, accept_sparse="csr", reset=False, ) @@ -522,11 +550,20 @@ def _onedal_decision_function(self, X, queue=None): X = check_array( X, dtype=[xp.float64, xp.float32], - force_all_finite=False, accept_sparse="csr", ) - return self._onedal_estimator.decision_function(X, queue=queue) + decision_function = self._onedal_estimator.decision_function(X, queue=queue) + + if len(self.classes_) == 2: + decision_function = decision_function.ravel() + elif len(self.classes_) > 2 and self.decision_function_shape == "ovr": + decision_function = self._onedal_ovr_decision_function( + decision_function < 0, -decision_function, len(self.classes_) + ) + + return decision_function + def _onedal_predict_proba(self, X, queue=None): if getattr(self, "clf_prob", None) is None: From 6d863fbd819222a9cd6ca0af669deed02c904dfc Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Wed, 4 Dec 2024 23:53:00 +0100 Subject: [PATCH 16/28] formatting --- onedal/svm/svm.py | 2 +- sklearnex/svm/_common.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/onedal/svm/svm.py b/onedal/svm/svm.py index b0873eaaa0..8895e35aef 100644 --- a/onedal/svm/svm.py +++ b/onedal/svm/svm.py @@ -225,7 +225,7 @@ def _infer(self, X, module, queue): def _predict(self, X, module, queue): return from_table(self._infer(X, module, queue).responses) - + def _decision_function(self, X, module, queue): return from_table(self._infer(X, module, queue).decision_function) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 9524f14911..7d20329cfa 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -507,8 +507,7 @@ def _onedal_predict(self, X, queue=None): if len(self.classes_) != 2: res = xp.take(self.classes_, xp.asarray(res, dtype=xp.int32)) return res - - + def _onedal_ovr_decision_function(self, predictions, confidences, n_classes): # This function is legacy from the original implementation and needs # to be refactored. @@ -561,9 +560,8 @@ def _onedal_decision_function(self, X, queue=None): decision_function = self._onedal_ovr_decision_function( decision_function < 0, -decision_function, len(self.classes_) ) - - return decision_function + return decision_function def _onedal_predict_proba(self, X, queue=None): if getattr(self, "clf_prob", None) is None: From 32e8754682d7eae0bed24e31a6f1357e1f0a6e40 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 5 Dec 2024 00:18:15 +0100 Subject: [PATCH 17/28] last fixes of the day --- onedal/svm/svm.py | 1 - onedal/svm/tests/test_nusvc.py | 2 +- onedal/svm/tests/test_svc.py | 2 +- sklearnex/svm/_common.py | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/onedal/svm/svm.py b/onedal/svm/svm.py index 8895e35aef..f7224bdc58 100644 --- a/onedal/svm/svm.py +++ b/onedal/svm/svm.py @@ -15,7 +15,6 @@ # ============================================================================== from abc import ABCMeta, abstractmethod -from enum import Enum import numpy as np from scipy import sparse as sp diff --git a/onedal/svm/tests/test_nusvc.py b/onedal/svm/tests/test_nusvc.py index bc9f43b9b6..6e514da17e 100644 --- a/onedal/svm/tests/test_nusvc.py +++ b/onedal/svm/tests/test_nusvc.py @@ -85,7 +85,7 @@ def test_decision_function(queue): rbfs = rbf_kernel(X, clf.support_vectors_, gamma=clf.gamma) dec = np.dot(rbfs, clf.dual_coef_.T) + clf.intercept_ - assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue)) + assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue).ravel()) @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") diff --git a/onedal/svm/tests/test_svc.py b/onedal/svm/tests/test_svc.py index 3210d4e0c4..8e148383fa 100644 --- a/onedal/svm/tests/test_svc.py +++ b/onedal/svm/tests/test_svc.py @@ -95,7 +95,7 @@ def test_decision_function(queue): rbfs = rbf_kernel(X, clf.support_vectors_, gamma=clf.gamma) dec = np.dot(rbfs, clf.dual_coef_.T) + clf.intercept_ - assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue)) + assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue).ravel()) @pass_if_not_implemented_for_gpu(reason="multiclass svm is not implemented") diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 7d20329cfa..abd3726a0d 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -538,7 +538,7 @@ def _onedal_decision_function(self, X, queue=None): ) xp, _ = get_namespace(X) if sklearn_check_version("1.0"): - validate_data( + X = validate_data( self, X, dtype=[xp.float64, xp.float32], From e2aabde40345656daceb1863048a45cd4ea53d0e Mon Sep 17 00:00:00 2001 From: Ian Faust Date: Thu, 5 Dec 2024 06:37:30 +0100 Subject: [PATCH 18/28] Update test_nusvc.py --- onedal/svm/tests/test_nusvc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onedal/svm/tests/test_nusvc.py b/onedal/svm/tests/test_nusvc.py index 6e514da17e..692a9d78c4 100644 --- a/onedal/svm/tests/test_nusvc.py +++ b/onedal/svm/tests/test_nusvc.py @@ -77,8 +77,8 @@ def test_sample_weight(queue): @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") @pytest.mark.parametrize("queue", get_queues()) def test_decision_function(queue): - X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] - Y = [1, 1, 1, 2, 2, 2] + X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=np.float32) + Y = np.array([1, 1, 1, 2, 2, 2], dtype=np.float32) clf = NuSVC(kernel="rbf", gamma=1, decision_function_shape="ovo") clf.fit(X, Y, queue=queue) From 02dce0545d98e496bda1a6b0da8178e9644dcf7c Mon Sep 17 00:00:00 2001 From: Ian Faust Date: Thu, 5 Dec 2024 07:24:28 +0100 Subject: [PATCH 19/28] Update _common.py --- sklearnex/svm/_common.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index abd3726a0d..80583f3c00 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -504,9 +504,7 @@ def _onedal_predict(self, X, queue=None): return xp.argmax(self._onedal_decision_function(X, queue=queue), axis=1) res = super()._onedal_predict(X, queue=queue, xp=xp) - if len(self.classes_) != 2: - res = xp.take(self.classes_, xp.asarray(res, dtype=xp.int32)) - return res + return xp.take(self.classes_, xp.asarray(res, dtype=xp.int32)) def _onedal_ovr_decision_function(self, predictions, confidences, n_classes): # This function is legacy from the original implementation and needs From 7cee399b39ff2bb950add84450a34355f9808d74 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 5 Dec 2024 08:33:16 +0100 Subject: [PATCH 20/28] make similar to previous --- sklearnex/svm/_common.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 80583f3c00..7cf87e5024 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -504,7 +504,12 @@ def _onedal_predict(self, X, queue=None): return xp.argmax(self._onedal_decision_function(X, queue=queue), axis=1) res = super()._onedal_predict(X, queue=queue, xp=xp) - return xp.take(self.classes_, xp.asarray(res, dtype=xp.int32)) + return xp.take( + self.classes_, + xp.asarray( + res if len(self.classes_) != 2 else xp.reshape(res, (-1,)), dtype=xp.int32 + ), + ) def _onedal_ovr_decision_function(self, predictions, confidences, n_classes): # This function is legacy from the original implementation and needs From a0f941e391a89b1a53dd64c6e64653c85b837a4f Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 5 Dec 2024 09:21:13 +0100 Subject: [PATCH 21/28] further modification to try and fix sklearn conformance --- sklearnex/svm/_common.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 7cf87e5024..aa4bcc8dab 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -105,7 +105,7 @@ def _onedal_cpu_supported(self, method_name, *data): return patching_status inference_methods = ( ["predict", "score"] - if class_name.endswith("R") + if isinstance(self, RegressorMixin) else ["predict", "predict_proba", "decision_function", "score"] ) if method_name in inference_methods: @@ -268,7 +268,7 @@ def _onedal_predict(self, X, queue=None, xp=None): accept_sparse="csr", ) - return xp.squeeze(self._onedal_estimator.predict(X, queue=queue)) + return self._onedal_estimator.predict(X, queue=queue) class BaseSVC(BaseSVM, _sklearn_BaseSVC): @@ -501,15 +501,17 @@ def _onedal_predict(self, X, queue=None): and self.decision_function_shape == "ovr" and len(self.classes_) > 2 ): - return xp.argmax(self._onedal_decision_function(X, queue=queue), axis=1) - - res = super()._onedal_predict(X, queue=queue, xp=xp) - return xp.take( - self.classes_, - xp.asarray( - res if len(self.classes_) != 2 else xp.reshape(res, (-1,)), dtype=xp.int32 - ), - ) + res = xp.argmax(self._onedal_decision_function(X, queue=queue), axis=1) + else: + res = super()._onedal_predict(X, queue=queue, xp=xp) + + # the extensive reshaping here comes from the previous implementation, and + # should be sorted out, as this is inefficient and likely can be reduced + res = xp.asarray(res, dtype=xp.int32) + if len(self.classes_) == 2: + res = xp.reshape(res, (-1,)) + + return xp.reshape(xp.take(xp.asarray(self.classes_), res), (-1,)) def _onedal_ovr_decision_function(self, predictions, confidences, n_classes): # This function is legacy from the original implementation and needs @@ -564,7 +566,7 @@ def _onedal_decision_function(self, X, queue=None): decision_function < 0, -decision_function, len(self.classes_) ) - return decision_function + return xp.asarray(decision_function) def _onedal_predict_proba(self, X, queue=None): if getattr(self, "clf_prob", None) is None: From b29c1c6c5002cef38e8ebe0c71fc8729c0d5884a Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 5 Dec 2024 09:23:49 +0100 Subject: [PATCH 22/28] forgot about regression --- sklearnex/svm/_common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index aa4bcc8dab..40e0a331c9 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -641,6 +641,11 @@ def _save_attributes(self): self._dualcoef_ = self.dual_coef_ + def _onedal_predict(self, X, queue=None): + xp, _ = get_namespace(X) + res = super()._onedal_predict(X, queue) + return xp.reshape(res, (-1,)) + def _onedal_score(self, X, y, sample_weight=None, queue=None): return r2_score( y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight From 45ee05dc50882bf2dec6e464abb22e981a704e64 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 5 Dec 2024 12:51:42 +0100 Subject: [PATCH 23/28] fix mro issue --- sklearnex/svm/_common.py | 6 +++--- sklearnex/svm/nusvc.py | 2 +- sklearnex/svm/svc.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 40e0a331c9..da7ab39927 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -46,7 +46,7 @@ validate_data = BaseEstimator._validate_data -class BaseSVM(BaseEstimator): +class BaseSVM(object): _onedal_factory = None @@ -271,7 +271,7 @@ def _onedal_predict(self, X, queue=None, xp=None): return self._onedal_estimator.predict(X, queue=queue) -class BaseSVC(BaseSVM, _sklearn_BaseSVC): +class BaseSVC(BaseSVM): @wrap_output_data def predict(self, X): @@ -592,7 +592,7 @@ def _onedal_score(self, X, y, sample_weight=None, queue=None): score.__doc__ = _sklearn_BaseSVC.score.__doc__ -class BaseSVR(BaseSVM, _sklearn_BaseLibSVM): +class BaseSVR(BaseSVM): @wrap_output_data def predict(self, X): check_is_fitted(self) diff --git a/sklearnex/svm/nusvc.py b/sklearnex/svm/nusvc.py index d613bf8ca7..b53cb71f63 100644 --- a/sklearnex/svm/nusvc.py +++ b/sklearnex/svm/nusvc.py @@ -32,7 +32,7 @@ if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data else: - validate_data = BaseSVC._validate_data + validate_data = _sklearn_NuSVC._validate_data @control_n_jobs( diff --git a/sklearnex/svm/svc.py b/sklearnex/svm/svc.py index 8b9330e887..e87085bd17 100644 --- a/sklearnex/svm/svc.py +++ b/sklearnex/svm/svc.py @@ -40,7 +40,7 @@ if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data else: - validate_data = BaseSVC._validate_data + validate_data = _sklearn_SVC._validate_data @control_n_jobs( From 470a8c9a8f5e7b91b4072ebc8b929fd83ee0def6 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 5 Dec 2024 13:09:38 +0100 Subject: [PATCH 24/28] refactoring --- sklearnex/svm/nusvr.py | 2 +- sklearnex/svm/svr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearnex/svm/nusvr.py b/sklearnex/svm/nusvr.py index 3b822b3c6a..557d55bcb5 100644 --- a/sklearnex/svm/nusvr.py +++ b/sklearnex/svm/nusvr.py @@ -32,7 +32,7 @@ if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data else: - validate_data = BaseSVR._validate_data + validate_data = _sklearn_NuSVR._validate_data @control_n_jobs(decorated_methods=["fit", "predict", "score"]) diff --git a/sklearnex/svm/svr.py b/sklearnex/svm/svr.py index 70965ab219..fc99e191a9 100644 --- a/sklearnex/svm/svr.py +++ b/sklearnex/svm/svr.py @@ -28,7 +28,7 @@ if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data else: - validate_data = BaseSVR._validate_data + validate_data = _sklearn_SVR._validate_data @control_n_jobs(decorated_methods=["fit", "predict", "score"]) From 5bed19b13a6c73b28721ddac9dea226191a79fc2 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 5 Dec 2024 21:48:03 +0100 Subject: [PATCH 25/28] remove vestigial code --- onedal/svm/svm.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/onedal/svm/svm.py b/onedal/svm/svm.py index f7224bdc58..f6732edda0 100644 --- a/onedal/svm/svm.py +++ b/onedal/svm/svm.py @@ -119,6 +119,7 @@ def _fit(self, X, y, sample_weight, module, queue): force_all_finite=True, accept_sparse="csr", ) + # hard work remains on moving validate targets away from onedal y = self._validate_targets(y, X.dtype) if sample_weight is not None and len(sample_weight) > 0: sample_weight = _check_array( @@ -144,24 +145,7 @@ def _fit(self, X, y, sample_weight, module, queue): self._scale_, self._sigma_ = 1.0, 1.0 self.coef0 = 0.0 else: - if isinstance(self.gamma, str): - if self.gamma == "scale": - if sp.issparse(X): - # var = E[X^2] - E[X]^2 - X_sc = (X.multiply(X)).mean() - (X.mean()) ** 2 - else: - X_sc = X.var() - _gamma = 1.0 / (X.shape[1] * X_sc) if X_sc != 0 else 1.0 - elif self.gamma == "auto": - _gamma = 1.0 / X.shape[1] - else: - raise ValueError( - "When 'gamma' is a string, it should be either 'scale' or " - "'auto'. Got '{}' instead.".format(self.gamma) - ) - else: - _gamma = self.gamma - self._scale_, self._sigma_ = _gamma, np.sqrt(0.5 / _gamma) + self._scale_, self._sigma_ = self.gamma, np.sqrt(0.5 / self.gamma) policy = _get_policy(queue, *data) data_t = to_table(*_convert_to_supported(policy, *data)) @@ -180,6 +164,7 @@ def _fit(self, X, y, sample_weight, module, queue): self.n_features_in_ = X.shape[1] self.shape_fit_ = X.shape + # _n_support not used in this object, will be moved to sklearnex if getattr(self, "classes_", None) is not None: indices = y.take(self.support_, axis=0) self._n_support = np.array( From 39b4bc125c9d45d2987d3e59770a431a89c36457 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 5 Dec 2024 23:02:25 +0100 Subject: [PATCH 26/28] move initial tests from onedal to sklearnex --- sklearnex/svm/tests/test_svm.py | 81 +++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/sklearnex/svm/tests/test_svm.py b/sklearnex/svm/tests/test_svm.py index f0d561744e..985fc1f44c 100755 --- a/sklearnex/svm/tests/test_svm.py +++ b/sklearnex/svm/tests/test_svm.py @@ -16,13 +16,20 @@ import numpy as np import pytest +import scipy.sparse as sp from numpy.testing import assert_allclose +from sklearn.datasets import load_diabetes, load_iris, make_classification +from onedal.svm.tests.test_csr_svm import check_svm_model_equal from onedal.tests.utils._dataframes_support import ( _as_numpy, _convert_to_dataframe, get_dataframes_and_queues, ) +from onedal.tests.utils._device_selection import ( + get_queues, + pass_if_not_implemented_for_gpu, +) @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @@ -91,3 +98,77 @@ def test_sklearnex_import_nusvr(dataframe, queue): _as_numpy(svc.dual_coef_), [[-1.0, 0.611111, 1.0, -0.611111]], rtol=1e-3 ) assert_allclose(_as_numpy(svc.support_), [1, 2, 3, 5]) + + +@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") +@pytest.mark.parametrize( + "queue", + get_queues("cpu") + + [ + pytest.param( + get_queues("gpu"), + marks=pytest.mark.xfail( + reason="raises UnknownError for linear and rbf, " + "Unimplemented error with inconsistent error message " + "for poly and sigmoid" + ), + ) + ], +) +@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"]) +def test_binary_dataset(queue, kernel): + from sklearnex import config_context + from sklearnex.svm import SVC + + X, y = make_classification(n_samples=80, n_features=20, n_classes=2, random_state=0) + sparse_X = sp.csr_matrix(X) + + dataset = sparse_X, y, sparse_X + with config_context(target_offload=queue): + clf0 = SVC(kernel=kernel) + clf1 = SVC(kernel=kernel) + check_svm_model_equal(queue, clf0, clf1, *dataset) + + +@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") +@pytest.mark.parametrize("queue", get_queues()) +@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"]) +def test_iris(queue, kernel): + from sklearnex import config_context + from sklearnex.svm import SVC + + if kernel == "rbf": + pytest.skip("RBF CSR SVM test failing in 2025.0.") + iris = load_iris() + rng = np.random.RandomState(0) + perm = rng.permutation(iris.target.size) + iris.data = iris.data[perm] + iris.target = iris.target[perm] + sparse_iris_data = sp.csr_matrix(iris.data) + + dataset = sparse_iris_data, iris.target, sparse_iris_data + + with config_context(target_offload=queue): + clf0 = SVC(kernel=kernel) + clf1 = SVC(kernel=kernel) + check_svm_model_equal(queue, clf0, clf1, *dataset, decimal=2) + + +@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") +@pytest.mark.parametrize("queue", get_queues()) +@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"]) +def test_diabetes(queue, kernel): + from sklearnex import config_context + from sklearnex.svm import SVR + + if kernel == "sigmoid": + pytest.skip("Sparse sigmoid kernel function is buggy.") + diabetes = load_diabetes() + + sparse_diabetes_data = sp.csr_matrix(diabetes.data) + dataset = sparse_diabetes_data, diabetes.target, sparse_diabetes_data + + with config_context(target_offload=queue): + clf0 = SVR(kernel=kernel, C=0.1) + clf1 = SVR(kernel=kernel, C=0.1) + check_svm_model_equal(queue, clf0, clf1, *dataset) From 7448ad4190994a50035303dcb368ab5792474a2c Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 5 Dec 2024 23:02:39 +0100 Subject: [PATCH 27/28] remove vestigial code --- onedal/svm/tests/test_csr_svm.py | 74 -------------------------------- 1 file changed, 74 deletions(-) diff --git a/onedal/svm/tests/test_csr_svm.py b/onedal/svm/tests/test_csr_svm.py index 3720798e8b..04bcd77d75 100644 --- a/onedal/svm/tests/test_csr_svm.py +++ b/onedal/svm/tests/test_csr_svm.py @@ -93,80 +93,6 @@ def test_simple_dataset(queue, kernel): _test_simple_dataset(queue, kernel) -def _test_binary_dataset(queue, kernel): - X, y = make_classification(n_samples=80, n_features=20, n_classes=2, random_state=0) - sparse_X = sp.csr_matrix(X) - - dataset = sparse_X, y, sparse_X - clf0 = SVC(kernel=kernel) - clf1 = SVC(kernel=kernel) - check_svm_model_equal(queue, clf0, clf1, *dataset) - - -@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") -@pytest.mark.parametrize( - "queue", - get_queues("cpu") - + [ - pytest.param( - get_queues("gpu"), - marks=pytest.mark.xfail( - reason="raises UnknownError for linear and rbf, " - "Unimplemented error with inconsistent error message " - "for poly and sigmoid" - ), - ) - ], -) -@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"]) -def test_binary_dataset(queue, kernel): - _test_binary_dataset(queue, kernel) - - -def _test_iris(queue, kernel): - iris = datasets.load_iris() - rng = np.random.RandomState(0) - perm = rng.permutation(iris.target.size) - iris.data = iris.data[perm] - iris.target = iris.target[perm] - sparse_iris_data = sp.csr_matrix(iris.data) - - dataset = sparse_iris_data, iris.target, sparse_iris_data - - clf0 = SVC(kernel=kernel) - clf1 = SVC(kernel=kernel) - check_svm_model_equal(queue, clf0, clf1, *dataset, decimal=2) - - -@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") -@pytest.mark.parametrize("queue", get_queues()) -@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"]) -def test_iris(queue, kernel): - if kernel == "rbf": - pytest.skip("RBF CSR SVM test failing in 2025.0.") - _test_iris(queue, kernel) - - -def _test_diabetes(queue, kernel): - diabetes = datasets.load_diabetes() - - sparse_diabetes_data = sp.csr_matrix(diabetes.data) - dataset = sparse_diabetes_data, diabetes.target, sparse_diabetes_data - - clf0 = SVR(kernel=kernel, C=0.1) - clf1 = SVR(kernel=kernel, C=0.1) - check_svm_model_equal(queue, clf0, clf1, *dataset) - - -@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") -@pytest.mark.parametrize("queue", get_queues()) -@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"]) -def test_diabetes(queue, kernel): - if kernel == "sigmoid": - pytest.skip("Sparse sigmoid kernel function is buggy.") - _test_diabetes(queue, kernel) - - @pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") @pytest.mark.xfail(reason="Failed test. Need investigate") @pytest.mark.parametrize("queue", get_queues()) From b9807345f54af2a6d9c991efd1127265162de9f9 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 5 Dec 2024 23:06:10 +0100 Subject: [PATCH 28/28] switch rbf to linear --- onedal/svm/tests/test_nusvr.py | 2 +- onedal/svm/tests/test_svr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onedal/svm/tests/test_nusvr.py b/onedal/svm/tests/test_nusvr.py index 1bec991961..bfd40c5767 100644 --- a/onedal/svm/tests/test_nusvr.py +++ b/onedal/svm/tests/test_nusvr.py @@ -196,7 +196,7 @@ def test_synth_poly_compare_with_sklearn(queue, params): def test_pickle(queue): diabetes = datasets.load_diabetes() - clf = NuSVR(kernel="rbf", C=10.0) + clf = NuSVR(kernel="linear", C=10.0) clf.fit(diabetes.data, diabetes.target, queue=queue) expected = clf.predict(diabetes.data, queue=queue) diff --git a/onedal/svm/tests/test_svr.py b/onedal/svm/tests/test_svr.py index 591387165a..82ce86a78c 100644 --- a/onedal/svm/tests/test_svr.py +++ b/onedal/svm/tests/test_svr.py @@ -231,7 +231,7 @@ def test_sided_sample_weight(queue): @pytest.mark.parametrize("queue", get_queues()) def test_pickle(queue): diabetes = datasets.load_diabetes() - clf = SVR(kernel="rbf", C=10.0) + clf = SVR(kernel="linear", C=10.0) clf.fit(diabetes.data, diabetes.target, queue=queue) expected = clf.predict(diabetes.data, queue=queue)