Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fb7a1fa
Add serialization for LogisticRegression
levkk Oct 6, 2022
0930322
Merge pull request #1 from postgresml/levkk-add-ser-for-logistic
levkk Oct 6, 2022
378c3b4
Serialization for multi-class
levkk Oct 6, 2022
ecd48d0
Merge pull request #2 from postgresml/levkk-fix-missing-ser
levkk Oct 6, 2022
fe3ae53
Float type restriction with handwritten bounds
gkobeaga Oct 6, 2022
9392fe6
Merge pull request #3 from gkobeaga/serde-logistic
levkk Oct 6, 2022
3fa43a8
Merge branch 'rust-ml:master' into master
levkk Oct 16, 2022
c44940b
Confusion matrix should use labels from predictions and ground truth
levkk Oct 17, 2022
4057c2d
Merge pull request #4 from postgresml/levkk-f1-division-by-zero
levkk Oct 17, 2022
d91de55
Clippy fixes
levkk Oct 17, 2022
3356d42
Merge pull request #5 from postgresml/levkk-fix-f1-metric
levkk Oct 17, 2022
4ac3ec8
This is the correct test
levkk Oct 17, 2022
3dd71b1
Merge pull request #6 from postgresml/levkk-fix-test-not-sure
levkk Oct 17, 2022
1e8ac38
Merge branch 'rust-ml:master' into master
montanalow Jun 6, 2023
ef0a23a
Merge branch 'rust-ml:master' into master
montanalow Jul 18, 2023
01c8224
Merge branch 'rust-ml:master' into master
levkk Nov 2, 2023
4004fec
Merge branch 'rust-ml:master' into master
montanalow Jan 11, 2025
7dee254
fix warnings
montanalow Jan 11, 2025
e9904a8
remove lifetimes
montanalow Jan 11, 2025
4f8ccef
clippy lints
montanalow Jan 11, 2025
5ec7b2f
fix ownership
montanalow Jan 11, 2025
97d52e7
fix ownership
montanalow Jan 11, 2025
d4a5744
cleanup lints
montanalow Jan 11, 2025
9d615fc
Merge pull request #7 from postgresml/montana/a
montanalow Jan 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ approx = "0.4"

ndarray = { version = "0.15", features = ["approx"] }
ndarray-linalg = { version = "0.16", optional = true }
sprs = { version = "0.11", default-features = false }
sprs = { version = "=0.11.1", default-features = false }

thiserror = "1.0"

Expand Down
2 changes: 1 addition & 1 deletion algorithms/linfa-ftrl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ version = "1.0"
features = ["derive"]

[dependencies]
ndarray = { version = "0.15.4", features = ["serde"] }
ndarray = { version = "0.15", features = ["serde"] }
ndarray-rand = "0.14.0"
argmin = { version = "0.9.0", default-features = false }
argmin-math = { version = "0.3", features = ["ndarray_v0_15-nolinalg"] }
Expand Down
2 changes: 1 addition & 1 deletion algorithms/linfa-kernel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ features = ["std", "derive"]
[dependencies]
ndarray = "0.15"
num-traits = "0.2"
sprs = { version="0.11", default-features = false }
sprs = { version="=0.11.1", default-features = false }

linfa = { version = "0.7.0", path = "../.." }
linfa-nn = { version = "0.7.0", path = "../linfa-nn" }
2 changes: 1 addition & 1 deletion algorithms/linfa-preprocessing/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ndarray-rand = { version = "0.14" }
unicode-normalization = "0.1.8"
regex = "1.4.5"
encoding = "0.2"
sprs = { version = "0.11.0", default-features = false }
sprs = { version = "=0.11.1", default-features = false }

serde_regex = { version = "1.1", optional = true }

