Skip to content

Commit 590843f

Browse files
authored
feat(classical-ml): add random forest algorithm (#241)
* feat(classical-ml): add random forest algorithm * fix: clippy
1 parent 026011f commit 590843f

File tree

2 files changed

+170
-17
lines changed

2 files changed

+170
-17
lines changed

delta/src/classical_ml/algorithms.rs

+167-16
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ use std::fmt::Debug;
3232
use std::{collections::HashSet, ops::SubAssign};
3333

3434
use libm::log2;
35-
use ndarray::{Array1, Array2, ScalarOperand};
35+
use ndarray::{Array1, Array2, Axis, ScalarOperand};
3636
use num_traits::{Float, FromPrimitive};
37+
use rand::Rng;
3738

3839
use super::{Algorithm, batch_gradient_descent, logistic_gradient_descent, losses::Loss};
3940

@@ -268,6 +269,115 @@ where
268269
}
269270
}
270271

272+
/// A struct for performing Random Forest classification.
273+
///
274+
/// This implementation builds an ensemble of decision trees (each built via bootstrap sampling)
275+
/// and uses majority voting for prediction.
276+
pub struct RandomForest<T, L>
277+
where
278+
T: Float,
279+
L: Loss<T> + Clone, // clone required here to pass loss_function to each decision tree
280+
{
281+
trees: Vec<TreeNode<T, usize>>,
282+
n_trees: usize,
283+
sample_size: usize,
284+
max_depth: usize,
285+
min_loss: f64,
286+
loss_function: L,
287+
}
288+
289+
impl<T, L> RandomForest<T, L>
290+
where
291+
T: Float + FromPrimitive,
292+
L: Loss<T> + Clone,
293+
{
294+
fn predict_tree(&self, row: &Array1<T>, node: &TreeNode<T, usize>) -> usize {
295+
match node {
296+
TreeNode::Leaf { prediction, .. } => *prediction,
297+
TreeNode::Internal { feature, threshold, left, right } => {
298+
let f = feature.expect("Internal node must have a feature");
299+
let th = threshold.expect("Internal node must have a threshold");
300+
if row[f] < th {
301+
self.predict_tree(row, left.as_ref().expect("Left node missing"))
302+
} else {
303+
self.predict_tree(row, right.as_ref().expect("Right node missing"))
304+
}
305+
}
306+
}
307+
}
308+
309+
fn majority_vote(&self, votes: &[usize]) -> usize {
310+
let mut counts = HashMap::new();
311+
for &v in votes {
312+
*counts.entry(v).or_insert(0) += 1;
313+
}
314+
counts.into_iter().max_by_key(|&(_, count)| count).map(|(pred, _)| pred).unwrap()
315+
}
316+
}
317+
318+
impl<T, L> Algorithm<T, L> for RandomForest<T, L>
319+
where
320+
T: Debug + Float + FromPrimitive + SubAssign,
321+
L: Loss<T> + Clone,
322+
{
323+
/// Creates a new RandomForest instance.
324+
fn new(loss_function: L) -> Self {
325+
RandomForest {
326+
trees: Vec::new(),
327+
n_trees: 10,
328+
sample_size: 0,
329+
max_depth: 10,
330+
min_loss: 1e-6,
331+
loss_function,
332+
}
333+
}
334+
335+
/// Fits the RandomForest by training each decision tree on a bootstrap sample.
336+
/// If `sample_size` is 0 then we use all available samples.
337+
fn fit(&mut self, x: &Array2<T>, y: &Array1<T>, _learning_rate: T, _epochs: usize) {
338+
if self.sample_size == 0 {
339+
self.sample_size = x.shape()[0];
340+
}
341+
342+
self.trees.clear();
343+
let n_samples = x.shape()[0];
344+
let mut rng = rand::thread_rng();
345+
346+
for _ in 0..self.n_trees {
347+
let indices: Vec<usize> =
348+
(0..self.sample_size).map(|_| rng.gen_range(0..n_samples)).collect();
349+
350+
let x_bootstrap = x.select(Axis(0), &indices);
351+
let y_bootstrap = y
352+
.select(Axis(0), &indices)
353+
.mapv(|val| val.to_usize().expect("Failed to convert label to usize"));
354+
355+
let mut dtc = DecisionTreeClassifier::new(
356+
self.max_depth,
357+
self.min_loss,
358+
self.loss_function.clone(),
359+
);
360+
361+
dtc.fit(&x_bootstrap, &y_bootstrap.mapv(|v| T::from_usize(v).unwrap()), T::zero(), 1);
362+
363+
if let Some(root) = dtc.root {
364+
self.trees.push(root);
365+
}
366+
}
367+
}
368+
369+
fn predict(&self, x: &Array2<T>) -> Array1<T> {
370+
let mut predictions = Vec::with_capacity(x.shape()[0]);
371+
for row in x.outer_iter() {
372+
let votes: Vec<usize> =
373+
self.trees.iter().map(|tree| self.predict_tree(&row.to_owned(), tree)).collect();
374+
let majority = self.majority_vote(&votes);
375+
predictions.push(T::from_usize(majority).unwrap());
376+
}
377+
Array1::from(predictions)
378+
}
379+
}
380+
271381
/// Represents a node in a decision tree, which can be either an `Internal` node or a `Leaf` node at any given moment.
272382
///
273383
/// This enum is generic over two type parameters:
@@ -320,7 +430,7 @@ where
320430
/// # Arguments
321431
/// - `max_depth`: The maximum depth of the tree.
322432
/// - `loss_function`: The loss function to use.
323-
///S
433+
///
324434
/// # Returns
325435
/// A new instance of `DecisionTree`.
326436
pub fn new(max_depth: usize, min_loss: f64, loss_function: L) -> Self {
@@ -336,18 +446,17 @@ where
336446

337447
/// Recursively splits the data based on the best feature and threshold.
338448
fn build_tree(&mut self, node: &mut TreeNode<T, usize>, indices: Array1<usize>, depth: usize) {
339-
// println!("TOP: depth: {}, node {:?}", depth, node);
340-
// println!("indices len : {}", indices.len());
449+
if indices.is_empty() {
450+
*node = TreeNode::Leaf { prediction: 0, indices };
451+
return;
452+
}
453+
341454
if depth >= self.max_depth || indices.shape()[0] <= 1 {
342455
let prediction = self.calculate_leaf_prediction(&indices).unwrap();
343-
344-
// Update this node to Leaf Node
345456
*node = TreeNode::Leaf { prediction, indices };
346-
// println!("BOTTOM: depth: {}, node {:?}", depth, node);
347457
return;
348458
}
349459

350-
// Check If Pure If yes then assign this node as leaf
351460
let data_y_ref = self.data_y.as_ref().unwrap();
352461
let classes: HashSet<_> = indices.iter().map(|&idx| data_y_ref[idx]).collect();
353462
let mut class_counts: HashMap<usize, usize> =
@@ -360,14 +469,10 @@ where
360469

361470
if loss <= self.min_loss {
362471
let prediction = self.calculate_leaf_prediction(&indices).unwrap();
363-
364-
// Update this node to Leaf Node
365472
*node = TreeNode::Leaf { prediction, indices };
366-
// println!("BOTTOM: depth: {}, node {:?}", depth, node);
367473
return;
368474
}
369475

370-
// Main Decision Tree Algorithm
371476
let (best_feature, best_threshold) = self.find_best_split(&indices);
372477
let (index_left, index_right) = self.split_data(indices, best_feature, best_threshold);
373478

@@ -386,8 +491,6 @@ where
386491
left: Some(left_node),
387492
right: Some(right_node),
388493
};
389-
390-
// println!("BOTTOM: depth: {}, node {:?}", depth, node);
391494
}
392495

