Skip to content

Commit 164435d

Browse files
committed
minor changes based on uxlfoundation#2206, suggestions
1 parent 8fca003 commit 164435d

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

sklearnex/utils/tests/test_validation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ==============================================================================
2-
# Copyright 2024 Intel Corporation
2+
# Copyright 2024 UXL Foundation Contributors
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.

sklearnex/utils/validation.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ def _onedal_supported_format(X, xp=None):
4646
# _onedal_supported_format is therefore conservative in verifying attributes and
4747
# does not support array_api. This will block onedal_assert_all_finite from being
4848
# used for array_api inputs but will allow dpnp ndarrays and dpctl tensors.
49-
return X.dtype in [xp.float32, xp.float64] and hasattr(X, "flags")
49+
# only check contiguous arrays to prevent unnecessary copying of data, even if
50+
# non-contiguous arrays can now be converted to oneDAL tables.
51+
return (
52+
X.dtype in [xp.float32, xp.float64]
53+
and hasattr(X, "flags")
54+
and (X.flags["C_CONTIGUOUS"] or X.flags["F_CONTIGUOUS"])
55+
)
5056

5157
else:
5258
from daal4py.utils.validation import _assert_all_finite as _onedal_assert_all_finite
@@ -108,14 +114,37 @@ def validate_data(
108114
y=y,
109115
**kwargs,
110116
)
117+
118+
check_x = not isinstance(X, str) or X != "no_validation"
119+
check_y = not (y is None or isinstance(y, str) and y == "no_validation")
120+
111121
if ensure_all_finite:
112122
# run local finite check
113123
allow_nan = ensure_all_finite == "allow-nan"
114124
arg = iter(out if isinstance(out, tuple) else (out,))
115-
if not isinstance(X, str) or X != "no_validation":
125+
if check_x:
116126
assert_all_finite(next(arg), allow_nan=allow_nan, input_name="X")
117-
if not (y is None or isinstance(y, str) and y == "no_validation"):
127+
if check_y:
118128
assert_all_finite(next(arg), allow_nan=allow_nan, input_name="y")
129+
130+
if check_y and "dtype" in kwargs:
131+
# validate_data does not do full dtype conversions, as it uses check_X_y
132+
# oneDAL can make tables from [int32, float32, float64], requiring
133+
# a dtype check and conversion. This will query the array_namespace and
134+
# convert y as necessary. This is done after assert_all_finite, because
135+
# int y arrays do not need to finite check, and this will lead to a speedup
136+
# in comparison to sklearn
137+
dtype = kwargs["dtype"]
138+
if not isinstance(dtype, (tuple, list)):
139+
dtype = tuple(dtype)
140+
141+
outx, outy = out if check_x else (None, out)
142+
if outy.dtype not in dtype:
143+
yp, _ = get_namespace(outy)
144+
# use asarray rather than astype because of numpy support
145+
outy = yp.asarray(outy, dtype=dtype[0])
146+
out = (outx, outy) if check_x else outy
147+
119148
return out
120149

121150

0 commit comments

Comments
 (0)