@@ -32,8 +32,9 @@ use std::fmt::Debug;
32
32
use std:: { collections:: HashSet , ops:: SubAssign } ;
33
33
34
34
use libm:: log2;
35
- use ndarray:: { Array1 , Array2 , ScalarOperand } ;
35
+ use ndarray:: { Array1 , Array2 , Axis , ScalarOperand } ;
36
36
use num_traits:: { Float , FromPrimitive } ;
37
+ use rand:: Rng ;
37
38
38
39
use super :: { Algorithm , batch_gradient_descent, logistic_gradient_descent, losses:: Loss } ;
39
40
@@ -268,6 +269,115 @@ where
268
269
}
269
270
}
270
271
272
+ /// A struct for performing Random Forest classification.
273
+ ///
274
+ /// This implementation builds an ensemble of decision trees (each built via bootstrap sampling)
275
+ /// and uses majority voting for prediction.
276
+ pub struct RandomForest < T , L >
277
+ where
278
+ T : Float ,
279
+ L : Loss < T > + Clone , // clone required here to pass loss_function to each decision tree
280
+ {
281
+ trees : Vec < TreeNode < T , usize > > ,
282
+ n_trees : usize ,
283
+ sample_size : usize ,
284
+ max_depth : usize ,
285
+ min_loss : f64 ,
286
+ loss_function : L ,
287
+ }
288
+
289
+ impl < T , L > RandomForest < T , L >
290
+ where
291
+ T : Float + FromPrimitive ,
292
+ L : Loss < T > + Clone ,
293
+ {
294
+ fn predict_tree ( & self , row : & Array1 < T > , node : & TreeNode < T , usize > ) -> usize {
295
+ match node {
296
+ TreeNode :: Leaf { prediction, .. } => * prediction,
297
+ TreeNode :: Internal { feature, threshold, left, right } => {
298
+ let f = feature. expect ( "Internal node must have a feature" ) ;
299
+ let th = threshold. expect ( "Internal node must have a threshold" ) ;
300
+ if row[ f] < th {
301
+ self . predict_tree ( row, left. as_ref ( ) . expect ( "Left node missing" ) )
302
+ } else {
303
+ self . predict_tree ( row, right. as_ref ( ) . expect ( "Right node missing" ) )
304
+ }
305
+ }
306
+ }
307
+ }
308
+
309
+ fn majority_vote ( & self , votes : & [ usize ] ) -> usize {
310
+ let mut counts = HashMap :: new ( ) ;
311
+ for & v in votes {
312
+ * counts. entry ( v) . or_insert ( 0 ) += 1 ;
313
+ }
314
+ counts. into_iter ( ) . max_by_key ( |& ( _, count) | count) . map ( |( pred, _) | pred) . unwrap ( )
315
+ }
316
+ }
317
+
318
+ impl < T , L > Algorithm < T , L > for RandomForest < T , L >
319
+ where
320
+ T : Debug + Float + FromPrimitive + SubAssign ,
321
+ L : Loss < T > + Clone ,
322
+ {
323
+ /// Creates a new RandomForest instance.
324
+ fn new ( loss_function : L ) -> Self {
325
+ RandomForest {
326
+ trees : Vec :: new ( ) ,
327
+ n_trees : 10 ,
328
+ sample_size : 0 ,
329
+ max_depth : 10 ,
330
+ min_loss : 1e-6 ,
331
+ loss_function,
332
+ }
333
+ }
334
+
335
+ /// Fits the RandomForest by training each decision tree on a bootstrap sample.
336
+ /// If `sample_size` is 0 then we use all available samples.
337
+ fn fit ( & mut self , x : & Array2 < T > , y : & Array1 < T > , _learning_rate : T , _epochs : usize ) {
338
+ if self . sample_size == 0 {
339
+ self . sample_size = x. shape ( ) [ 0 ] ;
340
+ }
341
+
342
+ self . trees . clear ( ) ;
343
+ let n_samples = x. shape ( ) [ 0 ] ;
344
+ let mut rng = rand:: thread_rng ( ) ;
345
+
346
+ for _ in 0 ..self . n_trees {
347
+ let indices: Vec < usize > =
348
+ ( 0 ..self . sample_size ) . map ( |_| rng. gen_range ( 0 ..n_samples) ) . collect ( ) ;
349
+
350
+ let x_bootstrap = x. select ( Axis ( 0 ) , & indices) ;
351
+ let y_bootstrap = y
352
+ . select ( Axis ( 0 ) , & indices)
353
+ . mapv ( |val| val. to_usize ( ) . expect ( "Failed to convert label to usize" ) ) ;
354
+
355
+ let mut dtc = DecisionTreeClassifier :: new (
356
+ self . max_depth ,
357
+ self . min_loss ,
358
+ self . loss_function . clone ( ) ,
359
+ ) ;
360
+
361
+ dtc. fit ( & x_bootstrap, & y_bootstrap. mapv ( |v| T :: from_usize ( v) . unwrap ( ) ) , T :: zero ( ) , 1 ) ;
362
+
363
+ if let Some ( root) = dtc. root {
364
+ self . trees . push ( root) ;
365
+ }
366
+ }
367
+ }
368
+
369
+ fn predict ( & self , x : & Array2 < T > ) -> Array1 < T > {
370
+ let mut predictions = Vec :: with_capacity ( x. shape ( ) [ 0 ] ) ;
371
+ for row in x. outer_iter ( ) {
372
+ let votes: Vec < usize > =
373
+ self . trees . iter ( ) . map ( |tree| self . predict_tree ( & row. to_owned ( ) , tree) ) . collect ( ) ;
374
+ let majority = self . majority_vote ( & votes) ;
375
+ predictions. push ( T :: from_usize ( majority) . unwrap ( ) ) ;
376
+ }
377
+ Array1 :: from ( predictions)
378
+ }
379
+ }
380
+
271
381
/// Represents a node in a decision tree, which can be either an `Internal` node or a `Leaf` node at any given moment.
272
382
///
273
383
/// This enum is generic over two type parameters:
@@ -320,7 +430,7 @@ where
320
430
/// # Arguments
321
431
/// - `max_depth`: The maximum depth of the tree.
322
432
/// - `loss_function`: The loss function to use.
323
- ///S
433
+ ///
324
434
/// # Returns
325
435
/// A new instance of `DecisionTree`.
326
436
pub fn new ( max_depth : usize , min_loss : f64 , loss_function : L ) -> Self {
@@ -336,18 +446,17 @@ where
336
446
337
447
/// Recursively splits the data based on the best feature and threshold.
338
448
fn build_tree ( & mut self , node : & mut TreeNode < T , usize > , indices : Array1 < usize > , depth : usize ) {
339
- // println!("TOP: depth: {}, node {:?}", depth, node);
340
- // println!("indices len : {}", indices.len());
449
+ if indices. is_empty ( ) {
450
+ * node = TreeNode :: Leaf { prediction : 0 , indices } ;
451
+ return ;
452
+ }
453
+
341
454
if depth >= self . max_depth || indices. shape ( ) [ 0 ] <= 1 {
342
455
let prediction = self . calculate_leaf_prediction ( & indices) . unwrap ( ) ;
343
-
344
- // Update this node to Leaf Node
345
456
* node = TreeNode :: Leaf { prediction, indices } ;
346
- // println!("BOTTOM: depth: {}, node {:?}", depth, node);
347
457
return ;
348
458
}
349
459
350
- // Check If Pure If yes then assign this node as leaf
351
460
let data_y_ref = self . data_y . as_ref ( ) . unwrap ( ) ;
352
461
let classes: HashSet < _ > = indices. iter ( ) . map ( |& idx| data_y_ref[ idx] ) . collect ( ) ;
353
462
let mut class_counts: HashMap < usize , usize > =
@@ -360,14 +469,10 @@ where
360
469
361
470
if loss <= self . min_loss {
362
471
let prediction = self . calculate_leaf_prediction ( & indices) . unwrap ( ) ;
363
-
364
- // Update this node to Leaf Node
365
472
* node = TreeNode :: Leaf { prediction, indices } ;
366
- // println!("BOTTOM: depth: {}, node {:?}", depth, node);
367
473
return ;
368
474
}
369
475
370
- // Main Decision Tree Algorithm
371
476
let ( best_feature, best_threshold) = self . find_best_split ( & indices) ;
372
477
let ( index_left, index_right) = self . split_data ( indices, best_feature, best_threshold) ;
373
478
@@ -386,8 +491,6 @@ where
386
491
left : Some ( left_node) ,
387
492
right : Some ( right_node) ,
388
493
} ;
389
-
390
- // println!("BOTTOM: depth: {}, node {:?}", depth, node);
391
494
}
392
495
393
496
fn calculate_leaf_prediction ( & self , indices : & Array1 < usize > ) -> Option < usize > {
@@ -472,7 +575,7 @@ where
472
575
fn calculate_entropy ( class_counts : & HashMap < usize , usize > ) -> f64 {
473
576
let subset_size = class_counts. values ( ) . sum :: < usize > ( ) as f64 ;
474
577
let entropy = class_counts
475
- . into_iter ( )
578
+ . iter ( )
476
579
. map ( |( _, & count) | {
477
580
if count == 0_usize {
478
581
0 as f64
@@ -566,7 +669,7 @@ mod tests {
566
669
567
670
use crate :: classical_ml:: {
568
671
Algorithm ,
569
- algorithms:: { LinearRegression , LogisticRegression } ,
672
+ algorithms:: { LinearRegression , LogisticRegression , RandomForest } ,
570
673
losses:: { CrossEntropy , MSE } ,
571
674
} ;
572
675
@@ -692,4 +795,52 @@ mod tests {
692
795
let accuracy = correct_predictions as f64 / y_test. len ( ) as f64 ;
693
796
println ! ( "Test accuracy: {:.2}%" , accuracy * 100.0 ) ;
694
797
}
798
+
799
+ #[ test]
800
+ fn test_random_forest_fit_and_predict ( ) {
801
+ let ( train, test) = linfa_datasets:: iris ( ) . split_with_ratio ( 0.8 ) ;
802
+
803
+ // Convert train data to ndarray format
804
+ let x_train = Array2 :: from_shape_vec (
805
+ ( train. records ( ) . nrows ( ) , train. records ( ) . ncols ( ) ) ,
806
+ train. records ( ) . to_owned ( ) . into_raw_vec ( ) ,
807
+ )
808
+ . unwrap ( ) ;
809
+
810
+ let y_train = Array1 :: from_shape_vec (
811
+ train. targets ( ) . len ( ) ,
812
+ train. targets ( ) . iter ( ) . map ( |& x| x as f64 ) . collect ( ) ,
813
+ )
814
+ . unwrap ( ) ;
815
+
816
+ // Convert test data to ndarray format
817
+ let x_test = Array2 :: from_shape_vec (
818
+ ( test. records ( ) . nrows ( ) , test. records ( ) . ncols ( ) ) ,
819
+ test. records ( ) . to_owned ( ) . into_raw_vec ( ) ,
820
+ )
821
+ . unwrap ( ) ;
822
+
823
+ let y_test = Array1 :: from_shape_vec (
824
+ test. targets ( ) . len ( ) ,
825
+ test. targets ( ) . iter ( ) . map ( |& x| x as f64 ) . collect ( ) ,
826
+ )
827
+ . unwrap ( ) ;
828
+
829
+ let mut model = RandomForest :: new ( CrossEntropy ) ;
830
+
831
+ println ! ( "Fitting the model..." ) ;
832
+ model. fit ( & x_train, & y_train, 0.1 , 100 ) ;
833
+ println ! ( "Model fitted." ) ;
834
+
835
+ let predictions = model. predict ( & x_test) ;
836
+
837
+ // Calculate and print the accuracy
838
+ let correct_predictions = predictions
839
+ . iter ( )
840
+ . zip ( y_test. iter ( ) )
841
+ . filter ( |( & pred, & actual) | ( pred - actual) . abs ( ) < 1e-6 )
842
+ . count ( ) ;
843
+ let accuracy = correct_predictions as f64 / y_test. len ( ) as f64 ;
844
+ println ! ( "Test accuracy: {:.2}%" , accuracy * 100.0 ) ;
845
+ }
695
846
}
0 commit comments