393496
fn calculate_leaf_prediction(&self, indices: &Array1<usize>) -> Option<usize> {
@@ -472,7 +575,7 @@ where
472575
fn calculate_entropy(class_counts: &HashMap<usize, usize>) -> f64 {
473576
let subset_size = class_counts.values().sum::<usize>() as f64;
474577
let entropy = class_counts
475-
.into_iter()
578+
.iter()
476579
.map(|(_, &count)| {
477580
if count == 0_usize {
478581
0 as f64
@@ -566,7 +669,7 @@ mod tests {
566669

567670
use crate::classical_ml::{
568671
Algorithm,
569-
algorithms::{LinearRegression, LogisticRegression},
672+
algorithms::{LinearRegression, LogisticRegression, RandomForest},
570673
losses::{CrossEntropy, MSE},
571674
};
572675

@@ -692,4 +795,52 @@ mod tests {
692795
let accuracy = correct_predictions as f64 / y_test.len() as f64;
693796
println!("Test accuracy: {:.2}%", accuracy * 100.0);
694797
}
798+
799+
#[test]
800+
fn test_random_forest_fit_and_predict() {
801+
let (train, test) = linfa_datasets::iris().split_with_ratio(0.8);
802+
803+
// Convert train data to ndarray format
804+
let x_train = Array2::from_shape_vec(
805+
(train.records().nrows(), train.records().ncols()),
806+
train.records().to_owned().into_raw_vec(),
807+
)
808+
.unwrap();
809+
810+
let y_train = Array1::from_shape_vec(
811+
train.targets().len(),
812+
train.targets().iter().map(|&x| x as f64).collect(),
813+
)
814+
.unwrap();
815+
816+
// Convert test data to ndarray format
817+
let x_test = Array2::from_shape_vec(
818+
(test.records().nrows(), test.records().ncols()),
819+
test.records().to_owned().into_raw_vec(),
820+
)
821+
.unwrap();
822+
823+
let y_test = Array1::from_shape_vec(
824+
test.targets().len(),
825+
test.targets().iter().map(|&x| x as f64).collect(),
826+
)
827+
.unwrap();
828+
829+
let mut model = RandomForest::new(CrossEntropy);
830+
831+
println!("Fitting the model...");
832+
model.fit(&x_train, &y_train, 0.1, 100);
833+
println!("Model fitted.");
834+
835+
let predictions = model.predict(&x_test);
836+
837+
// Calculate and print the accuracy
838+
let correct_predictions = predictions
839+
.iter()
840+
.zip(y_test.iter())
841+
.filter(|(&pred, &actual)| (pred - actual).abs() < 1e-6)
842+
.count();
843+
let accuracy = correct_predictions as f64 / y_test.len() as f64;
844+
println!("Test accuracy: {:.2}%", accuracy * 100.0);
845+
}
695846
}

delta/src/classical_ml/losses.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,17 @@ use ndarray::{Array1, ScalarOperand};
3333
use num_traits::Float;
3434

3535
/// A struct representing the Mean Squared Error (MSE) loss function.
36+
#[derive(Clone)]
3637
pub struct MSE;
3738

3839
/// A struct representing the Cross-Entropy loss function.
40+
#[derive(Clone)]
3941
pub struct CrossEntropy;
4042

4143
/// A struct representing the Entropy Loss for Decision Treee
44+
#[derive(Clone)]
4245
pub struct Entropy;
4346

44-
4547
/// A trait for loss functions, which calculates the error between predictions and actual values.
4648
pub trait Loss<T> {
4749
/// Calculates the loss value given the predicted values and the actual values.

0 commit comments

Comments
 (0)