Skip to content

Commit 35b425b

Browse files
cargo fmt
1 parent 684fc49 commit 35b425b

File tree

4 files changed

+10
-20
lines changed

4 files changed

+10
-20
lines changed

algorithms/linfa-trees/examples/iris_random_forest.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
1010
let mut rng = thread_rng();
1111

1212
// 1. Load, shuffle, and split the Iris dataset (80% train, 20% valid)
13-
let (train, valid) = iris()
14-
.shuffle(&mut rng)
15-
.split_with_ratio(0.8);
13+
let (train, valid) = iris().shuffle(&mut rng).split_with_ratio(0.8);
1614

1715
// 2. Single‐tree baseline
1816
let dt_model = DecisionTree::params()
19-
.max_depth(None) // no depth limit
17+
.max_depth(None) // no depth limit
2018
.fit(&train)?;
2119
let dt_preds = dt_model.predict(valid.records.clone());
2220
let dt_cm = dt_preds.confusion_matrix(&valid)?;
@@ -26,7 +24,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
2624
let rf_model = RandomForestParams::new(50)
2725
.max_depth(Some(5))
2826
.feature_subsample(0.7)
29-
.seed(42) // fix RNG seed for reproducibility
27+
.seed(42) // fix RNG seed for reproducibility
3028
.fit(&train)?;
3129
let rf_preds = rf_model.predict(valid.records.clone());
3230
let rf_cm = rf_preds.confusion_matrix(&valid)?;

algorithms/linfa-trees/src/decision_trees/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ pub use algorithm::*;
77
pub use hyperparams::*;
88
pub use iter::*;
99
pub use tikz::*;
10-
pub mod random_forest;
10+
pub mod random_forest;

algorithms/linfa-trees/src/decision_trees/random_forest.rs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,10 @@ fn bootstrap<F: Float>(
9898
Dataset::new(rec, tgt)
9999
}
100100

101-
impl<F: Float + Send + Sync> Fit<Array2<F>, Array1<usize>, Error>
102-
for RandomForestValidParams<F>
103-
{
101+
impl<F: Float + Send + Sync> Fit<Array2<F>, Array1<usize>, Error> for RandomForestValidParams<F> {
104102
type Object = RandomForestClassifier<F>;
105103

106-
fn fit(
107-
&self,
108-
dataset: &DatasetBase<Array2<F>, Array1<usize>>,
109-
) -> Result<Self::Object, Error> {
104+
fn fit(&self, dataset: &DatasetBase<Array2<F>, Array1<usize>>) -> Result<Self::Object, Error> {
110105
let mut rng = StdRng::seed_from_u64(self.seed);
111106
let mut trees = Vec::with_capacity(self.n_trees);
112107
let mut feats_list = Vec::with_capacity(self.n_trees);
@@ -141,9 +136,7 @@ impl<F: Float + Send + Sync> Fit<Array2<F>, Array1<usize>, Error>
141136
}
142137
}
143138

144-
impl<F: Float> Predict<Array2<F>, Array1<usize>>
145-
for RandomForestClassifier<F>
146-
{
139+
impl<F: Float> Predict<Array2<F>, Array1<usize>> for RandomForestClassifier<F> {
147140
fn predict(&self, x: Array2<F>) -> Array1<usize> {
148141
let n = x.nrows();
149142
// adjust 100 to the expected number of classes if known
@@ -162,7 +155,8 @@ impl<F: Float> Predict<Array2<F>, Array1<usize>>
162155
Array1::from(
163156
(0..n)
164157
.map(|i| {
165-
votes.iter()
158+
votes
159+
.iter()
166160
.enumerate()
167161
.max_by_key(|(_, v)| v[i])
168162
.map(|(lbl, _)| lbl)

algorithms/linfa-trees/tests/random_forest.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ use rand::SeedableRng;
1010
fn iris_random_forest_high_accuracy() {
1111
// reproducible split
1212
let mut rng = StdRng::seed_from_u64(42);
13-
let (train, valid) = iris()
14-
.shuffle(&mut rng)
15-
.split_with_ratio(0.8);
13+
let (train, valid) = iris().shuffle(&mut rng).split_with_ratio(0.8);
1614

1715
let model = RandomForestParams::new(100)
1816
.max_depth(Some(10))

0 commit comments

Comments
 (0)