Skip to content

Commit a974e5a

Browse files
committed
📜 Improve linfa_ensemble documentation
1 parent 91cbc5e commit a974e5a

File tree

3 files changed

+78
-6
lines changed

3 files changed

+78
-6
lines changed

algorithms/linfa-ensemble/src/algorithm.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,63 @@ use ndarray::{Array2, Axis, Zip};
1010
use rand::Rng;
1111
use std::{cmp::Eq, collections::HashMap, hash::Hash};
1212

13+
/// A fitted ensemble of [Decision Trees](DecisionTree) trained on a random subset of features.
14+
///
15+
/// Check out [EnsembleLearner] documentation for more information regarding [RandomForest] interface.
1316
pub type RandomForest<F, L> = EnsembleLearner<DecisionTree<F, L>>;
1417

18+
/// A fitted ensemble of learners for classification.
19+
///
20+
/// ## Structure
21+
///
22+
/// An Ensemble Learner is composed of a collection of fitted models of type `M`.
23+
///
24+
/// ## Fitting Algorithm
25+
///
26+
/// Given a [DatasetBase](DatasetBase) denoted as `D`,
27+
/// 1. Create as many distinct bootstrapped subset of the original dataset `D` as number of
28+
/// distinct model to fit.
29+
/// 2. Fit each distinct model on a distinct bootstrapped subset of `D`.
30+
///
31+
/// Note that the subset size, as well as the subset of feature to use in each training subset can
32+
/// be specified in the [parameters](crate::EnsembleLearnerParams).
33+
///
34+
/// ## Prediction Algorithm
35+
///
36+
/// The prediction result is the result of majority voting across the fitted learners.
37+
///
38+
/// ## Example
39+
///
40+
/// This example shows how to train a bagging model using 100 decision trees,
41+
/// each trained on 70% of the training data (bootstrap sampling).
42+
/// ```no_run
43+
/// use linfa::prelude::{Fit, Predict};
44+
/// use linfa_ensemble::EnsembleLearnerParams;
45+
/// use linfa_trees::DecisionTree;
46+
/// use ndarray_rand::rand::SeedableRng;
47+
/// use rand::rngs::SmallRng;
48+
///
49+
/// // Load Iris dataset
50+
/// let mut rng = SmallRng::seed_from_u64(42);
51+
/// let (train, test) = linfa_datasets::iris()
52+
/// .shuffle(&mut rng)
53+
/// .split_with_ratio(0.8);
54+
///
55+
/// // Train the model on the iris dataset
56+
/// let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
57+
/// .ensemble_size(100) // Number of Decision Tree to fit
58+
/// .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
59+
/// .fit(&train)
60+
/// .unwrap();
61+
///
62+
/// // Make predictions on the test set
63+
/// let predictions = bagging_model.predict(&test);
64+
/// ```
65+
///
66+
/// ## References
67+
///
68+
/// * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
69+
/// * [An Introduction to Statistical Learning](https://www.statlearning.com/)
1570
pub struct EnsembleLearner<M> {
1671
pub models: Vec<M>,
1772
pub model_features: Vec<Vec<usize>>,

algorithms/linfa-ensemble/src/hyperparams.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,26 @@ use linfa_trees::DecisionTreeParams;
66
use rand::rngs::ThreadRng;
77
use rand::Rng;
88

9+
/// The set of valid hyper-parameters that can be specified for the fitting procedure of the
10+
/// [Ensemble Learner](crate::EnsembleLearner).
911
#[derive(Clone, Copy, Debug, PartialEq)]
1012
pub struct EnsembleLearnerValidParams<P, R> {
1113
/// The number of models in the ensemble
1214
pub ensemble_size: usize,
1315
/// The proportion of the total number of training samples that should be given to each model for training
1416
pub bootstrap_proportion: f64,
15-
/// The proportion of the total number of training feature that should be given to each model for training
17+
/// The proportion of the total number of training features that should be given to each model for training
1618
pub feature_proportion: f64,
1719
/// The model parameters for the base model
1820
pub model_params: P,
1921
pub rng: R,
2022
}
2123

24+
/// A helper struct for building a set of [Ensemble Learner](crate::EnsembleLearner) hyper-parameters.
2225
#[derive(Clone, Copy, Debug, PartialEq)]
2326
pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);
2427

