@@ -46,7 +46,13 @@ def _onedal_supported_format(X, xp=None):
46
46
# _onedal_supported_format is therefore conservative in verifying attributes and
47
47
# does not support array_api. This will block onedal_assert_all_finite from being
48
48
# used for array_api inputs but will allow dpnp ndarrays and dpctl tensors.
49
- return X .dtype in [xp .float32 , xp .float64 ] and hasattr (X , "flags" )
49
+ # only check contiguous arrays to prevent unnecessary copying of data, even if
50
+ # non-contiguous arrays can now be converted to oneDAL tables.
51
+ return (
52
+ X .dtype in [xp .float32 , xp .float64 ]
53
+ and hasattr (X , "flags" )
54
+ and (X .flags ["C_CONTIGUOUS" ] or X .flags ["F_CONTIGUOUS" ])
55
+ )
50
56
51
57
else :
52
58
from daal4py .utils .validation import _assert_all_finite as _onedal_assert_all_finite
@@ -108,14 +114,37 @@ def validate_data(
108
114
y = y ,
109
115
** kwargs ,
110
116
)
117
+
118
+ check_x = not isinstance (X , str ) or X != "no_validation"
119
+ check_y = not (y is None or isinstance (y , str ) and y == "no_validation" )
120
+
111
121
if ensure_all_finite :
112
122
# run local finite check
113
123
allow_nan = ensure_all_finite == "allow-nan"
114
124
arg = iter (out if isinstance (out , tuple ) else (out ,))
115
- if not isinstance ( X , str ) or X != "no_validation" :
125
+ if check_x :
116
126
assert_all_finite (next (arg ), allow_nan = allow_nan , input_name = "X" )
117
- if not ( y is None or isinstance ( y , str ) and y == "no_validation" ) :
127
+ if check_y :
118
128
assert_all_finite (next (arg ), allow_nan = allow_nan , input_name = "y" )
129
+
130
+ if check_y and "dtype" in kwargs :
131
+ # validate_data does not do full dtype conversions, as it uses check_X_y
132
+ # oneDAL can make tables from [int32, float32, float64], requiring
133
+ # a dtype check and conversion. This will query the array_namespace and
134
+ # convert y as necessary. This is done after assert_all_finite, because
135
+ # int y arrays do not need to finite check, and this will lead to a speedup
136
+ # in comparison to sklearn
137
+ dtype = kwargs ["dtype" ]
138
+ if not isinstance (dtype , (tuple , list )):
139
+ dtype = tuple (dtype )
140
+
141
+ outx , outy = out if check_x else (None , out )
142
+ if outy .dtype not in dtype :
143
+ yp , _ = get_namespace (outy )
144
+ # use asarray rather than astype because of numpy support
145
+ outy = yp .asarray (outy , dtype = dtype [0 ])
146
+ out = (outx , outy ) if check_x else outy
147
+
119
148
return out
120
149
121
150
0 commit comments