Skip to content

Commit 744c848

Browse files
authored
cuml-cpu: fix import issues, enable conda import tests (#6400)
Fixes #6403 This project publishes a conda package, `cuml-cpu`, which does what it sounds like... allows the use of cuML on systems without a GPU. This proposes some updates to packaging for `cuml-cpu`: * fixes importing in CPU-only environment (broken in 25.04, see #6403) * enables import tests during conda builds, to reduce the risk of such issues going undetected in the future ## Notes for Reviewers ### Why all these changes in Python code? See some of the challenges I faced documented in #6400 (comment). In short, `import cuml` when it was installed via `cuml-cpu` will break at import time whenever modules imported with `cuml.internals.safe_imports.gpu_only_import()` are used in any of the following ways: * type hints * decorators * any other module-level direct use Like this: ```text cuml.internals.safe_imports.UnavailableError: cudf is not installed in non GPU-enabled installations ``` ### How long has this been broken? What's the root cause? It seems like something changed within 25.04... earlier versions of cuML are not affected by these issues: #6403 (comment) I don't know what the root cause is. Maybe some changes to `cuml`'s top-level imports in 25.04 is now pulling in the modules with these problems at runtime, when previously it wasn't? I'm really not sure. ### Benefits of these Changes This adds a bit of test coverage in CI, minimally verifying that `cuml-cpu` is installable and that `import cuml` works in an environment without a GPU. Inspired by: * similar changes in `cuvs`: rapidsai/cuvs#750 * this conversation I recently had with @betatim : rapidsai/cuvs#743 (comment) ### How I tested this Saw stuff like this in `conda-python-build` jobs, confirming that the import tests were running and passing: ```text BUILD START: ['cuml-cpu-25.04.00a137-py310_250312_g153b21870_137.conda'] ... import: 'cuml' ... Resource usage statistics from testing cuml-cpu: ... Time elapsed: 0:00:10.0 ... TEST END: /tmp/conda-bld-output/linux-64/cuml-cpu-25.04.00a137-py310_250312_g153b21870_137.conda ``` Authors: - James Lamb (https://github.com/jameslamb) Approvers: - Gil Forsyth (https://github.com/gforsyth) - Simon Adorf (https://github.com/csadorf) - Tim Head (https://github.com/betatim) URL: #6400
1 parent 0b13f70 commit 744c848

File tree

12 files changed

+52
-48
lines changed

12 files changed

+52
-48
lines changed

ci/build_python.sh

-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ if [[ ${RAPIDS_CUDA_MAJOR} == "12" ]]; then
3737
sccache --zero-stats
3838

3939
RAPIDS_PACKAGE_VERSION=$(head -1 ./VERSION) rapids-conda-retry build \
40-
--no-test \
4140
conda/recipes/cuml-cpu
4241

4342
sccache --show-adv-stats

conda/recipes/cuml-cpu/meta.yaml

+4-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ requirements:
3939
- umap-learn=0.5.6
4040
- nvtx
4141

42-
tests: # [linux64]
43-
imports: # [linux64]
44-
- cuml # [linux64]
42+
test:
43+
# test that the package is installable and these modules are importable
44+
imports:
45+
- cuml
4546

4647
about:
4748
home: https://rapids.ai/

python/cuml/cuml/_thirdparty/sklearn/preprocessing/_imputation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from cuml.internals.safe_imports import cpu_only_import
3131
numpy = cpu_only_import('numpy')
32-
np = gpu_only_import('cupy')
32+
np = gpu_only_import('cupy', alt=numpy)
3333
sparse = gpu_only_import_from('cupyx.scipy', 'sparse')
3434

3535

python/cuml/cuml/_thirdparty/sklearn/utils/sparsefuncs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# This code is under BSD 3 clause license.
1212
# Authors mentioned above do not endorse or promote this production.
1313

14+
import numpy
1415

1516
from ....thirdparty_adapters.sparsefuncs_fast import (
1617
csr_mean_variance_axis0 as _csr_mean_var_axis0,
@@ -21,7 +22,7 @@
2122
from cuml.internals.safe_imports import cpu_only_import_from
2223
cpu_sp = cpu_only_import_from('scipy', 'sparse')
2324
gpu_sp = gpu_only_import_from('cupyx.scipy', 'sparse')
24-
np = gpu_only_import('cupy')
25+
np = gpu_only_import('cupy', alt=numpy)
2526
cpu_np = cpu_only_import('numpy')
2627

2728

python/cuml/cuml/dask/common/input_utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from dask_cudf import Series as dcSeries
2828
from dask.dataframe import Series as daskSeries
2929
from dask.dataframe import DataFrame as daskDataFrame
30-
from cudf import Series
3130
from cuml.internals.safe_imports import gpu_only_import_from
3231
from collections import OrderedDict
3332
from cuml.internals.memory_utils import with_cupy_rmm
@@ -197,7 +196,7 @@ def _get_datatype_from_inputs(data):
197196

198197
@with_cupy_rmm
199198
def concatenate(objs, axis=0):
200-
if isinstance(objs[0], DataFrame) or isinstance(objs[0], Series):
199+
if isinstance(objs[0], DataFrame) or isinstance(objs[0], cudf.Series):
201200
if len(objs) == 1:
202201
return objs[0]
203202
else:

python/cuml/cuml/feature_extraction/_vectorizers.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
1+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -14,7 +14,6 @@
1414
#
1515
from cuml.internals.safe_imports import cpu_only_import
1616
import cuml.internals.logger as logger
17-
from cudf.utils.dtypes import min_signed_type
1817
from cuml.internals.type_utils import CUPY_SPARSE_DTYPES
1918
import numbers
2019
from cuml.internals.safe_imports import gpu_only_import
@@ -256,7 +255,7 @@ def _compute_empty_doc_ids(self, count_df, n_doc):
256255
of documents.
257256
"""
258257
remaining_docs = count_df["doc_id"].unique()
259-
dtype = min_signed_type(n_doc)
258+
dtype = cudf.utils.dtypes.min_signed_type(n_doc)
260259
doc_ids = cudf.DataFrame(
261260
data={"all_ids": cp.arange(0, n_doc, dtype=dtype)}, dtype=dtype
262261
)

python/cuml/cuml/internals/base_return_types.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
2+
# Copyright (c) 2022-2025, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -97,7 +97,8 @@ def _get_base_return_type(class_name, attr):
9797
if attr.__annotations__["return"].replace("'", "") == class_name:
9898
return "base"
9999
except Exception:
100-
assert False, "Shouldn't get here"
101-
return None
100+
raise AssertionError(
101+
f"Failed to determine return type for {attr} (class = '${class_name}'). This is a bug in cuML, please report it."
102+
)
102103

103104
return None

python/cuml/cuml/model_selection/_split.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15+
from __future__ import annotations
1516

1617
from typing import Optional, Union, List, Tuple
1718

python/cuml/cuml/preprocessing/LabelEncoder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def fit(self, y, _classes=None):
207207
self.dtype = y.dtype if y.dtype != cp.dtype("O") else str
208208
return self
209209

210-
def transform(self, y) -> cudf.Series:
210+
def transform(self, y):
211211
"""
212212
Transform an input into its categorical keys.
213213
@@ -242,7 +242,7 @@ def transform(self, y) -> cudf.Series:
242242

243243
return encoded
244244

245-
def fit_transform(self, y, z=None) -> cudf.Series:
245+
def fit_transform(self, y, z=None):
246246
"""
247247
Simultaneously fit and transform an input
248248
@@ -258,7 +258,7 @@ def fit_transform(self, y, z=None) -> cudf.Series:
258258

259259
return y.cat.codes
260260

261-
def inverse_transform(self, y: cudf.Series) -> cudf.Series:
261+
def inverse_transform(self, y: "cudf.Series"):
262262
"""
263263
Revert ordinal label to original label
264264

python/cuml/cuml/preprocessing/encoders.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Optional
1717

1818
import cuml.internals.logger as logger
19-
from cudf import DataFrame, Series
2019
from cuml import Base
2120
from cuml.common.doc_utils import generate_docstring
2221
from cuml.common.exceptions import NotFittedError
@@ -95,7 +94,7 @@ def _check_input(self, X, is_categories=False):
9594
self._set_input_type("array")
9695
if is_categories:
9796
X = X.transpose()
98-
return DataFrame(X)
97+
return cudf.DataFrame(X)
9998
else:
10099
self._set_input_type("df")
101100
return X
@@ -346,7 +345,7 @@ def _compute_drop_idx(self):
346345
)
347346
drop_idx = dict()
348347
for feature in self.drop.keys():
349-
self.drop[feature] = Series(self.drop[feature])
348+
self.drop[feature] = cudf.Series(self.drop[feature])
350349
if len(self.drop[feature]) != 1:
351350
msg = (
352351
"Trying to drop multiple values for feature {}, "
@@ -361,7 +360,7 @@ def _compute_drop_idx(self):
361360
"categories.".format(feature)
362361
)
363362
raise ValueError(msg)
364-
cats = Series(cats)
363+
cats = cudf.Series(cats)
365364
idx = cats.isin(self.drop[feature])
366365
drop_idx[feature] = cp.asarray(cats[idx].index)
367366
return drop_idx
@@ -517,26 +516,28 @@ def inverse_transform(self, X):
517516
# if close: `and not cupyx.scipy.sparse.issparsecsc(X)`
518517
# and change the following line by `X = X.tocsc()`
519518
X = X.toarray()
520-
result = DataFrame(columns=self._encoders.keys())
519+
result = cudf.DataFrame(columns=self._encoders.keys())
521520
j = 0
522521
for feature in self._encoders.keys():
523522
feature_enc = self._encoders[feature]
524523
cats = feature_enc.classes_
525524

526525
if self.drop is not None:
527526
# Remove dropped categories
528-
dropped_class_idx = Series(self.drop_idx_[feature])
529-
dropped_class_mask = Series(cats).isin(cats[dropped_class_idx])
527+
dropped_class_idx = cudf.Series(self.drop_idx_[feature])
528+
dropped_class_mask = cudf.Series(cats).isin(
529+
cats[dropped_class_idx]
530+
)
530531
if len(cats) == 1:
531-
inv = Series(Index([cats[0]]).repeat(X.shape[0]))
532+
inv = cudf.Series(Index([cats[0]]).repeat(X.shape[0]))
532533
result[feature] = inv
533534
continue
534535
cats = cats[~dropped_class_mask]
535536

536537
enc_size = len(cats)
537538
x_feature = X[:, j : j + enc_size]
538539
idx = cp.argmax(x_feature, axis=1)
539-
inv = Series(cats.iloc[idx]).reset_index(drop=True)
540+
inv = cudf.Series(cats.iloc[idx]).reset_index(drop=True)
540541

541542
if self.handle_unknown == "ignore":
542543
not_null_idx = x_feature.any(axis=1)
@@ -548,7 +549,7 @@ def inverse_transform(self, X):
548549
dropped_mask = cp.asarray(x_feature.sum(axis=1) == 0).flatten()
549550
if dropped_mask.any():
550551
inv[dropped_mask] = feature_enc.inverse_transform(
551-
Series(self.drop_idx_[feature])
552+
cudf.Series(self.drop_idx_[feature])
552553
)[0]
553554

554555
result[feature] = inv
@@ -624,7 +625,7 @@ def _slice_feat(X, i):
624625
def _get_output(
625626
output_type: Optional[str],
626627
input_type: Optional[str],
627-
out: DataFrame,
628+
out: "cudf.DataFrame",
628629
dtype,
629630
):
630631
if output_type == "input":
@@ -729,7 +730,7 @@ def transform(self, X):
729730
col_idx = self._encoders[feature].transform(Xi)
730731
result[feature] = col_idx
731732

732-
r = DataFrame(result)
733+
r = cudf.DataFrame(result)
733734
return _get_output(self.output_type, self.input_type, r, self.dtype)
734735

735736
@generate_docstring(
@@ -766,7 +767,7 @@ def inverse_transform(self, X):
766767
inv = self._encoders[feature].inverse_transform(Xi)
767768
result[feature] = inv
768769

769-
r = DataFrame(result)
770+
r = cudf.DataFrame(result)
770771
return _get_output(self.output_type, self.input_type, r, self.dtype)
771772

772773
@classmethod

python/cuml/cuml/preprocessing/text/stem/porter_stemmer_utils/suffix_utils.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
2+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -84,23 +84,25 @@ def subtract_valid(input_array, valid_bool_array, sub_val):
8484
input_array[pos] = input_array[pos] - sub_val
8585

8686

87-
@cudf.core.buffer.acquire_spill_lock()
8887
def get_stem_series(word_str_ser, suffix_len, can_replace_mask):
8988
"""
9089
word_str_ser: input string column
9190
suffix_len: length of suffix to replace
9291
can_repalce_mask: bool array marking strings where to replace
9392
"""
94-
NTHRD = 1024
95-
NBLCK = int(np.ceil(float(len(word_str_ser)) / float(NTHRD)))
96-
97-
start_series = cudf.Series(cp.zeros(len(word_str_ser), dtype=cp.int32))
98-
end_ser = word_str_ser.str.len()
99-
100-
end_ar = end_ser._column.data_array_view(mode="read")
101-
can_replace_mask_ar = can_replace_mask._column.data_array_view(mode="read")
102-
103-
subtract_valid[NBLCK, NTHRD](end_ar, can_replace_mask_ar, suffix_len)
104-
return word_str_ser.str.slice_from(
105-
starts=start_series, stops=end_ser.fillna(0)
106-
)
93+
with cudf.core.buffer.acquire_spill_lock():
94+
NTHRD = 1024
95+
NBLCK = int(np.ceil(float(len(word_str_ser)) / float(NTHRD)))
96+
97+
start_series = cudf.Series(cp.zeros(len(word_str_ser), dtype=cp.int32))
98+
end_ser = word_str_ser.str.len()
99+
100+
end_ar = end_ser._column.data_array_view(mode="read")
101+
can_replace_mask_ar = can_replace_mask._column.data_array_view(
102+
mode="read"
103+
)
104+
105+
subtract_valid[NBLCK, NTHRD](end_ar, can_replace_mask_ar, suffix_len)
106+
return word_str_ser.str.slice_from(
107+
starts=start_series, stops=end_ser.fillna(0)
108+
)

python/cuml/cuml/thirdparty_adapters/adapters.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
2+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -14,12 +14,10 @@
1414
# limitations under the License.
1515
#
1616

17-
from cupyx.scipy import sparse as gpu_sparse
1817
from scipy import sparse as cpu_sparse
1918
from scipy.sparse import csc_matrix as cpu_coo_matrix
2019
from scipy.sparse import csc_matrix as cpu_csc_matrix
2120
from cuml.internals.safe_imports import cpu_only_import_from
22-
from cupyx.scipy.sparse import csc_matrix as gpu_coo_matrix
2321
from cuml.internals.safe_imports import gpu_only_import_from
2422
from cuml.internals.global_settings import GlobalSettings
2523
from cuml.internals.input_utils import input_to_cupy_array, input_to_host_array
@@ -28,6 +26,8 @@
2826

2927
np = cpu_only_import("numpy")
3028
cp = gpu_only_import("cupy")
29+
gpu_sparse = gpu_only_import("cupyx.scipy.sparse")
30+
gpu_coo_matrix = gpu_only_import_from("cupyx.scipy.sparse", "coo_matrix")
3131
gpu_csr_matrix = gpu_only_import_from("cupyx.scipy.sparse", "csr_matrix")
3232
gpu_csc_matrix = gpu_only_import_from("cupyx.scipy.sparse", "csc_matrix")
3333
cpu_csr_matrix = cpu_only_import_from("scipy.sparse", "csr_matrix")

0 commit comments

Comments
 (0)