Skip to content

Commit 95bd1ea

Browse files
icfaustethanglaser
andauthored
[enhancement] add sklearnex version of validate_data, _check_sample_weight (#2177)
* add finiteness_checker pybind11 bindings * added finiteness checker * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Rename finiteness_checker.cpp to finiteness_checker.cpp * Update finiteness_checker.cpp * add next step * follow conventions * make xtable explicit * remove comment * Update validation.py * Update __init__.py * Update validation.py * Update __init__.py * Update __init__.py * Update validation.py * Update _data_conversion.py * Update _data_conversion.py * Update policy_common.cpp * Update policy_common.cpp * Update _policy.py * Update policy_common.cpp * Rename finiteness_checker.cpp to finiteness_checker.cpp * Create finiteness_checker.py * Update validation.py * Update __init__.py * attempt at fixing circular imports again * fix isort * remove __init__ changes * last move * Update policy_common.cpp * Update policy_common.cpp * Update policy_common.cpp * Update policy_common.cpp * Update validation.py * add testing * isort * attempt to fix module error * add fptype * fix typo * Update validation.py * remove sua_ifcae from to_table * isort and black * Update test_memory_usage.py * format * Update _data_conversion.py * Update _data_conversion.py * Update test_validation.py * remove unnecessary code * make reviewer changes * make dtype check change * add sparse testing * try again * try again * try again * temporary commit * first attempt * missing change? * modify DummyEstimator for testing * generalize DummyEstimator * switch test * further testing changes * add initial validate_data test, will be refactored * fixes for CI * Update validation.py * Update validation.py * Update test_memory_usage.py * Update base.py * Update base.py * improve tests * fix logic * fix logic * fix logic again * rename file * Revert "rename file" This reverts commit 8d47744. * remove duplication * fix imports * Rename test_finite.py to test_validation.py * Revert "Rename test_finite.py to test_validation.py" This reverts commit ee799f6. * updates * Update validation.py * fixes for some test failures * fix text * fixes for some failures * make consistent * fix bad logic * fix in string * attempt tp see if dataframe conversion is causing the issue * fix iter problem * fix testing issues * formatting * revert change * fixes for pandas * there is a slowdown with pandas that needs to be solved * swap to transpose for speed * more clarity * add _check_sample_weight * add more testing' * rename * remove unnecessary imports * fix test slowness * focus get_dataframes_and_queues * put config_context around * Update test_validation.py * Update base.py * Update test_validation.py * generalize regex * add fixes for sklearn 1.0 and input_name * fixes for test failures * Update validation.py * Update test_validation.py * Update validation.py * formattintg * make suggested changes * follow changes made in #2126 * fix future device problem * Update validation.py * minor changes based on #2206, suggestions * remove xp as keyword * only_non_negative -> ensure_non_negative * add commentary * formatting * address changes * Update test_validation.py * Update base.py * Update test_validation.py * Update sklearnex/utils/validation.py Co-authored-by: ethanglaser <[email protected]> --------- Co-authored-by: ethanglaser <[email protected]>
1 parent 5d8d9bb commit 95bd1ea

File tree

7 files changed

+487
-130
lines changed

7 files changed

+487
-130
lines changed

sklearnex/tests/test_memory_usage.py

+7-38
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,14 @@
3535
get_dataframes_and_queues,
3636
)
3737
from onedal.tests.utils._device_selection import get_queues, is_dpctl_device_available
38-
from onedal.utils._array_api import _get_sycl_namespace
3938
from onedal.utils._dpep_helpers import dpctl_available, dpnp_available
4039
from sklearnex import config_context
41-
from sklearnex.tests.utils import PATCHED_FUNCTIONS, PATCHED_MODELS, SPECIAL_INSTANCES
40+
from sklearnex.tests.utils import (
41+
PATCHED_FUNCTIONS,
42+
PATCHED_MODELS,
43+
SPECIAL_INSTANCES,
44+
DummyEstimator,
45+
)
4246
from sklearnex.utils._array_api import get_namespace
4347

4448
if dpctl_available:
@@ -131,41 +135,6 @@ def gen_functions(functions):
131135
ORDER_DICT = {"F": np.asfortranarray, "C": np.ascontiguousarray}
132136

133137

