Skip to content

Commit a0af77a

Browse files
committed
Fix: Added predict_proba and tests for the function(#372)
1 parent e437f32 commit a0af77a

File tree

1 file changed

+37
-1
lines changed
  • algorithms/linfa-clustering/src/gaussian_mixture

1 file changed

+37
-1
lines changed

algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ impl<F: Float> GaussianMixtureModel<F> {
211211
self.means()
212212
}
213213

214+
pub fn predict_proba<D: Data<Elem = F>>(&self, observations: &ArrayBase<D, Ix2>) -> Array2<F> {
215+
let(_, log_resp) = self.estimate_log_prob_resp(observations);
216+
log_resp.mapv(F::exp)
217+
}
218+
214219
#[allow(clippy::type_complexity)]
215220
fn estimate_gaussian_parameters<D: Data<Elem = F>>(
216221
observations: &ArrayBase<D, Ix2>,
@@ -483,6 +488,7 @@ impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<usize
483488
#[cfg(test)]
484489
mod tests {
485490
use super::*;
491+
use rand_xoshiro::Xoshiro256Plus;
486492
use approx::{abs_diff_eq, assert_abs_diff_eq};
487493
use linfa_datasets::generate;
488494
use linfa_linalg::LinalgError;
@@ -493,6 +499,7 @@ mod tests {
493499
use ndarray_rand::rand_distr::Normal;
494500
use ndarray_rand::rand_distr::{Distribution, StandardNormal};
495501
use ndarray_rand::RandomExt;
502+
use ndarray::Array;
496503

497504
#[test]
498505
fn autotraits() {
@@ -746,4 +753,33 @@ mod tests {
746753
ThreadRng::default(),
747754
));
748755
}
749-
}
756+
757+
#[test]
758+
fn test_predict_proba() {
759+
let mut rng = Xoshiro256Plus::seed_from_u64(42);
760+
let centroids = array![[0.0, 1.0], [-10.0, 20.0], [-1.0, 10.0]];
761+
let n_samples_per_cluster = 1000;
762+
let dataset = DatasetBase::from(generate::blobs(n_samples_per_cluster, &centroids, &mut rng));
763+
let n_clusters = centroids.len_of(Axis(0));
764+
// total samples = (n_samples_per_cluster * n_clusters)
765+
let total_samples = n_samples_per_cluster * n_clusters;
766+
767+
// model fitting
768+
let gmm = GaussianMixtureModel::params(n_clusters)
769+
.with_rng(rng)
770+
.fit(&dataset)
771+
.expect("Failed to fit GMM");
772+
773+
// getting probabilities
774+
let proba = gmm.predict_proba(dataset.records());
775+
776+
// checking output shape is correct or not
777+
assert_eq!(proba.dim(), (total_samples, n_clusters));
778+
779+
// checking for each sample, the sum of probabilities is 1
780+
for sample in proba.outer_iter() {
781+
let sum: f64 = sample.sum();
782+
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
783+
}
784+
}
785+
}

0 commit comments

Comments
 (0)