28+
/// A helper struct for building a set of [Random Forest](crate::RandomForest) hyper-parameters.
2529
pub type RandomForestParams<F, L, R> = EnsembleLearnerParams<DecisionTreeParams<F, L>, R>;
2630

2731
impl<P> EnsembleLearnerParams<P, ThreadRng> {
@@ -41,16 +45,25 @@ impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
4145
})
4246
}
4347

48+
/// Specifies the number of models to fit in the ensemble.
4449
pub fn ensemble_size(mut self, size: usize) -> Self {
4550
self.0.ensemble_size = size;
4651
self
4752
}
4853

54+
/// Sets the proportion of the total number of training samples that should be given to each model for training
55+
///
56+
/// Note that the `proportion` should be in the interval (0, 1] in order to pass the
57+
/// parameter validation check.
4958
pub fn bootstrap_proportion(mut self, proportion: f64) -> Self {
5059
self.0.bootstrap_proportion = proportion;
5160
self
5261
}
5362

63+
/// Sets the proportion of the total number of training features that should be given to each model for training
64+
///
65+
/// Note that the `proportion` should be in the interval (0, 1] in order to pass the
66+
/// parameter validation check.
5467
pub fn feature_proportion(mut self, proportion: f64) -> Self {
5568
self.0.feature_proportion = proportion;
5669
self

algorithms/linfa-ensemble/src/lib.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
//! Ensemble methods combine the predictions of several base estimators built with a given
44
//! learning algorithm in order to improve generalizability / robustness over a single estimator.
55
//!
6+
//! This crate (`linfa-ensemble`), provides pure Rust implementations of popular ensemble techniques, such as
7+
//! * [Boostrap Aggregation](EnsembleLearner)
8+
//! * [Random Forest](RandomForest)
9+
//!
610
//! ## Bootstrap Aggregation (aka Bagging)
711
//!
812
//! A typical example of ensemble method is Bootstrap Aggregation, which combines the predictions of
9-
//! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset.
13+
//! several decision trees (see [`linfa-trees`](linfa_trees)) trained on different samples subset of the training dataset.
1014
//!
1115
//! ## Random Forest
1216
//!
13-
//! A special case of Bootstrap Aggregation using decision trees (see `linfa-trees`) with random feature
17+
//! A special case of Bootstrap Aggregation using decision trees (see [`linfa-trees`](linfa_trees)) with random feature
1418
//! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being
1519
//! the number of available features.
1620
//!
@@ -48,7 +52,7 @@
4852
//! let predictions = bagging_model.predict(&test);
4953
//! ```
5054
//!
51-
//! This example shows how to train a Random Forest model using 100 decision trees,
55+
//! This example shows how to train a [Random Forest](RandomForest) model using 100 decision trees,
5256
//! each trained on 70% of the training data (bootstrap sampling) and using only
5357
//! 30% of the available features.
5458
//!
@@ -66,15 +70,15 @@
6670
//! .split_with_ratio(0.8);
6771
//!
6872
//! // Train the model on the iris dataset
69-
//! let bagging_model = RandomForestParams::new(DecisionTree::params())
73+
//! let random_forest = RandomForestParams::new(DecisionTree::params())
7074
//! .ensemble_size(100) // Number of Decision Tree to fit
7175
//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
7276
//! .feature_proportion(0.3) // Select only 30% of the feature
7377
//! .fit(&train)
7478
//! .unwrap();
7579
//!
7680
//! // Make predictions on the test set
77-
//! let predictions = bagging_model.predict(&test);
81+
//! let predictions = random_forest.predict(&test);
7882
//! ```
7983
8084
mod algorithm;

0 commit comments

Comments
 (0)