11use linfa:: prelude:: { Fit , Predict , ToConfusionMatrix } ;
2- use linfa_ensemble:: EnsembleLearnerParams ;
2+ use linfa_ensemble:: { EnsembleLearnerParams , RandomForestParams } ;
33use linfa_trees:: DecisionTree ;
44use ndarray_rand:: rand:: SeedableRng ;
55use rand:: rngs:: SmallRng ;
66
7- fn ensemble_learner (
8- ensemble_size : usize ,
9- bootstrap_proportion : f64 ,
10- feature_proportion : f64 ,
11- ) -> ( ) {
7+ fn ensemble_learner ( ensemble_size : usize , bootstrap_proportion : f64 ) -> ( ) {
128 // Load dataset
139 let mut rng = SmallRng :: seed_from_u64 ( 42 ) ;
1410 let ( train, test) = linfa_datasets:: iris ( )
@@ -17,6 +13,30 @@ fn ensemble_learner(
1713
1814 // Train ensemble learner model
1915 let model = EnsembleLearnerParams :: new_fixed_rng ( DecisionTree :: params ( ) , rng)
16+ . ensemble_size ( ensemble_size)
17+ . bootstrap_proportion ( bootstrap_proportion)
18+ . fit ( & train)
19+ . unwrap ( ) ;
20+
21+ // Return highest ranking predictions
22+ let final_predictions_ensemble = model. predict ( & test) ;
23+ println ! ( "Final Predictions: \n {final_predictions_ensemble:?}" ) ;
24+
25+ let cm = final_predictions_ensemble. confusion_matrix ( & test) . unwrap ( ) ;
26+
27+ println ! ( "{cm:?}" ) ;
28+ println ! ( "Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {ensemble_size},\n Bootstrap Proportion: {bootstrap_proportion}" ,
29+ 100.0 * cm. accuracy( ) ) ;
30+ }
31+
32+ fn random_forest ( ensemble_size : usize , bootstrap_proportion : f64 , feature_proportion : f64 ) {
33+ let mut rng = SmallRng :: seed_from_u64 ( 42 ) ;
34+ let ( train, test) = linfa_datasets:: iris ( )
35+ . shuffle ( & mut rng)
36+ . split_with_ratio ( 0.8 ) ;
37+
38+ // Train ensemble learner model
39+ let model = RandomForestParams :: new_fixed_rng ( DecisionTree :: params ( ) , rng)
2040 . ensemble_size ( ensemble_size)
2141 . bootstrap_proportion ( bootstrap_proportion)
2242 . feature_proportion ( feature_proportion)
@@ -37,9 +57,9 @@ fn ensemble_learner(
3757fn main ( ) {
3858 // This is an example bagging with decision tree
3959 println ! ( "An example using Bagging with Decision Tree on Iris Dataset" ) ;
40- ensemble_learner ( 100 , 0.7 , 1.0 ) ;
60+ ensemble_learner ( 100 , 0.7 ) ;
4161 println ! ( "" ) ;
4262 // This is basically a Random Forest ensemble
4363 println ! ( "An example using a Random Forest on Iris Dataset" ) ;
44- ensemble_learner ( 100 , 0.7 , 0.2 ) ;
64+ random_forest ( 100 , 0.7 , 0.2 ) ;
4565}
0 commit comments