Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/codequality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
needs: codequality
name: coverage
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false && (github.event_name == 'pull_request' || github.ref == 'refs/heads/master')
if: github.event_name == 'pull_request' || github.ref == 'refs/heads/master'

steps:
- name: Checkout sources
Expand Down
6 changes: 4 additions & 2 deletions algorithms/linfa-preprocessing/src/linear_scaling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
/// Substitutes the records of the dataset with their scaled version.
/// Panics if the shape of the records is not compatible with the shape of the dataset used for fitting.
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let feature_names = x.feature_names().to_vec();
let target_names = x.target_names().to_vec();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_feature_names(feature_names)
.with_target_names(target_names)
}
}

Expand Down Expand Up @@ -575,7 +577,7 @@ mod tests {
#[test]
fn test_retain_feature_names() {
let dataset = linfa_datasets::diabetes();
let original_feature_names = dataset.feature_names();
let original_feature_names = dataset.feature_names().to_vec();
let transformed = LinearScaler::standard()
.fit(&dataset)
.unwrap()
Expand Down
6 changes: 4 additions & 2 deletions algorithms/linfa-preprocessing/src/norm_scaling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
{
/// Substitutes the records of the dataset with their scaled versions with unit norm.
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let feature_names = x.feature_names().to_vec();
let target_names = x.target_names().to_vec();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_feature_names(feature_names)
.with_target_names(target_names)
}
}

Expand Down Expand Up @@ -160,7 +162,7 @@ mod tests {
#[test]
fn test_retain_feature_names() {
let dataset = linfa_datasets::diabetes();
let original_feature_names = dataset.feature_names();
let original_feature_names = dataset.feature_names().to_vec();
let transformed = NormScaler::l2().transform(dataset);
assert_eq!(original_feature_names, transformed.feature_names())
}
Expand Down
6 changes: 4 additions & 2 deletions algorithms/linfa-preprocessing/src/whitening.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
for FittedWhitener<F>
{
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let feature_names = x.feature_names().to_vec();
let target_names = x.target_names().to_vec();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_feature_names(feature_names)
.with_target_names(target_names)
}
}

Expand Down Expand Up @@ -334,7 +336,7 @@ mod tests {
#[test]
fn test_retain_feature_names() {
let dataset = linfa_datasets::diabetes();
let original_feature_names = dataset.feature_names();
let original_feature_names = dataset.feature_names().to_vec();
let transformed = Whitener::cholesky()
.fit(&dataset)
.unwrap()
Expand Down
8 changes: 7 additions & 1 deletion algorithms/linfa-trees/src/decision_trees/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,13 @@
/// a matrix of features `x` and an array of labels `y`.
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
let x = dataset.records();
let feature_names = dataset.feature_names();
let feature_names = if dataset.feature_names().is_empty() {
(0..x.nfeatures())

Check warning on line 527 in algorithms/linfa-trees/src/decision_trees/algorithm.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/algorithm.rs#L527

Added line #L527 was not covered by tests
.map(|idx| format!("feature-{idx}"))
.collect()
} else {
dataset.feature_names().to_vec()

Check warning on line 531 in algorithms/linfa-trees/src/decision_trees/algorithm.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-trees/src/decision_trees/algorithm.rs#L531

Added line #L531 was not covered by tests
};
let all_idxs = RowMask::all(x.nrows());
let sorted_indices: Vec<_> = (0..(x.ncols()))
.map(|feature_idx| {
Expand Down
9 changes: 8 additions & 1 deletion datasets/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,11 @@ pub fn linnerud() -> Dataset<f64, f64> {
let output_array = array_from_gz_csv(&output_data[..], true, b',').unwrap();

let feature_names = vec!["Chins", "Situps", "Jumps"];
let target_names = vec!["Weight", "Waist", "Pulse"];

Dataset::new(input_array, output_array).with_feature_names(feature_names)
Dataset::new(input_array, output_array)
.with_feature_names(feature_names)
.with_target_names(target_names)
}

#[cfg(test)]
Expand Down Expand Up @@ -261,6 +264,10 @@ mod tests {
let feature_names = vec!["Chins", "Situps", "Jumps"];
assert_eq!(ds.feature_names(), feature_names);

// check for target names
let target_names = vec!["Weight", "Waist", "Pulse"];
assert_eq!(ds.target_names(), target_names);

// get the mean per target: Weight, Waist, Pulse
let mean_targets = ds.targets().mean_axis(Axis(0)).unwrap();
assert_abs_diff_eq!(mean_targets, array![178.6, 35.4, 56.1]);
Expand Down
2 changes: 1 addition & 1 deletion src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@
PearsonCorrelation {
pearson_coeffs,
p_values,
feature_names: dataset.feature_names(),
feature_names: dataset.feature_names().to_vec(),

Check warning on line 172 in src/correlation.rs

View check run for this annotation

Codecov / codecov/patch

src/correlation.rs#L172

Added line #L172 was not covered by tests
}
}

Expand Down
70 changes: 51 additions & 19 deletions src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
targets,
weights: Array1::zeros(0),
feature_names: Vec::new(),
target_names: Vec::new(),

Check warning on line 33 in src/dataset/impl_dataset.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/impl_dataset.rs#L33

Added line #L33 was not covered by tests
}
}

Expand Down Expand Up @@ -60,14 +61,8 @@
/// A feature name gives a human-readable string describing the purpose of a single feature.
/// This allow the reader to understand its purpose while analysing results, for example
/// correlation analysis or feature importance.
pub fn feature_names(&self) -> Vec<String> {
if !self.feature_names.is_empty() {
self.feature_names.clone()
} else {
(0..self.records.nfeatures())
.map(|idx| format!("feature-{idx}"))
.collect()
}
pub fn feature_names(&self) -> &[String] {
&self.feature_names

Check warning on line 65 in src/dataset/impl_dataset.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/impl_dataset.rs#L64-L65

Added lines #L64 - L65 were not covered by tests
}

/// Return records of a dataset
Expand All @@ -81,13 +76,14 @@
/// Updates the records of a dataset
///
/// This function overwrites the records in a dataset. It also invalidates the weights and
/// feature names.
/// feature/target names.
pub fn with_records<T: Records>(self, records: T) -> DatasetBase<T, S> {
DatasetBase {
records,
targets: self.targets,
weights: Array1::zeros(0),
feature_names: Vec::new(),
target_names: Vec::new(),

Check warning on line 86 in src/dataset/impl_dataset.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/impl_dataset.rs#L86

Added line #L86 was not covered by tests
}
}

Expand All @@ -100,6 +96,7 @@
targets,
weights: self.weights,
feature_names: self.feature_names,
target_names: self.target_names,

Check warning on line 99 in src/dataset/impl_dataset.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/impl_dataset.rs#L99

Added line #L99 was not covered by tests
}
}

Expand All @@ -111,11 +108,14 @@
}

/// Updates the feature names of a dataset
///
/// **Panics** when given names not empty and length does not equal to the number of features
pub fn with_feature_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, S> {
let feature_names = names.into_iter().map(|x| x.into()).collect();

self.feature_names = feature_names;

assert!(
names.is_empty() || names.len() == self.nfeatures(),
"Wrong number of feature names"

Check warning on line 116 in src/dataset/impl_dataset.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/impl_dataset.rs#L115-L116

Added lines #L115 - L116 were not covered by tests
);
self.feature_names = names.into_iter().map(|x| x.into()).collect();
self
}
}
Expand All @@ -131,6 +131,18 @@
}

impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
/// Updates the target names of a dataset
///
/// **Panics** when given names not empty and length does not equal to the number of targets
pub fn with_target_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, T> {
assert!(
names.is_empty() || names.len() == self.ntargets(),
"Wrong number of target names"

Check warning on line 140 in src/dataset/impl_dataset.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/impl_dataset.rs#L139-L140

Added lines #L139 - L140 were not covered by tests
);
self.target_names = names.into_iter().map(|x| x.into()).collect();
self
}

/// Map targets with a function `f`
///
/// # Example
Expand All @@ -153,6 +165,7 @@
targets,
weights,
feature_names,
target_names,
..
} = self;

Expand All @@ -163,9 +176,17 @@
targets: targets.map(fnc),
weights,
feature_names,
target_names,
}
}

/// Returns target names
///
/// A target name gives a human-readable string describing the purpose of a single target.
pub fn target_names(&self) -> &[String] {
&self.target_names

Check warning on line 187 in src/dataset/impl_dataset.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/impl_dataset.rs#L186-L187

Added lines #L186 - L187 were not covered by tests
}

/// Return the number of targets in the dataset
///
/// # Example
Expand Down Expand Up @@ -217,6 +238,7 @@
where
D: Data<Elem = F>,
T: AsTargets<Elem = L> + FromTargetArray<'a>,
T::View: AsTargets<Elem = L>,
{
/// Creates a view of a dataset
pub fn view(&'a self) -> DatasetBase<ArrayView2<'a, F>, T::View> {
Expand All @@ -226,6 +248,7 @@
DatasetBase::new(records, targets)
.with_feature_names(self.feature_names.clone())
.with_weights(self.weights.clone())
.with_target_names(self.target_names.clone())
}

/// Iterate over features
Expand Down Expand Up @@ -268,6 +291,7 @@
impl<'a, L: 'a, F, T> DatasetBase<ArrayView2<'a, F>, T>
where
T: AsTargets<Elem = L> + FromTargetArray<'a>,
T::View: AsTargets<Elem = L>,
{
/// Split dataset into two disjoint chunks
///
Expand Down Expand Up @@ -299,11 +323,13 @@
};
let dataset1 = DatasetBase::new(records_first, targets_first)
.with_weights(first_weights)
.with_feature_names(self.feature_names.clone());
.with_feature_names(self.feature_names.clone())
.with_target_names(self.target_names.clone());

let dataset2 = DatasetBase::new(records_second, targets_second)
.with_weights(second_weights)
.with_feature_names(self.feature_names.clone());
.with_feature_names(self.feature_names.clone())
.with_target_names(self.target_names.clone());

(dataset1, dataset2)
}
Expand Down Expand Up @@ -349,7 +375,8 @@
label,
DatasetBase::new(self.records().view(), targets)
.with_feature_names(self.feature_names.clone())
.with_weights(self.weights.clone()),
.with_weights(self.weights.clone())
.with_target_names(self.target_names.clone()),
)
})
.collect())
Expand Down Expand Up @@ -405,6 +432,7 @@
targets: empty_targets,
weights: Array1::zeros(0),
feature_names: Vec::new(),
target_names: Vec::new(),

