Skip to content

Prepare for n_init=auto in KMeans #6142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions python/cuml/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

# distutils: language = c++

import warnings

from cuml.internals.safe_imports import cpu_only_import
np = cpu_only_import('numpy')
from cuml.internals.safe_imports import gpu_only_import
Expand Down Expand Up @@ -49,6 +51,7 @@ from cuml.internals.mixins import CMajorInputTagMixin
from cuml.common import input_to_cuml_array
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
from cuml.internals.global_settings import GlobalSettings

# from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
_openmp_effective_n_threads = safe_import_from(
Expand Down Expand Up @@ -95,7 +98,7 @@ class KMeans(UniversalBase,
3 4.0 3.0
>>>
>>> # Calling fit
>>> kmeans_float = KMeans(n_clusters=2)
>>> kmeans_float = KMeans(n_clusters=2, n_init="auto")
>>> kmeans_float.fit(b)
KMeans()
>>>
Expand Down Expand Up @@ -143,10 +146,17 @@ class KMeans(UniversalBase,
- If an ndarray is passed, it should be of
shape (`n_clusters`, `n_features`) and gives the initial centers.

n_init: int (default = 1)
n_init: 'auto' or int (default = 1)
Number of instances the k-means algorithm will be called with
different seeds. The final results will be from the instance
that produces lowest inertia out of n_init instances.

.. versionadded:: 25.02
Added 'auto' option for `n_init`.

.. versionchanged:: 25.04
Default value for `n_init` will change from 1 to `'auto'` in version 25.04.

oversampling_factor : float64 (default = 2.0)
The amount of points to sample
in scalable k-means++ initialization for potential centroids.
Expand Down Expand Up @@ -216,15 +226,31 @@ class KMeans(UniversalBase,
params.metric = DistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501
params.batch_samples = <int>self.max_samples_per_batch
params.oversampling_factor = <double>self.oversampling_factor
params.n_init = <int>self.n_init
n_init = self.n_init
if n_init == "warn":
if not GlobalSettings().accelerator_active:
warnings.warn(
"The default value of `n_init` will change from"
" 1 to 'auto' in 25.04. Set the value of `n_init`"
" explicitly to suppress this warning.",
FutureWarning,
)
n_init = 1
if n_init == "auto":
if self.init in ("k-means||", "scalable-k-means++"):
params.n_init = 1
else:
params.n_init = 10
else:
params.n_init = <int>n_init
return <size_t>params
ELSE:
return None

@device_interop_preparation
def __init__(self, *, handle=None, n_clusters=8, max_iter=300, tol=1e-4,
verbose=False, random_state=1,
init='scalable-k-means++', n_init=1, oversampling_factor=2.0,
init='scalable-k-means++', n_init="warn", oversampling_factor=2.0,
max_samples_per_batch=1<<15, convert_dtype=True,
output_type=None):
super().__init__(handle=handle,
Expand Down
7 changes: 5 additions & 2 deletions python/cuml/cuml/explainer/sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -91,7 +91,10 @@ def kmeans_sampling(X, k, round_values=True, detailed=False, random_state=0):
X = imp.fit_transform(X)

kmeans = KMeans(
n_clusters=k, random_state=random_state, output_type=_output_dtype_str
n_clusters=k,
random_state=random_state,
output_type=_output_dtype_str,
n_init="auto",
).fit(X)

if round_values:
Expand Down
25 changes: 19 additions & 6 deletions python/cuml/cuml/tests/dask/test_dask_kmeans.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -63,7 +63,10 @@ def test_end_to_end(
X_train, y_train = X, y

cumlModel = cumlKMeans(
init="k-means||", n_clusters=nclusters, random_state=10
init="k-means||",
n_clusters=nclusters,
random_state=10,
n_init="auto",
)

cumlModel.fit(X_train)
Expand Down Expand Up @@ -120,7 +123,7 @@ def test_large_data_no_overflow(nrows_per_part, ncols, nclusters, client):
X.compute_chunk_sizes().persist()

cumlModel = cumlKMeans(
init="k-means||", n_clusters=nclusters, random_state=10
init="k-means||", n_clusters=nclusters, random_state=10, n_init="auto"
)

cumlModel.fit(X_train)
Expand Down Expand Up @@ -171,7 +174,11 @@ def test_weighted_kmeans(nrows, ncols, nclusters, n_parts, client):
wt[cp.argmax(cp.array(y.compute()) == i).item()] = 5000.0

cumlModel = cumlKMeans(
verbose=0, init="k-means||", n_clusters=nclusters, random_state=10
verbose=0,
init="k-means||",
n_clusters=nclusters,
random_state=10,
n_init="auto",
)

chunk_parts = int(nrows / n_parts)
Expand Down Expand Up @@ -237,7 +244,10 @@ def test_transform(nrows, ncols, nclusters, n_parts, input_type, client):
labels = cp.squeeze(y_train.compute())

cumlModel = cumlKMeans(
init="k-means||", n_clusters=nclusters, random_state=10
init="k-means||",
n_clusters=nclusters,
random_state=10,
n_init="auto",
)

cumlModel.fit(X_train)
Expand Down Expand Up @@ -302,7 +312,10 @@ def test_score(nrows, ncols, nclusters, n_parts, input_type, client):
X_train, y_train = X, y

cumlModel = cumlKMeans(
init="k-means||", n_clusters=nclusters, random_state=10
init="k-means||",
n_clusters=nclusters,
random_state=10,
n_init="auto",
)

cumlModel.fit(X_train)
Expand Down
4 changes: 3 additions & 1 deletion python/cuml/cuml/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -237,6 +237,8 @@ def test_fit_function(dataset, model_name):
model = models[model_name](np.random.normal(0.0, 1.0, (10,)))
elif model_name in ["RandomForestClassifier", "RandomForestRegressor"]:
model = models[model_name](n_bins=32)
elif model_name == "KMeans":
model = models[model_name](n_init="auto")
else:
if n_pos_args_constr == 1:
model = models[model_name]()
Expand Down
4 changes: 4 additions & 0 deletions python/cuml/cuml/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import cuml
from cuml.datasets import make_blobs

pytestmark = pytest.mark.filterwarnings(
"ignore:The default value of `n_init` will change from 1 to 'auto' in 25.04"
)


@pytest.mark.parametrize(
"Estimator",
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/cuml/tests/test_device_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,8 @@ def test_kmeans_methods(train_device, infer_device):
ref_model.fit(X_train_blob)
ref_output = ref_model.predict(X_test_blob)

model = KMeans(n_clusters=n_clusters, random_state=42)
model = KMeans(n_clusters=n_clusters, random_state=42, n_init="auto")

with using_device_type(train_device):
model.fit(X_train_blob)
with using_device_type(infer_device):
Expand Down
7 changes: 5 additions & 2 deletions python/cuml/cuml/tests/test_input_estimators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -115,7 +115,10 @@ def test_estimators_all_dtypes(model_name, dtype):

X_train, y_train, X_test = make_dataset(dtype, nrows, ncols, ninfo)
print(model_name)
model = models[model_name]()
if model_name == "KMeans":
model = models[model_name](n_init="auto")
else:
model = models[model_name]()
sign = inspect.signature(model.fit)
if "y" in sign.parameters:
model.fit(X=X_train, y=y_train)
Expand Down
29 changes: 27 additions & 2 deletions python/cuml/cuml/tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,24 @@ def random_state():
return random_state


def test_n_init_deprecation():
X, y = make_blobs(
random_state=0,
)

# Warn about default changing
kmeans = cuml.KMeans()
with pytest.warns(
FutureWarning, match="The default value of `n_init` will change from"
):
kmeans.fit(X)

# No warning when explicitly set to integer or 'auto'
for n_init in ("auto", 2):
kmeans = cuml.KMeans(n_init=n_init)
kmeans.fit(X)


@pytest.mark.xfail
def test_n_init_cluster_consistency(random_state):

Expand Down Expand Up @@ -127,7 +145,9 @@ def test_traditional_kmeans_plus_plus_init(
cuml_kmeans.fit(X)
cu_score = cuml_kmeans.score(X)

kmeans = cluster.KMeans(random_state=random_state, n_clusters=nclusters)
kmeans = cluster.KMeans(
random_state=random_state, n_clusters=nclusters, n_init=1
)
kmeans.fit(cp.asnumpy(X))
sk_score = kmeans.score(cp.asnumpy(X))

Expand Down Expand Up @@ -167,7 +187,9 @@ def test_weighted_kmeans(nrows, ncols, nclusters, max_weight, random_state):
cuml_kmeans.fit(X, sample_weight=wt)
cu_score = cuml_kmeans.score(X)

sk_kmeans = cluster.KMeans(random_state=random_state, n_clusters=nclusters)
sk_kmeans = cluster.KMeans(
random_state=random_state, n_clusters=nclusters, n_init=1
)
sk_kmeans.fit(cp.asnumpy(X), sample_weight=wt)
sk_score = sk_kmeans.score(cp.asnumpy(X))

Expand Down Expand Up @@ -200,6 +222,7 @@ def test_kmeans_clusters_blobs(
n_clusters=nclusters,
random_state=random_state,
output_type="numpy",
n_init=1,
)

preds = cuml_kmeans.fit_predict(X)
Expand Down Expand Up @@ -331,6 +354,7 @@ def test_all_kmeans_params(
oversampling_factor=oversampling_factor,
max_samples_per_batch=max_samples_per_batch,
output_type="cupy",
n_init=1,
)

cuml_kmeans.fit_predict(X)
Expand Down Expand Up @@ -359,6 +383,7 @@ def test_score(nrows, ncols, nclusters, random_state):
n_clusters=nclusters,
random_state=random_state,
output_type="numpy",
n_init=1,
)

cuml_kmeans.fit(X)
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/cuml/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_rand_index_score(name, nrows):
params = default_base.copy()
params.update(pat[1])

cuml_kmeans = cuml.KMeans(n_clusters=params["n_clusters"])
cuml_kmeans = cuml.KMeans(n_clusters=params["n_clusters"], n_init="auto")

X, y = pat[0]

Expand Down
5 changes: 4 additions & 1 deletion python/cuml/cuml/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,10 @@ def test_cluster_pickle(tmpdir, datatype, keys, data_size):
def create_mod():
nrows, ncols, n_info = data_size
X_train, y_train, X_test = make_dataset(datatype, nrows, ncols, n_info)
model = cluster_models[keys]()
if keys == "KMeans":
model = cluster_models[keys](n_init="auto")
else:
model = cluster_models[keys]()
model.fit(X_train)
result["cluster"] = model.predict(X_test)
return model, X_test
Expand Down
11 changes: 10 additions & 1 deletion python/cuml/cuml/tests/test_public_methods_attributes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,15 @@
from cuml.internals.global_settings import GlobalSettings
from sklearn.datasets import make_classification, make_regression

pytestmark = [
pytest.mark.filterwarnings(
"ignore:Starting from version 22.04, the default method of TSNE is 'fft'."
),
pytest.mark.filterwarnings(
"ignore:The default value of `n_init` will change from 1 to 'auto' in 25.04"
),
]


estimators = [
"KMeans",
Expand Down
6 changes: 6 additions & 0 deletions python/cuml/cuml/tests/test_sklearn_import_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def test_basic_roundtrip():
assert ckm.n_clusters == 13


@pytest.mark.filterwarnings(
"ignore:The default value of `n_init` will change from 1 to 'auto' in 25.04"
)
def test_kmeans(random_state):
# Using sklearn directly for demonstration
X, _ = make_blobs(
Expand Down Expand Up @@ -235,6 +238,9 @@ def test_lasso(random_state):
assert_estimator_roundtrip(original, SkLasso, X, y)


@pytest.mark.filterwarnings(
"ignore:Starting from version 22.04, the default method of TSNE is 'fft'."
)
def test_tsne(random_state):
# TSNE is a bit tricky as it is non-deterministic. For test simplicity:
X = np.random.RandomState(random_state).rand(50, 5)
Expand Down
Loading