diff --git a/deselected_tests.yaml b/deselected_tests.yaml index d35524c142..ea5c9af3a0 100755 --- a/deselected_tests.yaml +++ b/deselected_tests.yaml @@ -290,14 +290,6 @@ deselected_tests: - tests/test_common.py::test_estimators[LogisticRegression()-check_sample_weights_invariance(kind=zeros)] >=1.4 - tests/test_multioutput.py::test_classifier_chain_fit_and_predict_with_sparse_data >=1.4 - # Deselected tests for incremental algorithms - # Need to rework getting policy to correctly obtain it for method without data (finalize_fit) - # and avoid keeping it in class attribute, also need to investigate how to implement - # partial result serialization - - tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_estimators_pickle] - - tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_estimators_pickle(readonly_memmap=True)] - - tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle] - - tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle(readonly_memmap=True)] # There are not enough data to run onedal backend - tests/test_common.py::test_estimators[IncrementalRidge()-check_fit2d_1sample] diff --git a/onedal/datatypes/data_conversion.cpp b/onedal/datatypes/data_conversion.cpp index 569ad277c8..3b60518ad4 100644 --- a/onedal/datatypes/data_conversion.cpp +++ b/onedal/datatypes/data_conversion.cpp @@ -398,8 +398,7 @@ PyObject *convert_to_pyobject(const dal::table &input) { } if (input.get_kind() == dal::homogen_table::kind()) { const auto &homogen_input = static_cast(input); - if (homogen_input.get_data_layout() == dal::data_layout::row_major) { - const dal::data_type dtype = homogen_input.get_metadata().get_data_type(0); + const dal::data_type dtype = homogen_input.get_metadata().get_data_type(0); #define MAKE_NYMPY_FROM_HOMOGEN(NpType) \ { \ @@ -408,16 +407,10 @@ PyObject *convert_to_pyobject(const dal::table &input) { homogen_input.get_row_count(), \ homogen_input.get_column_count()); \ } - SET_CTYPE_NPY_FROM_DAL_TYPE( - dtype, - MAKE_NYMPY_FROM_HOMOGEN, - throw std::invalid_argument("Not avalible to convert a numpy")); + SET_CTYPE_NPY_FROM_DAL_TYPE(dtype, + MAKE_NYMPY_FROM_HOMOGEN, + throw std::invalid_argument("Unable to convert numpy object")); #undef MAKE_NYMPY_FROM_HOMOGEN - } - else { - throw std::invalid_argument( - "Output oneDAL table doesn't have row major format for homogen table"); - } } else if (input.get_kind() == csr_table_t::kind()) { const auto &csr_input = static_cast(input); @@ -427,7 +420,7 @@ PyObject *convert_to_pyobject(const dal::table &input) { SET_CTYPES_NPY_FROM_DAL_TYPE( dtype, MAKE_PY_FROM_CSR, - throw std::invalid_argument("Not avalible to convert a scipy.csr")); + throw std::invalid_argument("Unable to convert scipy csr object")); #undef MAKE_PY_FROM_CSR } else { diff --git a/onedal/linear_model/incremental_linear_model.py b/onedal/linear_model/incremental_linear_model.py index bc48d59077..539a6bf2b7 100644 --- a/onedal/linear_model/incremental_linear_model.py +++ b/onedal/linear_model/incremental_linear_model.py @@ -47,10 +47,22 @@ def __init__(self, fit_intercept=True, copy_X=False, algorithm="norm_eq"): self._reset() def _reset(self): + self._need_to_finalize = False self._partial_result = self._get_backend( "linear_model", "regression", "partial_train_result" ) + def __getstate__(self): + # Since finalize_fit can't be dispatched without directly provided queue + # and the dispatching policy can't be serialized, the computation is finalized + # here and the policy is not saved in serialized data. + + self.finalize_fit() + data = self.__dict__.copy() + data.pop("_queue", None) + + return data + def partial_fit(self, X, y, queue=None): """ Computes partial data for linear regression @@ -105,6 +117,9 @@ def partial_fit(self, X, y, queue=None): policy, self._params, self._partial_result, X_table, y_table ) + self._need_to_finalize = True + return self + def finalize_fit(self, queue=None): """ Finalizes linear regression computation and obtains coefficients @@ -121,27 +136,30 @@ def finalize_fit(self, queue=None): Returns the instance itself. """ - if queue is not None: - policy = self._get_policy(queue) - else: - policy = self._get_policy(self._queue) - - module = self._get_backend("linear_model", "regression") - hparams = get_hyperparameters("linear_regression", "train") - if hparams is not None and not hparams.is_default: - result = module.finalize_train( - policy, self._params, hparams.backend, self._partial_result + if self._need_to_finalize: + if queue is not None: + policy = self._get_policy(queue) + else: + policy = self._get_policy(self._queue) + + module = self._get_backend("linear_model", "regression") + hparams = get_hyperparameters("linear_regression", "train") + if hparams is not None and not hparams.is_default: + result = module.finalize_train( + policy, self._params, hparams.backend, self._partial_result + ) + else: + result = module.finalize_train(policy, self._params, self._partial_result) + + self._onedal_model = result.model + + packed_coefficients = from_table(result.model.packed_coefficients) + self.coef_, self.intercept_ = ( + packed_coefficients[:, 1:].squeeze(), + packed_coefficients[:, 0].squeeze(), ) - else: - result = module.finalize_train(policy, self._params, self._partial_result) - self._onedal_model = result.model - - packed_coefficients = from_table(result.model.packed_coefficients) - self.coef_, self.intercept_ = ( - packed_coefficients[:, 1:].squeeze(), - packed_coefficients[:, 0].squeeze(), - ) + self._need_to_finalize = False return self @@ -170,15 +188,26 @@ class IncrementalRidge(BaseLinearRegression): """ def __init__(self, alpha=1.0, fit_intercept=True, copy_X=False, algorithm="norm_eq"): - module = self._get_backend("linear_model", "regression") super().__init__( fit_intercept=fit_intercept, alpha=alpha, copy_X=copy_X, algorithm=algorithm ) - self._partial_result = module.partial_train_result() + self._reset() def _reset(self): module = self._get_backend("linear_model", "regression") self._partial_result = module.partial_train_result() + self._need_to_finalize = False + + def __getstate__(self): + # Since finalize_fit can't be dispatched without directly provided queue + # and the dispatching policy can't be serialized, the computation is finalized + # here and the policy is not saved in serialized data. + + self.finalize_fit() + data = self.__dict__.copy() + data.pop("_queue", None) + + return data def partial_fit(self, X, y, queue=None): """ @@ -223,6 +252,9 @@ def partial_fit(self, X, y, queue=None): policy, self._params, self._partial_result, X_table, y_table ) + self._need_to_finalize = True + return self + def finalize_fit(self, queue=None): """ Finalizes ridge regression computation and obtains coefficients @@ -238,19 +270,23 @@ def finalize_fit(self, queue=None): self : object Returns the instance itself. """ - module = self._get_backend("linear_model", "regression") - if queue is not None: - policy = self._get_policy(queue) - else: - policy = self._get_policy(self._queue) - result = module.finalize_train(policy, self._params, self._partial_result) - self._onedal_model = result.model + if self._need_to_finalize: + module = self._get_backend("linear_model", "regression") + if queue is not None: + policy = self._get_policy(queue) + else: + policy = self._get_policy(self._queue) + result = module.finalize_train(policy, self._params, self._partial_result) - packed_coefficients = from_table(result.model.packed_coefficients) - self.coef_, self.intercept_ = ( - packed_coefficients[:, 1:].squeeze(), - packed_coefficients[:, 0].squeeze(), - ) + self._onedal_model = result.model + + packed_coefficients = from_table(result.model.packed_coefficients) + self.coef_, self.intercept_ = ( + packed_coefficients[:, 1:].squeeze(), + packed_coefficients[:, 0].squeeze(), + ) + + self._need_to_finalize = False return self diff --git a/onedal/linear_model/linear_model.cpp b/onedal/linear_model/linear_model.cpp index 54e0972bae..450ef905ea 100644 --- a/onedal/linear_model/linear_model.cpp +++ b/onedal/linear_model/linear_model.cpp @@ -19,6 +19,9 @@ #include "onedal/common.hpp" #include "onedal/version.hpp" +#define NO_IMPORT_ARRAY // import_array called in table.cpp +#include "onedal/datatypes/data_conversion.hpp" + #include namespace py = pybind11; @@ -237,7 +240,23 @@ void init_partial_train_result(py::module_& m) { py::class_(m, "partial_train_result") .def(py::init()) .DEF_ONEDAL_PY_PROPERTY(partial_xtx, result_t) - .DEF_ONEDAL_PY_PROPERTY(partial_xty, result_t); + .DEF_ONEDAL_PY_PROPERTY(partial_xty, result_t) + .def(py::pickle( + [](const result_t& res) { + return py::make_tuple( + py::cast(convert_to_pyobject(res.get_partial_xtx())), + py::cast(convert_to_pyobject(res.get_partial_xty()))); + }, + [](py::tuple t) { + if (t.size() != 2) + throw std::runtime_error("Invalid state!"); + result_t res; + if (py::cast(t[0].attr("size")) != 0) + res.set_partial_xtx(convert_to_table(t[0])); + if (py::cast(t[1].attr("size")) != 0) + res.set_partial_xty(convert_to_table(t[1])); + return res; + })); } template diff --git a/onedal/linear_model/tests/test_incremental_linear_regression.py b/onedal/linear_model/tests/test_incremental_linear_regression.py index b707ceeada..4226205ffe 100644 --- a/onedal/linear_model/tests/test_incremental_linear_regression.py +++ b/onedal/linear_model/tests/test_incremental_linear_regression.py @@ -16,11 +16,12 @@ import numpy as np import pytest -from numpy.testing import assert_allclose, assert_array_equal +from numpy.testing import assert_allclose from sklearn.datasets import load_diabetes from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split +from onedal.datatypes import from_table from onedal.linear_model import IncrementalLinearRegression from onedal.tests.utils._device_selection import get_queues @@ -43,29 +44,6 @@ def test_diabetes(queue, dtype): assert mean_squared_error(y_test, y_pred) < 2396 -@pytest.mark.parametrize("queue", get_queues()) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.skip(reason="pickling not implemented for oneDAL entities") -def test_pickle(queue, dtype): - # TODO Implement pickling for oneDAL entities - X, y = load_diabetes(return_X_y=True) - X, y = X.astype(dtype), y.astype(dtype) - model = IncrementalLinearRegression(fit_intercept=True) - model.partial_fit(X, y, queue=queue) - model.finalize_fit() - expected = model.predict(X, queue=queue) - - import pickle - - dump = pickle.dumps(model) - model2 = pickle.loads(dump) - - assert isinstance(model2, model.__class__) - result = model2.predict(X, queue=queue) - - assert_array_equal(expected, result) - - @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("num_blocks", [1, 2, 10]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -167,3 +145,69 @@ def test_reconstruct_model(queue, dtype): tol = 1e-5 if res.dtype == np.float32 else 1e-7 assert_allclose(gtr, res, rtol=tol) + + +@pytest.mark.parametrize("queue", get_queues()) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_incremental_estimator_pickle(queue, dtype): + import pickle + + from onedal.linear_model import IncrementalLinearRegression + + inclr = IncrementalLinearRegression() + + # Check that estimator can be serialized without any data. + dump = pickle.dumps(inclr) + inclr_loaded = pickle.loads(dump) + seed = 77 + gen = np.random.default_rng(seed) + X = gen.uniform(low=-0.3, high=+0.7, size=(10, 10)) + X = X.astype(dtype) + coef = gen.random(size=(1, 10), dtype=dtype).T + y = X @ coef + X_split = np.array_split(X, 2) + y_split = np.array_split(y, 2) + inclr.partial_fit(X_split[0], y_split[0], queue=queue) + inclr_loaded.partial_fit(X_split[0], y_split[0], queue=queue) + + # inclr.finalize_fit() + + assert inclr._need_to_finalize == True + assert inclr_loaded._need_to_finalize == True + + # Check that estimator can be serialized after partial_fit call. + dump = pickle.dumps(inclr) + inclr_loaded = pickle.loads(dump) + + partial_xtx = from_table(inclr._partial_result.partial_xtx) + partial_xtx_loaded = from_table(inclr_loaded._partial_result.partial_xtx) + assert_allclose(partial_xtx, partial_xtx_loaded) + + partial_xty = from_table(inclr._partial_result.partial_xty) + partial_xty_loaded = from_table(inclr_loaded._partial_result.partial_xty) + assert_allclose(partial_xty, partial_xty_loaded) + + assert inclr._need_to_finalize == False + # Finalize is called during serialization to make sure partial results are finalized correctly. + assert inclr_loaded._need_to_finalize == False + + inclr.partial_fit(X_split[1], y_split[1], queue=queue) + inclr_loaded.partial_fit(X_split[1], y_split[1], queue=queue) + assert inclr._need_to_finalize == True + assert inclr_loaded._need_to_finalize == True + + dump = pickle.dumps(inclr_loaded) + inclr_loaded = pickle.loads(dump) + + assert inclr._need_to_finalize == True + assert inclr_loaded._need_to_finalize == False + + inclr.finalize_fit() + inclr_loaded.finalize_fit() + + # Check that finalized estimator can be serialized. + dump = pickle.dumps(inclr_loaded) + inclr_loaded = pickle.loads(dump) + + assert_allclose(inclr.coef_, inclr_loaded.coef_, atol=1e-6) + assert_allclose(inclr.intercept_, inclr_loaded.intercept_, atol=1e-6) diff --git a/onedal/linear_model/tests/test_incremental_ridge_regression.py b/onedal/linear_model/tests/test_incremental_ridge_regression.py index 471f46e4f6..259a51d4cc 100644 --- a/onedal/linear_model/tests/test_incremental_ridge_regression.py +++ b/onedal/linear_model/tests/test_incremental_ridge_regression.py @@ -24,6 +24,7 @@ from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split + from onedal.datatypes import from_table from onedal.linear_model import IncrementalRidge from onedal.tests.utils._device_selection import get_queues @@ -105,3 +106,66 @@ def test_no_intercept_results(queue, num_blocks, dtype): tol = 2e-4 if res.dtype == np.float32 else 1e-7 assert_allclose(gtr, res, rtol=tol) + + @pytest.mark.parametrize("queue", get_queues()) + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_incremental_estimator_pickle(queue, dtype): + import pickle + + model = IncrementalRidge() + + # Check that estimator can be serialized without any data. + dump = pickle.dumps(model) + model_loaded = pickle.loads(dump) + seed = 77 + gen = np.random.default_rng(seed) + X = gen.uniform(low=-0.3, high=+0.7, size=(10, 10)) + X = X.astype(dtype) + coef = gen.random(size=(1, 10), dtype=dtype).T + y = X @ coef + X_split = np.array_split(X, 2) + y_split = np.array_split(y, 2) + model.partial_fit(X_split[0], y_split[0], queue=queue) + model_loaded.partial_fit(X_split[0], y_split[0], queue=queue) + + # model.finalize_fit() + + assert model._need_to_finalize == True + assert model_loaded._need_to_finalize == True + + # Check that estimator can be serialized after partial_fit call. + dump = pickle.dumps(model) + model_loaded = pickle.loads(dump) + + partial_xtx = from_table(model._partial_result.partial_xtx) + partial_xtx_loaded = from_table(model_loaded._partial_result.partial_xtx) + assert_allclose(partial_xtx, partial_xtx_loaded) + + partial_xty = from_table(model._partial_result.partial_xty) + partial_xty_loaded = from_table(model_loaded._partial_result.partial_xty) + assert_allclose(partial_xty, partial_xty_loaded) + + assert model._need_to_finalize == False + # Finalize is called during serialization to make sure partial results are finalized correctly. + assert model_loaded._need_to_finalize == False + + model.partial_fit(X_split[1], y_split[1], queue=queue) + model_loaded.partial_fit(X_split[1], y_split[1], queue=queue) + assert model._need_to_finalize == True + assert model_loaded._need_to_finalize == True + + dump = pickle.dumps(model_loaded) + model_loaded = pickle.loads(dump) + + assert model._need_to_finalize == True + assert model_loaded._need_to_finalize == False + + model.finalize_fit() + model_loaded.finalize_fit() + + # Check that finalized estimator can be serialized. + dump = pickle.dumps(model_loaded) + model_loaded = pickle.loads(dump) + + assert_allclose(model.coef_, model_loaded.coef_, atol=1e-6) + assert_allclose(model.intercept_, model_loaded.intercept_, atol=1e-6) diff --git a/onedal/spmd/linear_model/incremental_linear_model.py b/onedal/spmd/linear_model/incremental_linear_model.py index d3846bc82a..267d17172a 100644 --- a/onedal/spmd/linear_model/incremental_linear_model.py +++ b/onedal/spmd/linear_model/incremental_linear_model.py @@ -35,6 +35,7 @@ class IncrementalLinearRegression(BaseEstimatorSPMD, base_IncrementalLinearRegre """ def _reset(self): + self._need_to_finalize = False self._partial_result = super(base_IncrementalLinearRegression, self)._get_backend( "linear_model", "regression", "partial_train_result" ) @@ -95,3 +96,6 @@ def partial_fit(self, X, y, queue=None): self._partial_result = module.partial_train( policy, self._params, self._partial_result, X_table, y_table ) + + self._need_to_finalize = True + return self diff --git a/sklearnex/linear_model/incremental_linear.py b/sklearnex/linear_model/incremental_linear.py index db2d6549c0..622b15ef6c 100644 --- a/sklearnex/linear_model/incremental_linear.py +++ b/sklearnex/linear_model/incremental_linear.py @@ -103,6 +103,13 @@ class IncrementalLinearRegression( n_features_in_ : int Number of features seen during ``fit`` or ``partial_fit``. + Note + ---- + Serializing instances of this class will trigger a forced finalization of calculations. + Since finalize_fit can't be dispatched without directly provided queue + and the dispatching policy can't be serialized, the computation is finalized + during serialization call and the policy is not saved in serialized data. + Examples -------- >>> import numpy as np diff --git a/sklearnex/linear_model/incremental_ridge.py b/sklearnex/linear_model/incremental_ridge.py index e750491ef9..39097d3e8d 100644 --- a/sklearnex/linear_model/incremental_ridge.py +++ b/sklearnex/linear_model/incremental_ridge.py @@ -96,6 +96,13 @@ class IncrementalRidge(MultiOutputMixin, RegressorMixin, BaseEstimator): batch_size_ : int Inferred batch size from ``batch_size``. + + Note + ---- + Serializing instances of this class will trigger a forced finalization of calculations. + Since finalize_fit can't be dispatched without directly provided queue + and the dispatching policy can't be serialized, the computation is finalized + during serialization call and the policy is not saved in serialized data. """ _onedal_incremental_ridge = staticmethod(onedal_IncrementalRidge) diff --git a/sklearnex/linear_model/tests/test_incremental_linear.py b/sklearnex/linear_model/tests/test_incremental_linear.py index e4ab649daf..e27e620ce6 100644 --- a/sklearnex/linear_model/tests/test_incremental_linear.py +++ b/sklearnex/linear_model/tests/test_incremental_linear.py @@ -205,3 +205,63 @@ def test_sklearnex_partial_fit_on_random_data( y_pred = inclin.predict(X_test_df) assert_allclose(expected_y_pred, _as_numpy(y_pred), atol=tol) + + +@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) +@pytest.mark.parametrize("fit_intercept", [True, False]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_sklearnex_incremental_estimatior_pickle(dataframe, queue, fit_intercept, dtype): + import pickle + + from sklearnex.linear_model import IncrementalLinearRegression + + inclin = IncrementalLinearRegression() + + # Check that estimator can be serialized without any data. + dump = pickle.dumps(inclin) + inclin_loaded = pickle.loads(dump) + + seed = 77 + gen = np.random.default_rng(seed) + intercept = gen.random(size=1, dtype=dtype) + coef = gen.random(size=(1, 10), dtype=dtype).T + X = gen.uniform(low=-0.3, high=+0.7, size=(30, 10)) + X = X.astype(dtype) + if fit_intercept: + y = X @ coef + intercept[np.newaxis, :] + else: + y = X @ coef + X_split = np.array_split(X, 2) + y_split = np.array_split(y, 2) + X_split_df = _convert_to_dataframe(X_split[0], sycl_queue=queue, target_df=dataframe) + y_split_df = _convert_to_dataframe(y_split[0], sycl_queue=queue, target_df=dataframe) + inclin.partial_fit(X_split_df, y_split_df) + inclin_loaded.partial_fit(X_split_df, y_split_df) + + # Check that estimator can be serialized after partial_fit call. + dump = pickle.dumps(inclin_loaded) + inclin_loaded = pickle.loads(dump) + + assert inclin.batch_size == inclin_loaded.batch_size + assert inclin.n_features_in_ == inclin_loaded.n_features_in_ + assert inclin.n_samples_seen_ == inclin_loaded.n_samples_seen_ + if hasattr(inclin, "_parameter_constraints"): + assert inclin._parameter_constraints == inclin_loaded._parameter_constraints + assert inclin.n_jobs == inclin_loaded.n_jobs + + X_split_df = _convert_to_dataframe(X_split[1], sycl_queue=queue, target_df=dataframe) + y_split_df = _convert_to_dataframe(y_split[1], sycl_queue=queue, target_df=dataframe) + inclin.partial_fit(X_split_df, y_split_df) + inclin_loaded.partial_fit(X_split_df, y_split_df) + dump = pickle.dumps(inclin) + inclin_loaded = pickle.loads(dump) + + assert_allclose(inclin.coef_, inclin_loaded.coef_, atol=1e-6) + assert_allclose(inclin.intercept_, inclin_loaded.intercept_, atol=1e-6) + + # Check that finalized estimator can be serialized. + dump = pickle.dumps(inclin_loaded) + inclin_loaded = pickle.loads(dump) + + assert_allclose(inclin.coef_, inclin_loaded.coef_, atol=1e-6) + assert_allclose(inclin.intercept_, inclin_loaded.intercept_, atol=1e-6) diff --git a/sklearnex/linear_model/tests/test_incremental_ridge.py b/sklearnex/linear_model/tests/test_incremental_ridge.py index adcd5349ed..6829a5303d 100644 --- a/sklearnex/linear_model/tests/test_incremental_ridge.py +++ b/sklearnex/linear_model/tests/test_incremental_ridge.py @@ -151,3 +151,64 @@ def test_inc_ridge_predict_after_fit(dataframe, queue, fit_intercept): y_pred_manual += intercept_manual assert_allclose(_as_numpy(y_pred), y_pred_manual, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) +@pytest.mark.parametrize("fit_intercept", [True, False]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_sklearnex_incremental_estimatior_pickle(dataframe, queue, fit_intercept, dtype): + import pickle + + from sklearnex.linear_model import IncrementalRidge + + inc_ridge = IncrementalRidge() + + # Check that estimator can be serialized without any data. + dump = pickle.dumps(inc_ridge) + inc_ridge_loaded = pickle.loads(dump) + + seed = 77 + gen = np.random.default_rng(seed) + intercept = gen.random(size=1, dtype=dtype) + coef = gen.random(size=(1, 10), dtype=dtype).T + X = gen.uniform(low=-0.3, high=+0.7, size=(30, 10)) + X = X.astype(dtype) + if fit_intercept: + y = X @ coef + intercept[np.newaxis, :] + else: + y = X @ coef + X_split = np.array_split(X, 2) + y_split = np.array_split(y, 2) + X_split_df = _convert_to_dataframe(X_split[0], sycl_queue=queue, target_df=dataframe) + y_split_df = _convert_to_dataframe(y_split[0], sycl_queue=queue, target_df=dataframe) + inc_ridge.partial_fit(X_split_df, y_split_df) + inc_ridge_loaded.partial_fit(X_split_df, y_split_df) + + # Check that estimator can be serialized after partial_fit call. + dump = pickle.dumps(inc_ridge_loaded) + inc_ridge_loaded = pickle.loads(dump) + + assert inc_ridge.batch_size == inc_ridge_loaded.batch_size + assert inc_ridge.n_features_in_ == inc_ridge_loaded.n_features_in_ + assert inc_ridge.n_samples_seen_ == inc_ridge_loaded.n_samples_seen_ + assert inc_ridge.alpha == inc_ridge_loaded.alpha + if hasattr(inc_ridge, "_parameter_constraints"): + assert inc_ridge._parameter_constraints == inc_ridge_loaded._parameter_constraints + assert inc_ridge.n_jobs == inc_ridge_loaded.n_jobs + + X_split_df = _convert_to_dataframe(X_split[1], sycl_queue=queue, target_df=dataframe) + y_split_df = _convert_to_dataframe(y_split[1], sycl_queue=queue, target_df=dataframe) + inc_ridge.partial_fit(X_split_df, y_split_df) + inc_ridge_loaded.partial_fit(X_split_df, y_split_df) + dump = pickle.dumps(inc_ridge) + inc_ridge_loaded = pickle.loads(dump) + + assert_allclose(inc_ridge.coef_, inc_ridge_loaded.coef_, atol=1e-6) + assert_allclose(inc_ridge.intercept_, inc_ridge_loaded.intercept_, atol=1e-6) + + # Check that finalized estimator can be serialized. + dump = pickle.dumps(inc_ridge_loaded) + inc_ridge_loaded = pickle.loads(dump) + + assert_allclose(inc_ridge.coef_, inc_ridge_loaded.coef_, atol=1e-6) + assert_allclose(inc_ridge.intercept_, inc_ridge_loaded.intercept_, atol=1e-6)