diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index 2d52a545cf..163c23ff16 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -32,6 +32,8 @@ from onedal import _is_dpc_backend from onedal.tests.utils._dataframes_support import ( _convert_to_dataframe, + dpctl_available, + dpnp_available, get_dataframes_and_queues, ) from onedal.tests.utils._device_selection import get_queues, is_dpctl_device_available @@ -167,10 +169,13 @@ def take(x, index, axis=0, queue=None): return xp.take( x, xp.asarray(index, usm_type="device", sycl_queue=x.sycl_queue), axis=axis ) + # TODO: + # re-impl _is_numpy_namespace + # elif array_api and not isinstance(x, np.ndarray): elif array_api: return xp.take(x, xp.asarray(index, device=x.device), axis=axis) else: - return x.take(index, axis=axis) + return xp.take(x, xp.asarray(index), axis=axis) def split_train_inference(kf, x, y, estimator, queue=None): @@ -288,7 +293,7 @@ def _kfold_function_template( @pytest.mark.parametrize("order", ["F", "C"]) @pytest.mark.parametrize( - "dataframe,queue", get_dataframes_and_queues("numpy,pandas,dpctl", "cpu") + "dataframe,queue", get_dataframes_and_queues("numpy,pandas,dpctl,array_api", "cpu") ) @pytest.mark.parametrize("estimator", CPU_ESTIMATORS.keys()) @pytest.mark.parametrize("data_shape", data_shapes) diff --git a/sklearnex/utils/_array_api.py b/sklearnex/utils/_array_api.py index bc30be5021..8e731bc425 100644 --- a/sklearnex/utils/_array_api.py +++ b/sklearnex/utils/_array_api.py @@ -16,11 +16,20 @@ """Tools to support array_api.""" +try: + import array_api_compat + + array_api_available = True +except ImportError: + array_api_available = False + import numpy as np from daal4py.sklearn._utils import sklearn_check_version from onedal.utils._array_api import _get_sycl_namespace +from .._config import get_config + if sklearn_check_version("1.2"): from sklearn.utils._array_api import get_namespace as sklearn_get_namespace @@ -76,7 +85,19 @@ def get_namespace(*arrays): if sycl_type: return xp, is_array_api_compliant + # TODO: + # correct condition. + # scikit-learn's get_namespace require config_contex(array_api_dispatch=True). + # elif sklearn_check_version("1.2") and get_config()["array_api_dispatch"]: elif sklearn_check_version("1.2"): - return sklearn_get_namespace(*arrays) + namespace, is_array_api_compliant = sklearn_get_namespace(*arrays) + return namespace, is_array_api_compliant + # TODO: + # on PR 2096. + # elif array_api_available: + # namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True + # return namespace, is_array_api_compliant + # TODO: + # should be removed. else: return np, False