Check warning on line 435 in src/dataset/impl_dataset.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/impl_dataset.rs#L435

Added line #L435 was not covered by tests
}
}
}
Expand All @@ -421,6 +449,7 @@
targets: rec_tar.1,
weights: Array1::zeros(0),
feature_names: Vec::new(),
target_names: Vec::new(),

Check warning on line 452 in src/dataset/impl_dataset.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/impl_dataset.rs#L452

Added line #L452 was not covered by tests
}
}
}
Expand Down Expand Up @@ -957,7 +986,8 @@
let n1 = (self.nsamples() as f32 * ratio).ceil() as usize;
let n2 = self.nsamples() - n1;

let feature_names = self.feature_names();
let feature_names = self.feature_names().to_vec();
let target_names = self.target_names().to_vec();

Check warning on line 990 in src/dataset/impl_dataset.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/impl_dataset.rs#L990

Added line #L990 was not covered by tests

// split records into two disjoint arrays
let mut array_buf = self.records.into_raw_vec();
Expand Down Expand Up @@ -990,10 +1020,12 @@
// create new datasets with attached weights
let dataset1 = Dataset::new(first, first_targets)
.with_weights(self.weights)
.with_feature_names(feature_names.clone());
.with_feature_names(feature_names.clone())
.with_target_names(target_names.clone());
let dataset2 = Dataset::new(second, second_targets)
.with_weights(second_weights)
.with_feature_names(feature_names);
.with_feature_names(feature_names.clone())
.with_target_names(target_names.clone());

(dataset1, dataset2)
}
Expand Down
3 changes: 2 additions & 1 deletion src/dataset/impl_targets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl<'a, L: Label + 'a, T> FromTargetArray<'a> for CountedTargets<L, T>
where
T: FromTargetArray<'a, Elem = L>,
T::Owned: Labels<Elem = L>,
T::View: Labels<Elem = L>,
T::View: Labels<Elem = L> + AsTargets,
{
type Owned = CountedTargets<L, T::Owned>;
type View = CountedTargets<L, T::View>;
Expand Down Expand Up @@ -231,6 +231,7 @@ where
weights: Array1::from(weights),
targets,
feature_names: self.feature_names.clone(),
target_names: self.target_names.clone(),
}
}
}
9 changes: 8 additions & 1 deletion src/dataset/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,24 @@
if self.target_or_feature && self.dataset.nfeatures() <= self.idx {
return None;
}

let mut records = self.dataset.records.view();
let mut targets = self.dataset.targets.as_targets();
let feature_names;
let target_names;

Check warning on line 83 in src/dataset/iter.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/iter.rs#L83

Added line #L83 was not covered by tests
let weights = self.dataset.weights.clone();

if !self.target_or_feature {
// This branch should only run for 2D targets
targets.collapse_axis(Axis(1), self.idx);
feature_names = self.dataset.feature_names.clone();
if self.dataset.target_names.is_empty() {
target_names = Vec::new();

Check warning on line 91 in src/dataset/iter.rs

View check run for this annotation

Codecov / codecov/patch

src/dataset/iter.rs#L91

Added line #L91 was not covered by tests
} else {
target_names = vec![self.dataset.target_names[self.idx].clone()];
}
} else {
records.collapse_axis(Axis(1), self.idx);
target_names = self.dataset.target_names.clone();
if self.dataset.feature_names.len() == records.len_of(Axis(1)) {
feature_names = vec![self.dataset.feature_names[self.idx].clone()];
} else {
Expand All @@ -103,6 +109,7 @@
targets,
weights,
feature_names,
target_names,
};

Some(dataset_view)
Expand Down
Loading
Loading