Skip to content

Commit 76ed2bf

Browse files
authored
docs(classical-ml): added random forest example (#242)
1 parent 590843f commit 76ed2bf

File tree

5 files changed

+74
-5
lines changed

5 files changed

+74
-5
lines changed

.mailmap

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
#
77

88
Marcus Cvjeticanin <[email protected]>
9-
Chase Willden <[email protected]>
9+
Chase Willden <[email protected]>
10+
Joseph "Jojo" Sutton <[email protected]>

Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ members = [
77
"examples/deep_learning/mnist",
88
"examples/deep_learning/imagenet_v2",
99
"examples/classical_ml/linear_regression",
10-
"examples/classical_ml/logistic_regression",
10+
"examples/classical_ml/logistic_regression", "examples/classical_ml/random_forest",
1111
]
1212
resolver = "2"
1313

@@ -33,4 +33,4 @@ codegen-units = 1
3333

3434
[profile.bench]
3535
opt-level = 3
36-
debug = false
36+
debug = false

delta/src/classical_ml/algorithms.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ where
313313
}
314314
counts.into_iter().max_by_key(|&(_, count)| count).map(|(pred, _)| pred).unwrap()
315315
}
316+
317+
pub fn calculate_loss(&self, predictions: &Array1<T>, actuals: &Array1<T>) -> T {
318+
self.loss_function.calculate(predictions, actuals)
319+
}
316320
}
317321

318322
impl<T, L> Algorithm<T, L> for RandomForest<T, L>
@@ -828,9 +832,7 @@ mod tests {
828832

829833
let mut model = RandomForest::new(CrossEntropy);
830834

831-
println!("Fitting the model...");
832835
model.fit(&x_train, &y_train, 0.1, 100);
833-
println!("Model fitted.");
834836

835837
let predictions = model.predict(&x_test);
836838

@@ -840,7 +842,13 @@ mod tests {
840842
.zip(y_test.iter())
841843
.filter(|(&pred, &actual)| (pred - actual).abs() < 1e-6)
842844
.count();
845+
843846
let accuracy = correct_predictions as f64 / y_test.len() as f64;
847+
let loss = model.calculate_loss(&predictions, &y_test);
848+
844849
println!("Test accuracy: {:.2}%", accuracy * 100.0);
850+
851+
assert!(accuracy > 0.0, "Accuracy should be positive, got: {}", accuracy);
852+
assert!(loss < 0.0, "Loss should be less than 0, got: {}", loss);
845853
}
846854
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[package]
2+
name = "random_forest"
3+
version = "0.1.0"
4+
edition = "2024"
5+
publish = false
6+
7+
[dependencies]
8+
deltaml = { path = "../../../delta", features = ["classical_ml"] }
9+
tokio = { workspace = true, features = ["full"] }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use deltaml::{
2+
classical_ml::{Algorithm, algorithms::RandomForest, losses::MSE},
3+
ndarray::{Array1, Array2},
4+
};
5+
6+
#[tokio::main]
7+
async fn main() {
8+
// create feature data
9+
let x_data = Array2::from_shape_vec(
10+
(4, 4),
11+
vec![
12+
25.0, // 25 years old
13+
30.0, // 30 years old
14+
22.0, // 22 years old
15+
40.0, // 40 years old
16+
0.0, // male
17+
1.0, // female
18+
0.0, // male
19+
1.0, // female
20+
1.0, // engineer
21+
2.0, // doctor
22+
3.0, // student
23+
4.0, // manager
24+
60000.0, // 60000.0 salary
25+
80000.0, // 80000.0 salary
26+
20000.0, // 20000.0 salary
27+
100000.0, // 100000.0 salary
28+
],
29+
)
30+
.unwrap();
31+
32+
let y_data = Array1::from_vec(vec![1.0, 1.0, 0.0, 1.0]); // high, high, low, high
33+
34+
// Instantiate the model
35+
let mut model = RandomForest::new(MSE);
36+
37+
// Train the model
38+
let learning_rate = 0.01;
39+
let epochs = 1000;
40+
model.fit(&x_data, &y_data, learning_rate, epochs);
41+
42+
// Make predictions with the trained model
43+
let new_data = Array2::from_shape_vec((4, 1), vec![28.0, 1.0, 1.0, 70000.0]).unwrap();
44+
let predictions = model.predict(&new_data);
45+
46+
println!("Predictions for new data: {:?}", predictions);
47+
48+
// Calculate accuracy or loss for the test data for demonstration
49+
let test_loss = model.calculate_loss(&model.predict(&x_data), &y_data);
50+
println!("Test Loss after training: {:.6}", test_loss);
51+
}

0 commit comments

Comments
 (0)