@@ -211,6 +211,11 @@ impl<F: Float> GaussianMixtureModel<F> {
211211 self . means ( )
212212 }
213213
214+ pub fn predict_proba < D : Data < Elem = F > > ( & self , observations : & ArrayBase < D , Ix2 > ) -> Array2 < F > {
215+ let ( _, log_resp) = self . estimate_log_prob_resp ( observations) ;
216+ log_resp. mapv ( F :: exp)
217+ }
218+
214219 #[ allow( clippy:: type_complexity) ]
215220 fn estimate_gaussian_parameters < D : Data < Elem = F > > (
216221 observations : & ArrayBase < D , Ix2 > ,
@@ -483,6 +488,7 @@ impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<usize
483488#[ cfg( test) ]
484489mod tests {
485490 use super :: * ;
491+ use rand_xoshiro:: Xoshiro256Plus ;
486492 use approx:: { abs_diff_eq, assert_abs_diff_eq} ;
487493 use linfa_datasets:: generate;
488494 use linfa_linalg:: LinalgError ;
@@ -493,6 +499,7 @@ mod tests {
493499 use ndarray_rand:: rand_distr:: Normal ;
494500 use ndarray_rand:: rand_distr:: { Distribution , StandardNormal } ;
495501 use ndarray_rand:: RandomExt ;
502+ use ndarray:: Array ;
496503
497504 #[ test]
498505 fn autotraits ( ) {
@@ -746,4 +753,33 @@ mod tests {
746753 ThreadRng :: default ( ) ,
747754 ) ) ;
748755 }
749- }
756+
757+ #[ test]
758+ fn test_predict_proba ( ) {
759+ let mut rng = Xoshiro256Plus :: seed_from_u64 ( 42 ) ;
760+ let centroids = array ! [ [ 0.0 , 1.0 ] , [ -10.0 , 20.0 ] , [ -1.0 , 10.0 ] ] ;
761+ let n_samples_per_cluster = 1000 ;
762+ let dataset = DatasetBase :: from ( generate:: blobs ( n_samples_per_cluster, & centroids, & mut rng) ) ;
763+ let n_clusters = centroids. len_of ( Axis ( 0 ) ) ;
764+ // total samples = (n_samples_per_cluster * n_clusters)
765+ let total_samples = n_samples_per_cluster * n_clusters;
766+
767+ // model fitting
768+ let gmm = GaussianMixtureModel :: params ( n_clusters)
769+ . with_rng ( rng)
770+ . fit ( & dataset)
771+ . expect ( "Failed to fit GMM" ) ;
772+
773+ // getting probabilities
774+ let proba = gmm. predict_proba ( dataset. records ( ) ) ;
775+
776+ // checking output shape is correct or not
777+ assert_eq ! ( proba. dim( ) , ( total_samples, n_clusters) ) ;
778+
779+ // checking for each sample, the sum of probabilities is 1
780+ for sample in proba. outer_iter ( ) {
781+ let sum: f64 = sample. sum ( ) ;
782+ assert_abs_diff_eq ! ( sum, 1.0 , epsilon = 1e-6 ) ;
783+ }
784+ }
785+ }
0 commit comments