Skip to content

Commit c0eb5ad

Browse files
FIX: update functional support fallback logic for a DPNP/DPCTL ndarray inputs (uxlfoundation#2113)
* FIX: update functional support fallback logic a little bit * host numpy copies of the inputs data will be used for the fallback cases, since stock scikit-learn doesn't support DPCTL usm_ndarray and DPNP ndarray * Added a clarifying comment * Enhanced patch message for data transfer
1 parent 0809a3e commit c0eb5ad

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

onedal/_device_offload.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _transfer_to_host(queue, *data):
133133
raise RuntimeError("Input data shall be located on single target device")
134134

135135
host_data.append(item)
136-
return queue, host_data
136+
return has_usm_data, queue, host_data
137137

138138

139139
def _get_global_queue():
@@ -150,8 +150,8 @@ def _get_global_queue():
150150

151151
def _get_host_inputs(*args, **kwargs):
152152
q = _get_global_queue()
153-
q, hostargs = _transfer_to_host(q, *args)
154-
q, hostvalues = _transfer_to_host(q, *kwargs.values())
153+
_, q, hostargs = _transfer_to_host(q, *args)
154+
_, q, hostvalues = _transfer_to_host(q, *kwargs.values())
155155
hostkwargs = dict(zip(kwargs.keys(), hostvalues))
156156
return q, hostargs, hostkwargs
157157

sklearnex/_device_offload.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,34 @@ def _get_backend(obj, queue, method_name, *data):
6363

6464
def dispatch(obj, method_name, branches, *args, **kwargs):
6565
q = _get_global_queue()
66-
q, hostargs = _transfer_to_host(q, *args)
67-
q, hostvalues = _transfer_to_host(q, *kwargs.values())
66+
has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args)
67+
has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values())
6868
hostkwargs = dict(zip(kwargs.keys(), hostvalues))
6969

7070
backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs)
71-
71+
has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs
7272
if backend == "onedal":
73-
patching_status.write_log(queue=q)
73+
# Host args only used before onedal backend call.
74+
# Device will be offloaded when onedal backend will be called.
75+
patching_status.write_log(queue=q, transferred_to_host=False)
7476
return branches[backend](obj, *hostargs, **hostkwargs, queue=q)
7577
if backend == "sklearn":
7678
if (
7779
"array_api_dispatch" in get_config()
7880
and get_config()["array_api_dispatch"]
7981
and "array_api_support" in obj._get_tags()
8082
and obj._get_tags()["array_api_support"]
83+
and not has_usm_data
8184
):
85+
# USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is
86+
# not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant,
87+
# except for the linalg module. There is no guarantee that stock scikit-learn will
88+
# work with such input data. The condition will be updated after DPNP.ndarray and
89+
# DPCTL usm_ndarray enabling for conformance testing and these arrays supportance
90+
# of the fallback cases.
8291
# If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn,
8392
# then raw inputs are used for the fallback.
84-
patching_status.write_log()
93+
patching_status.write_log(transferred_to_host=False)
8594
return branches[backend](obj, *args, **kwargs)
8695
else:
8796
patching_status.write_log()

sklearnex/_utils.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ class PatchingConditionsChain(daal4py_PatchingConditionsChain):
2929
def get_status(self):
3030
return self.patching_is_enabled
3131

32-
def write_log(self, queue=None):
32+
def write_log(self, queue=None, transferred_to_host=True):
3333
if self.patching_is_enabled:
3434
self.logger.info(
35-
f"{self.scope_name}: {get_patch_message('onedal', queue=queue)}"
35+
f"{self.scope_name}: {get_patch_message('onedal', queue=queue, transferred_to_host=transferred_to_host)}"
3636
)
3737
else:
3838
self.logger.debug(
@@ -43,7 +43,9 @@ def write_log(self, queue=None):
4343
self.logger.debug(
4444
f"{self.scope_name}: patching failed with cause - {message}"
4545
)
46-
self.logger.info(f"{self.scope_name}: {get_patch_message('sklearn')}")
46+
self.logger.info(
47+
f"{self.scope_name}: {get_patch_message('sklearn', transferred_to_host=transferred_to_host)}"
48+
)
4749

4850

4951
def set_sklearn_ex_verbose():
@@ -66,7 +68,7 @@ def set_sklearn_ex_verbose():
6668
)
6769

6870

69-
def get_patch_message(s, queue=None):
71+
def get_patch_message(s, queue=None, transferred_to_host=True):
7072
if s == "onedal":
7173
message = "running accelerated version on "
7274
if queue is not None:
@@ -87,6 +89,10 @@ def get_patch_message(s, queue=None):
8789
f"Invalid input - expected one of 'onedal','sklearn',"
8890
f" 'sklearn_after_onedal', got {s}"
8991
)
92+
if transferred_to_host:
93+
message += (
94+
". All input data transferred to host for further backend computations."
95+
)
9096
return message
9197

9298

0 commit comments

Comments
 (0)