134-
if _is_dpc_backend:
135-
136-
from sklearn.utils.validation import check_is_fitted
137-
138-
from onedal.datatypes import from_table, to_table
139-
140-
class DummyEstimatorWithTableConversions(BaseEstimator):
141-
142-
def fit(self, X, y=None):
143-
sua_iface, xp, _ = _get_sycl_namespace(X)
144-
X_table = to_table(X)
145-
y_table = to_table(y)
146-
# The presence of the fitted attributes (ending with a trailing
147-
# underscore) is required for the correct check. The cleanup of
148-
# the memory will occur at the estimator instance deletion.
149-
self.x_attr_ = from_table(
150-
X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
151-
)
152-
self.y_attr_ = from_table(
153-
y_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
154-
)
155-
return self
156-
157-
def predict(self, X):
158-
# Checks if the estimator is fitted by verifying the presence of
159-
# fitted attributes (ending with a trailing underscore).
160-
check_is_fitted(self)
161-
sua_iface, xp, _ = _get_sycl_namespace(X)
162-
X_table = to_table(X)
163-
returned_X = from_table(
164-
X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
165-
)
166-
return returned_X
167-
168-
169138
def gen_clsf_data(n_samples, n_features, dtype=None):
170139
data, label = make_classification(
171140
n_classes=2, n_samples=n_samples, n_features=n_features, random_state=777
@@ -369,7 +338,7 @@ def test_table_conversions_memory_leaks(dataframe, queue, order, data_shape, dty
369338
pytest.skip("SYCL device memory leak check requires the level zero sysman")
370339

371340
_kfold_function_template(
372-
DummyEstimatorWithTableConversions,
341+
DummyEstimator,
373342
dataframe,
374343
data_shape,
375344
queue,

sklearnex/tests/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
SPECIAL_INSTANCES,
2222
UNPATCHED_FUNCTIONS,
2323
UNPATCHED_MODELS,
24+
DummyEstimator,
2425
_get_processor_info,
2526
call_method,
2627
gen_dataset,
@@ -39,6 +40,7 @@
3940
"gen_models_info",
4041
"gen_dataset",
4142
"sklearn_clone_dict",
43+
"DummyEstimator",
4244
]
4345

4446
_IS_INTEL = "GenuineIntel" in _get_processor_info()

sklearnex/tests/utils/base.py

+44
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@
3232
)
3333
from sklearn.datasets import load_diabetes, load_iris
3434
from sklearn.neighbors._base import KNeighborsMixin
35+
from sklearn.utils.validation import check_is_fitted
3536

37+
from onedal.datatypes import from_table, to_table
3638
from onedal.tests.utils._dataframes_support import _convert_to_dataframe
39+
from onedal.utils._array_api import _get_sycl_namespace
3740
from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn
3841
from sklearnex.basic_statistics import BasicStatistics, IncrementalBasicStatistics
3942
from sklearnex.linear_model import LogisticRegression
@@ -369,3 +372,44 @@ def _get_processor_info():
369372
)
370373

371374
return proc
375+
376+
377+
class DummyEstimator(BaseEstimator):
378+
379+
def fit(self, X, y=None):
380+
sua_iface, xp, _ = _get_sycl_namespace(X)
381+
X_table = to_table(X)
382+
y_table = to_table(y)
383+
# The presence of the fitted attributes (ending with a trailing
384+
# underscore) is required for the correct check. The cleanup of
385+
# the memory will occur at the estimator instance deletion.
386+
if sua_iface:
387+
self.x_attr_ = from_table(
388+
X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
389+
)
390+
self.y_attr_ = from_table(
391+
y_table,
392+
sua_iface=sua_iface,
393+
sycl_queue=X.sycl_queue if y is None else y.sycl_queue,
394+
xp=xp,
395+
)
396+
else:
397+
self.x_attr = from_table(X_table)
398+
self.y_attr = from_table(y_table)
399+
400+
return self
401+
402+
def predict(self, X):
403+
# Checks if the estimator is fitted by verifying the presence of
404+
# fitted attributes (ending with a trailing underscore).
405+
check_is_fitted(self)
406+
sua_iface, xp, _ = _get_sycl_namespace(X)
407+
X_table = to_table(X)
408+
if sua_iface:
409+
returned_X = from_table(
410+
X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
411+
)
412+
else:
413+
returned_X = from_table(X_table)
414+
415+
return returned_X

sklearnex/utils/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
# limitations under the License.
1515
# ===============================================================================
1616

17-
from .validation import _assert_all_finite
17+
from .validation import assert_all_finite
1818

19-
__all__ = ["_assert_all_finite"]
19+
__all__ = ["assert_all_finite"]

sklearnex/utils/tests/test_finite.py

-89
This file was deleted.

0 commit comments

Comments
 (0)