-
Notifications
You must be signed in to change notification settings - Fork 197
[REVIEW] Add KMeans.fit_predict to Python API
#1956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
|
|
||
| from .kmeans import KMeansParams, cluster_cost, fit, predict | ||
| from .kmeans import KMeansParams, cluster_cost, fit, fit_predict, predict | ||
|
|
||
| __all__ = ["KMeansParams", "cluster_cost", "fit", "predict"] | ||
| __all__ = ["KMeansParams", "cluster_cost", "fit", "fit_predict", "predict"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| # | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # cython: language_level=3 | ||
|
|
@@ -156,6 +156,10 @@ cdef class KMeansParams: | |
|
|
||
| FitOutput = namedtuple("FitOutput", "centroids inertia n_iter") | ||
|
|
||
| FitPredictOutput = namedtuple( | ||
| "FitPredictOutput", "labels centroids inertia n_iter" | ||
| ) | ||
|
|
||
|
|
||
| @auto_sync_resources | ||
| @auto_convert_output | ||
|
|
@@ -239,6 +243,86 @@ def fit( | |
| return FitOutput(centroids, inertia, n_iter) | ||
|
|
||
|
|
||
| @auto_sync_resources | ||
| @auto_convert_output | ||
| def fit_predict( | ||
| KMeansParams params, | ||
| X, | ||
| centroids=None, | ||
| sample_weights=None, | ||
| labels=None, | ||
| normalize_weight=True, | ||
| resources=None, | ||
| ): | ||
| """ | ||
| Fit k-means on ``X`` and return cluster labels for the same data. | ||
|
|
||
| This is equivalent to calling :func:`fit` followed by :func:`predict` on | ||
| ``X`` with the fitted centroids. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| params : KMeansParams | ||
| Parameters to use to fit the KMeans model | ||
| X : Input CUDA array interface compliant matrix shape (m, k) | ||
| centroids : Optional writable CUDA array interface compliant matrix | ||
| shape (n_clusters, k) | ||
| sample_weights : Optional input CUDA array interface compliant matrix shape | ||
| (n_clusters, 1) default: None | ||
| labels : Optional preallocated CUDA array interface matrix shape (m, 1) | ||
| to hold the output labels | ||
| normalize_weight: bool | ||
| Passed to :func:`predict`; True if the weights should be normalized | ||
| {resources_docstring} | ||
|
|
||
| Returns | ||
| ------- | ||
| labels : raft.device_ndarray | ||
| Cluster index for each row of ``X`` | ||
| centroids : raft.device_ndarray | ||
| The fitted cluster centroids | ||
| inertia : float | ||
| Sum of squared distances of samples to their closest cluster center | ||
| (from the prediction step) | ||
| n_iter : int | ||
| Number of iterations used in :func:`fit` | ||
|
|
||
| Examples | ||
| -------- | ||
|
|
||
| >>> import cupy as cp | ||
| >>> | ||
| >>> from cuvs.cluster.kmeans import fit_predict, KMeansParams | ||
| >>> | ||
| >>> n_samples = 5000 | ||
| >>> n_features = 50 | ||
| >>> n_clusters = 3 | ||
| >>> | ||
| >>> X = cp.random.random_sample((n_samples, n_features), | ||
| ... dtype=cp.float32) | ||
| >>> | ||
| >>> params = KMeansParams(n_clusters=n_clusters) | ||
| >>> labels, centroids, inertia, n_iter = fit_predict(params, X) | ||
| """ | ||
| centroids_out, _, n_iter = fit( | ||
| params, | ||
| X, | ||
| centroids=centroids, | ||
| sample_weights=sample_weights, | ||
| resources=resources, | ||
| ) | ||
| labels_out, inertia_pred = predict( | ||
| params, | ||
| X, | ||
| centroids_out, | ||
| sample_weights=sample_weights, | ||
| labels=labels, | ||
| normalize_weight=normalize_weight, | ||
| resources=resources, | ||
| ) | ||
| return FitPredictOutput(labels_out, centroids_out, inertia_pred, n_iter) | ||
|
Comment on lines
+307
to
+323
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would need to call the missing C function "cuvsKMeansFitPredict" because #1939 is adding some improvements to the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jrbourbeau, given that we're likely to merge @lowener's PR first in sequence, could you base your branch on his and build from his changes? Then we can merge yours in shortly after.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI the latest update on #1939 is that the C++ fit_predict() function will be removed and the labels will be returned as part of the fit() function as an optionnal output
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lowener is this still the case after your benchmarking or are we planning to keep the explicit fit_predict? |
||
|
|
||
|
|
||
| PredictOutput = namedtuple("PredictOutput", "labels inertia") | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going away in preference for the Fern-based docs, which we can generate with codex. Just FYI. The process is about to get a lot easier.