|
22 | 22 | from onedal import _backend, _is_dpc_backend
|
23 | 23 |
|
24 | 24 | from ..utils import _is_csr
|
25 |
| -from ..utils._dpep_helpers import is_dpctl_available |
26 | 25 |
|
27 |
| -dpctl_available = is_dpctl_available("0.14") |
28 | 26 |
|
29 |
| -if dpctl_available: |
30 |
| - import dpctl |
31 |
| - import dpctl.tensor as dpt |
32 |
| - |
33 |
| - |
34 |
| -def _apply_and_pass(func, *args): |
| 27 | +def _apply_and_pass(func, *args, **kwargs): |
35 | 28 | if len(args) == 1:
|
36 |
| - return func(args[0]) |
37 |
| - return tuple(map(func, args)) |
| 29 | + return func(args[0], **kwargs) |
| 30 | + return tuple(map(lambda arg: func(arg, **kwargs), args)) |
38 | 31 |
|
39 | 32 |
|
40 |
| -def from_table(*args): |
41 |
| - return _apply_and_pass(_backend.from_table, *args) |
| 33 | +if _is_dpc_backend: |
42 | 34 |
|
| 35 | + from ..utils._dpep_helpers import dpctl_available, dpnp_available |
43 | 36 |
|
44 |
| -def convert_one_to_table(arg): |
45 | 37 | if dpctl_available:
|
46 |
| - if isinstance(arg, dpt.usm_ndarray): |
47 |
| - return _backend.dpctl_to_table(arg) |
48 |
| - |
49 |
| - if not _is_csr(arg): |
50 |
| - arg = make2d(arg) |
51 |
| - return _backend.to_table(arg) |
52 |
| - |
| 38 | + import dpctl.tensor as dpt |
53 | 39 |
|
54 |
| -def to_table(*args): |
55 |
| - return _apply_and_pass(convert_one_to_table, *args) |
| 40 | + if dpnp_available: |
| 41 | + import dpnp |
56 | 42 |
|
57 |
| - |
58 |
| -if _is_dpc_backend: |
59 | 43 | from ..common._policy import _HostInteropPolicy
|
60 | 44 |
|
61 | 45 | def _convert_to_supported(policy, *data):
|
@@ -85,10 +69,78 @@ def convert_or_pass(x):
|
85 | 69 |
|
86 | 70 | return _apply_and_pass(func, *data)
|
87 | 71 |
|
| 72 | + def convert_one_from_table(table, sycl_queue=None, sua_iface=None, xp=None): |
| 73 | + # Currently only `__sycl_usm_array_interface__` protocol used to |
| 74 | + # convert into dpnp/dpctl tensors. |
| 75 | + if sua_iface: |
| 76 | + if ( |
| 77 | + sycl_queue |
| 78 | + and sycl_queue.sycl_device.is_cpu |
| 79 | + and table.__sycl_usm_array_interface__["syclobj"] is None |
| 80 | + ): |
| 81 | + # oneDAL returns tables with None sycl queue for CPU sycl queue inputs. |
| 82 | + # This workaround is necessary for the functional preservation |
| 83 | + # of the compute-follows-data execution. |
| 84 | + # Host tables first converted into numpy.narrays and then to array from xp |
| 85 | + # namespace. |
| 86 | + return xp.asarray( |
| 87 | + _backend.from_table(table), usm_type="device", sycl_queue=sycl_queue |
| 88 | + ) |
| 89 | + else: |
| 90 | + xp_name = xp.__name__ |
| 91 | + if dpnp_available and xp_name == "dpnp": |
| 92 | + # By default DPNP ndarray created with a copy. |
| 93 | + # TODO: |
| 94 | + # investigate why dpnp.array(table, copy=False) doesn't work. |
| 95 | + # Work around with using dpctl.tensor.asarray. |
| 96 | + return dpnp.array(dpt.asarray(table), copy=False) |
| 97 | + else: |
| 98 | + return xp.asarray(table) |
| 99 | + return _backend.from_table(table) |
| 100 | + |
| 101 | + def convert_one_to_table(arg, sua_iface=None): |
| 102 | + # Note: currently only oneDAL homogen tables are supported and the |
| 103 | + # contiuginity of the input array should be checked in advance. |
| 104 | + if sua_iface: |
| 105 | + return _backend.sua_iface_to_table(arg) |
| 106 | + |
| 107 | + if not _is_csr(arg): |
| 108 | + arg = make2d(arg) |
| 109 | + return _backend.to_table(arg) |
| 110 | + |
88 | 111 | else:
|
89 | 112 |
|
90 | 113 | def _convert_to_supported(policy, *data):
|
91 | 114 | def func(x):
|
92 | 115 | return x
|
93 | 116 |
|
94 | 117 | return _apply_and_pass(func, *data)
|
| 118 | + |
| 119 | + def convert_one_from_table(table, sycl_queue=None, sua_iface=None, xp=None): |
| 120 | + # Currently only `__sycl_usm_array_interface__` protocol used to |
| 121 | + # convert into dpnp/dpctl tensors. |
| 122 | + if sua_iface: |
| 123 | + raise RuntimeError( |
| 124 | + "SYCL usm array conversion from table requires the DPC backend" |
| 125 | + ) |
| 126 | + return _backend.from_table(table) |
| 127 | + |
| 128 | + def convert_one_to_table(arg, sua_iface=None): |
| 129 | + if sua_iface: |
| 130 | + raise RuntimeError( |
| 131 | + "SYCL usm array conversion to table requires the DPC backend" |
| 132 | + ) |
| 133 | + |
| 134 | + if not _is_csr(arg): |
| 135 | + arg = make2d(arg) |
| 136 | + return _backend.to_table(arg) |
| 137 | + |
| 138 | + |
| 139 | +def from_table(*args, sycl_queue=None, sua_iface=None, xp=None): |
| 140 | + return _apply_and_pass( |
| 141 | + convert_one_from_table, *args, sycl_queue=sycl_queue, sua_iface=sua_iface, xp=xp |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +def to_table(*args, sua_iface=None): |
| 146 | + return _apply_and_pass(convert_one_to_table, *args, sua_iface=sua_iface) |
0 commit comments