Expand Down
3 changes: 1 addition & 2 deletions src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl<F: Float> PearsonCorrelation<F> {
///
/// * `dataset`: Data for the correlation analysis
/// * `num_iter`: optionally number of iterations of the p-value test, if none then no p-value
/// are calculate
/// are calculated
///
/// # Example
///
Expand All @@ -153,7 +153,6 @@ impl<F: Float> PearsonCorrelation<F> {
/// lamotrigine +0.47 (0.14)
/// blood sugar level
/// ```

pub fn from_dataset<D: Data<Elem = F>, T>(
dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
num_iter: Option<usize>,
Expand Down
20 changes: 10 additions & 10 deletions src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ where
/// println!("{} => {}", x, y);
/// }
/// ```
pub fn sample_iter(&'a self) -> Iter<'a, '_, F, T::Elem, T::Ix> {
pub fn sample_iter(&'a self) -> Iter<'a, 'a, F, T::Elem, T::Ix> {
Iter::new(self.records.view(), self.targets.as_targets())
}
}
Expand All @@ -232,7 +232,7 @@ where
///
/// This iterator produces dataset views with only a single feature, while the set of targets remain
/// complete. It can be useful to compare each feature individual to all targets.
pub fn feature_iter(&'a self) -> DatasetIter<'a, '_, ArrayBase<D, Ix2>, T> {
pub fn feature_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase<D, Ix2>, T> {
DatasetIter::new(self, true)
}

Expand All @@ -241,7 +241,7 @@ where
/// This functions creates an iterator which produces dataset views complete records, but only
/// a single target each. Useful to train multiple single target models for a multi-target
/// dataset.
pub fn target_iter(&'a self) -> DatasetIter<'a, '_, ArrayBase<D, Ix2>, T> {
pub fn target_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase<D, Ix2>, T> {
DatasetIter::new(self, false)
}
}
Expand Down Expand Up @@ -318,7 +318,7 @@ impl<L: Label, T: Labels<Elem = L>, R: Records> Labels for DatasetBase<R, T> {
}

#[allow(clippy::type_complexity)]
impl<'a, 'b: 'a, F, L: Label, T, D> DatasetBase<ArrayBase<D, Ix2>, T>
impl<F, L: Label, T, D> DatasetBase<ArrayBase<D, Ix2>, T>
where
D: Data<Elem = F>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
Expand Down Expand Up @@ -680,8 +680,8 @@ where
/// - `k`: the number of folds to apply to the dataset
/// - `params`: the desired parameters for the fittable algorithm at hand
/// - `fit_closure`: a closure of the type `(params, training_data) -> fitted_model`
/// that will be used to produce the trained model for each fold. The training data given in input
/// won't outlive the closure.
/// that will be used to produce the trained model for each fold. The training data given in
/// input won't outlive the closure.
///
/// ## Returns
///
Expand Down Expand Up @@ -732,7 +732,7 @@ where
&'a mut self,
k: usize,
fit_closure: C,
) -> impl Iterator<Item = (O, DatasetBase<ArrayView2<F>, ArrayView<E, I>>)> {
) -> impl Iterator<Item = (O, DatasetBase<ArrayView2<'a, F>, ArrayView<'a, E, I>>)> {
assert!(k > 0);
assert!(k <= self.nsamples());
let samples_count = self.nsamples();
Expand Down Expand Up @@ -794,9 +794,9 @@ where
/// - `k`: the number of folds to apply
/// - `parameters`: a list of models to compare
/// - `eval`: closure used to evaluate the performance of each trained model. This closure is
/// called on the model output and validation targets of each fold and outputs the performance
/// score for each target. For single-target dataset the signature is `(Array1, Array1) ->
/// Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`.
/// called on the model output and validation targets of each fold and outputs the performance
/// score for each target. For single-target dataset the signature is `(Array1, Array1) ->
/// Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`.
///
/// ### Returns
///
Expand Down
16 changes: 14 additions & 2 deletions src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl Deref for Pr {
/// # Fields
///
/// * `records`: a two-dimensional matrix with dimensionality (nsamples, nfeatures), in case of
/// kernel methods a quadratic matrix with dimensionality (nsamples, nsamples), which may be sparse
/// kernel methods a quadratic matrix with dimensionality (nsamples, nsamples), which may be sparse
/// * `targets`: a two-/one-dimension matrix with dimensionality (nsamples, ntargets)
/// * `weights`: optional weights for each sample with dimensionality (nsamples)
/// * `feature_names`: optional descriptive feature names with dimensionality (nfeatures)
Expand All @@ -170,7 +170,7 @@ impl Deref for Pr {
///
/// * `R: Records`: generic over feature matrices or kernel matrices
/// * `T`: generic over any `ndarray` matrix which can be used as targets. The `AsTargets` trait
/// bound is omitted here to avoid some repetition in implementation `src/dataset/impl_dataset.rs`
/// bound is omitted here to avoid some repetition in implementation `src/dataset/impl_dataset.rs`
#[derive(Debug, Clone, PartialEq)]
pub struct DatasetBase<R, T>
where
Expand Down Expand Up @@ -324,6 +324,18 @@ pub trait Labels {
fn labels(&self) -> Vec<Self::Elem> {
self.label_set().into_iter().flatten().collect()
Copy link
Collaborator

@YuhanLiin YuhanLiin Oct 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason this method doesn't dedup the final vector. It should do something like union all HashSet together. Or we can just change the return type to HashSet, but that might be too invasive.

}

fn combined_labels(&self, other: Vec<Self::Elem>) -> Vec<Self::Elem> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to have this method take &impl Labels or &Self as input. Then you can call label_set on both self and the input and union all the hashsets before converting it into a Vec.

let mut combined = self.labels();
combined.extend(other);

combined
.iter()
.collect::<HashSet<_>>()
.into_iter()
.cloned()
.collect()
}
}

#[cfg(test)]
Expand Down
13 changes: 12 additions & 1 deletion src/metrics_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ where
return Err(Error::MismatchedShapes(targets.len(), ground_truth.len()));
}

let classes = self.labels();
let classes = self.combined_labels(ground_truth.labels());

let indices = map_prediction_to_idx(
targets.as_slice().unwrap(),
Expand Down Expand Up @@ -636,6 +636,17 @@ mod tests {
);
}

#[test]
fn test_division_by_zero_cm() {
let ground_truth = Array1::from(vec![1, 1, 0, 1, 0, 1]);
let predicted = Array1::from(vec![0, 0, 0, 0, 0, 0]);

let x = ground_truth.confusion_matrix(predicted).unwrap();
let f1 = x.f1_score();

assert_eq!(f1, 0.5);
}

#[test]
fn test_roc_curve() {
let predicted = ArrayView1::from(&[0.1, 0.3, 0.5, 0.7, 0.8, 0.9]).mapv(Pr::new);
Expand Down
3 changes: 1 addition & 2 deletions src/metrics_clustering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ impl<F: Float> DistanceCount<F> {
}

impl<
'a,
F: Float,
L: 'a + Label,
L: Label,
D: Data<Elem = F>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
> SilhouetteScore<F> for DatasetBase<ArrayBase<D, Ix2>, T>
Expand Down
Loading