Skip to content

Commit e7974ae

Browse files
committed
Add type alias for Random Forest as an EnsembleLearner with model type DecisionTree.
1 parent ddba346 commit e7974ae

File tree

3 files changed

+38
-12
lines changed

3 files changed

+38
-12
lines changed

algorithms/linfa-ensemble/examples/ensemble_iris.rs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
2-
use linfa_ensemble::EnsembleLearnerParams;
2+
use linfa_ensemble::{EnsembleLearnerParams, RandomForestParams};
33
use linfa_trees::DecisionTree;
44
use ndarray_rand::rand::SeedableRng;
55
use rand::rngs::SmallRng;
66

7-
fn ensemble_learner(
8-
ensemble_size: usize,
9-
bootstrap_proportion: f64,
10-
feature_proportion: f64,
11-
) -> () {
7+
fn ensemble_learner(ensemble_size: usize, bootstrap_proportion: f64) -> () {
128
// Load dataset
139
let mut rng = SmallRng::seed_from_u64(42);
1410
let (train, test) = linfa_datasets::iris()
@@ -17,6 +13,30 @@ fn ensemble_learner(
1713

1814
// Train ensemble learner model
1915
let model = EnsembleLearnerParams::new_fixed_rng(DecisionTree::params(), rng)
16+
.ensemble_size(ensemble_size)
17+
.bootstrap_proportion(bootstrap_proportion)
18+
.fit(&train)
19+
.unwrap();
20+
21+
// Return highest ranking predictions
22+
let final_predictions_ensemble = model.predict(&test);
23+
println!("Final Predictions: \n{final_predictions_ensemble:?}");
24+
25+
let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap();
26+
27+
println!("{cm:?}");
28+
println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {ensemble_size},\n Bootstrap Proportion: {bootstrap_proportion}",
29+
100.0 * cm.accuracy());
30+
}
31+
32+
fn random_forest(ensemble_size: usize, bootstrap_proportion: f64, feature_proportion: f64) {
33+
let mut rng = SmallRng::seed_from_u64(42);
34+
let (train, test) = linfa_datasets::iris()
35+
.shuffle(&mut rng)
36+
.split_with_ratio(0.8);
37+
38+
// Train ensemble learner model
39+
let model = RandomForestParams::new_fixed_rng(DecisionTree::params(), rng)
2040
.ensemble_size(ensemble_size)
2141
.bootstrap_proportion(bootstrap_proportion)
2242
.feature_proportion(feature_proportion)
@@ -37,9 +57,9 @@ fn ensemble_learner(
3757
fn main() {
3858
// This is an example bagging with decision tree
3959
println!("An example using Bagging with Decision Tree on Iris Dataset");
40-
ensemble_learner(100, 0.7, 1.0);
60+
ensemble_learner(100, 0.7);
4161
println!("");
4262
// This is basically a Random Forest ensemble
4363
println!("An example using a Random Forest on Iris Dataset");
44-
ensemble_learner(100, 0.7, 0.2);
64+
random_forest(100, 0.7, 0.2);
4565
}

algorithms/linfa-ensemble/src/algorithm.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ use linfa::{
55
traits::*,
66
DatasetBase,
77
};
8+
use linfa_trees::DecisionTree;
89
use ndarray::{Array2, Axis, Zip};
910
use rand::Rng;
1011
use std::{cmp::Eq, collections::HashMap, hash::Hash};
1112

13+
pub type RandomForest<F, L> = EnsembleLearner<DecisionTree<F, L>>;
14+
1215
pub struct EnsembleLearner<M> {
1316
pub models: Vec<M>,
1417
pub model_features: Vec<Vec<usize>>,
@@ -18,14 +21,14 @@ impl<M> EnsembleLearner<M> {
1821
// Generates prediction iterator returning predictions from each model
1922
pub fn generate_predictions<'b, R: Records, T>(
2023
&'b self,
21-
x: &'b Vec<R>,
24+
x: &'b [R],
2225
) -> impl Iterator<Item = T> + 'b
2326
where
2427
M: Predict<&'b R, T>,
2528
{
2629
self.models
2730
.iter()
28-
.zip(x.into_iter())
31+
.zip(x.iter())
2932
.map(move |(m, sub_data)| m.predict(sub_data))
3033
}
3134
}
@@ -112,8 +115,8 @@ where
112115
}
113116

114117
Ok(EnsembleLearner {
115-
models: models,
116-
model_features: model_features,
118+
models,
119+
model_features,
117120
})
118121
}
119122
}

algorithms/linfa-ensemble/src/hyperparams.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use linfa::{
22
error::{Error, Result},
33
ParamGuard,
44
};
5+
use linfa_trees::DecisionTreeParams;
56
use rand::rngs::ThreadRng;
67
use rand::Rng;
78

@@ -21,6 +22,8 @@ pub struct EnsembleLearnerValidParams<P, R> {
2122
#[derive(Clone, Copy, Debug, PartialEq)]
2223
pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);
2324

25+
pub type RandomForestParams<F, L, R> = EnsembleLearnerParams<DecisionTreeParams<F, L>, R>;
26+
2427
impl<P> EnsembleLearnerParams<P, ThreadRng> {
2528
pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> {
2629
Self::new_fixed_rng(model_params, rand::thread_rng())

0 commit comments

Comments
 (0)