77//!
88//! [`fit`] computes cluster centroids for a dataset, [`predict`] assigns points
99//! to clusters, and [`cluster_cost`] reports the inertia. All inputs and outputs
10- //! reside in device memory and are passed through the
11- //! [`IntoDlTensor `] /
12- //! [`IntoDlTensorMut `] traits; see the
10+ //! reside in device memory and are borrowed through the
11+ //! [`AsDlTensor `] /
12+ //! [`AsDlTensorMut `] traits; see the
1313//! [`dlpack`](crate::dlpack) module for the tensor model.
1414
1515mod params;
1616
1717pub use params:: Params ;
1818
19- use crate :: dlpack:: { IntoDlTensor , IntoDlTensorMut } ;
19+ use crate :: dlpack:: { AsDlTensor , AsDlTensorMut } ;
2020use crate :: error:: { Result , check_cuvs} ;
2121use crate :: resources:: Resources ;
2222
@@ -25,18 +25,23 @@ use crate::resources::Resources;
2525/// `x` (shape `m × k`) is the input matrix and `centroids` (shape
2626/// `n_clusters × k`) receives the fitted centroids; `sample_weight` is an
2727/// optional per-sample weight. All reside in device memory and implement
28- /// [`IntoDlTensor `] /
29- /// [`IntoDlTensorMut `].
30- pub fn fit < ' a > (
28+ /// [`AsDlTensor `] /
29+ /// [`AsDlTensorMut `].
30+ pub fn fit < X , W , C > (
3131 res : & Resources ,
3232 params : & Params ,
33- x : impl IntoDlTensor < ' a > ,
34- sample_weight : Option < impl IntoDlTensor < ' a > > ,
35- centroids : impl IntoDlTensorMut < ' a > ,
36- ) -> Result < ( f64 , i32 ) > {
37- let x = x. into_dl_tensor ( ) ?;
38- let sample_weight = sample_weight. map ( |w| w. into_dl_tensor ( ) ) . transpose ( ) ?;
39- let centroids = centroids. into_dl_tensor_mut ( ) ?;
33+ x : & X ,
34+ sample_weight : Option < & W > ,
35+ centroids : & mut C ,
36+ ) -> Result < ( f64 , i32 ) >
37+ where
38+ X : AsDlTensor + ?Sized ,
39+ W : AsDlTensor + ?Sized ,
40+ C : AsDlTensorMut + ?Sized ,
41+ {
42+ let x = x. as_dl_tensor ( ) ?;
43+ let sample_weight = sample_weight. map ( |w| w. as_dl_tensor ( ) ) . transpose ( ) ?;
44+ let centroids = centroids. as_dl_tensor_mut ( ) ?;
4045 let mut inertia: f64 = 0.0 ;
4146 let mut niter: i32 = 0 ;
4247 let mut sample_weight_c = sample_weight. as_ref ( ) . map ( |w| w. to_c ( ) ) ;
@@ -62,22 +67,28 @@ pub fn fit<'a>(
6267///
6368/// `x` (shape `m × k`), `centroids` (shape `n_clusters × k`), the optional
6469/// `sample_weight`, and `labels` (shape `m × 1`) reside in device memory and
65- /// implement [`IntoDlTensor `] /
66- /// [`IntoDlTensorMut `]. `normalize_weight` selects
70+ /// implement [`AsDlTensor `] /
71+ /// [`AsDlTensorMut `]. `normalize_weight` selects
6772/// whether the sample weights are normalized.
68- pub fn predict < ' a > (
73+ pub fn predict < X , W , C , L > (
6974 res : & Resources ,
7075 params : & Params ,
71- x : impl IntoDlTensor < ' a > ,
72- sample_weight : Option < impl IntoDlTensor < ' a > > ,
73- centroids : impl IntoDlTensor < ' a > ,
74- labels : impl IntoDlTensorMut < ' a > ,
76+ x : & X ,
77+ sample_weight : Option < & W > ,
78+ centroids : & C ,
79+ labels : & mut L ,
7580 normalize_weight : bool ,
76- ) -> Result < f64 > {
77- let x = x. into_dl_tensor ( ) ?;
78- let sample_weight = sample_weight. map ( |w| w. into_dl_tensor ( ) ) . transpose ( ) ?;
79- let centroids = centroids. into_dl_tensor ( ) ?;
80- let labels = labels. into_dl_tensor_mut ( ) ?;
81+ ) -> Result < f64 >
82+ where
83+ X : AsDlTensor + ?Sized ,
84+ W : AsDlTensor + ?Sized ,
85+ C : AsDlTensor + ?Sized ,
86+ L : AsDlTensorMut + ?Sized ,
87+ {
88+ let x = x. as_dl_tensor ( ) ?;
89+ let sample_weight = sample_weight. map ( |w| w. as_dl_tensor ( ) ) . transpose ( ) ?;
90+ let centroids = centroids. as_dl_tensor ( ) ?;
91+ let labels = labels. as_dl_tensor_mut ( ) ?;
8192 let mut inertia: f64 = 0.0 ;
8293 let mut sample_weight_c = sample_weight. as_ref ( ) . map ( |w| w. to_c ( ) ) ;
8394 let sample_weight_ptr =
@@ -101,14 +112,14 @@ pub fn predict<'a>(
101112/// Computes the k-means cost (inertia) of `x` against existing `centroids`.
102113///
103114/// `x` (shape `m × k`) and `centroids` (shape `n_clusters × k`) reside in device
104- /// memory and implement [`IntoDlTensor `].
105- pub fn cluster_cost < ' a > (
106- res : & Resources ,
107- x : impl IntoDlTensor < ' a > ,
108- centroids : impl IntoDlTensor < ' a > ,
109- ) -> Result < f64 > {
110- let x = x. into_dl_tensor ( ) ?;
111- let centroids = centroids. into_dl_tensor ( ) ?;
115+ /// memory and implement [`AsDlTensor `].
116+ pub fn cluster_cost < X , C > ( res : & Resources , x : & X , centroids : & C ) -> Result < f64 >
117+ where
118+ X : AsDlTensor + ? Sized ,
119+ C : AsDlTensor + ? Sized ,
120+ {
121+ let x = x. as_dl_tensor ( ) ?;
122+ let centroids = centroids. as_dl_tensor ( ) ?;
112123 let mut inertia: f64 = 0.0 ;
113124
114125 unsafe {
0 commit comments