Skip to content

Commit fd4a214

Browse files
Add max_features and tokenizer to CountVectorizer (#376)
1 parent 6ab89bf commit fd4a214

File tree

9 files changed

+188
-57
lines changed

9 files changed

+188
-57
lines changed

algorithms/linfa-preprocessing/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ encoding = "0.2"
3232
sprs = { version = "=0.11.1", default-features = false }
3333

3434
serde_regex = { version = "1.1", optional = true }
35+
itertools = "0.14.0"
3536

3637
[dependencies.serde_crate]
3738
package = "serde"
@@ -44,6 +45,7 @@ features = ["std", "derive"]
4445
linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [
4546
"diabetes",
4647
"winequality",
48+
"generate"
4749
] }
4850
linfa-bayes = { version = "0.7.1", path = "../linfa-bayes" }
4951
iai = "0.1"

algorithms/linfa-preprocessing/benches/vectorizer_bench.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ fn fit_transform_vectorizer(file_names: &[std::path::PathBuf]) {
118118
file_names,
119119
encoding::all::ISO_8859_1,
120120
encoding::DecoderTrap::Strict,
121-
);
121+
)
122+
.unwrap();
122123
}
123124
fn fit_transform_tf_idf(file_names: &[std::path::PathBuf]) {
124125
TfIdfVectorizer::default()
@@ -134,7 +135,8 @@ fn fit_transform_tf_idf(file_names: &[std::path::PathBuf]) {
134135
file_names,
135136
encoding::all::ISO_8859_1,
136137
encoding::DecoderTrap::Strict,
137-
);
138+
)
139+
.unwrap();
138140
}
139141

140142
fn bench(c: &mut Criterion) {

algorithms/linfa-preprocessing/examples/count_vectorization.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ fn main() {
126126
// Transforming gives a sparse dataset, we make it dense in order to be able to fit the Naive Bayes model
127127
let training_records = vectorizer
128128
.transform_files(&training_filenames, ISO_8859_1, Strict)
129+
.unwrap()
129130
.to_dense();
130131
// Currently linfa only allows real valued features so we have to transform the integer counts to floats
131132
let training_records = training_records.mapv(|c| c as f32);
@@ -164,6 +165,7 @@ fn main() {
164165
);
165166
let test_records = vectorizer
166167
.transform_files(&test_filenames, ISO_8859_1, Strict)
168+
.unwrap()
167169
.to_dense();
168170
let test_records = test_records.mapv(|c| c as f32);
169171
let test_dataset: Dataset<f32, usize, Ix1> = (test_records, test_targets).into();

algorithms/linfa-preprocessing/examples/tfidf_vectorization.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ fn main() {
126126
// Transforming gives a sparse dataset, we make it dense in order to be able to fit the Naive Bayes model
127127
let training_records = vectorizer
128128
.transform_files(&training_filenames, ISO_8859_1, Strict)
129+
.unwrap()
129130
.to_dense();
130131

131132
println!(
@@ -162,6 +163,7 @@ fn main() {
162163
);
163164
let test_records = vectorizer
164165
.transform_files(&test_filenames, ISO_8859_1, Strict)
166+
.unwrap()
165167
.to_dense();
166168
let test_dataset: Dataset<f64, usize, Ix1> = (test_records, test_targets).into();
167169
// Let's predict the test data targets

algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use std::collections::HashSet;
77
#[cfg(feature = "serde")]
88
use serde_crate::{Deserialize, Serialize};
99

10+
use super::{Tokenizer, Tokenizerfp};
11+
1012
#[derive(Clone, Debug)]
1113
#[cfg(not(feature = "serde"))]
1214
struct SerdeRegex(Regex);
@@ -71,9 +73,21 @@ pub struct CountVectorizerValidParams {
7173
normalize: bool,
7274
document_frequency: (f32, f32),
7375
stopwords: Option<HashSet<String>>,
76+
max_features: Option<usize>,
77+
#[cfg_attr(feature = "serde", serde(skip))]
78+
pub(crate) tokenizer_function: Option<Tokenizerfp>,
79+
pub(crate) tokenizer_deserialization_guard: bool,
7480
}
7581

7682
impl CountVectorizerValidParams {
83+
pub fn tokenizer_function(&self) -> Option<Tokenizerfp> {
84+
self.tokenizer_function
85+
}
86+
87+
pub fn max_features(&self) -> Option<usize> {
88+
self.max_features
89+
}
90+
7791
pub fn convert_to_lowercase(&self) -> bool {
7892
self.convert_to_lowercase
7993
}
@@ -117,20 +131,41 @@ impl std::default::Default for CountVectorizerParams {
117131
normalize: true,
118132
document_frequency: (0., 1.),
119133
stopwords: None,
134+
max_features: None,
135+
tokenizer_function: None,
136+
tokenizer_deserialization_guard: false,
120137
})
121138
}
122139
}
123140

124141
impl CountVectorizerParams {
125-
///If true, all documents used for fitting will be converted to lowercase.
126-
pub fn convert_to_lowercase(mut self, convert_to_lowercase: bool) -> Self {
127-
self.0.convert_to_lowercase = convert_to_lowercase;
142+
// Set the tokenizer as either a function pointer or a regex
143+
// If this method is not called, the default is to use regex "\b\w\w+\b"
144+
pub fn tokenizer(mut self, tokenizer: Tokenizer) -> Self {
145+
match tokenizer {
146+
Tokenizer::Function(fp) => {
147+
self.0.tokenizer_function = Some(fp);
148+
self.0.tokenizer_deserialization_guard = true;
149+
}
150+
Tokenizer::Regex(regex_str) => {
151+
self.0.split_regex_expr = regex_str.to_string();
152+
self.0.tokenizer_deserialization_guard = false;
153+
}
154+
}
155+
156+
self
157+
}
158+
159+
/// When building the vocabulary, only consider the top max_features (by term frequency).
160+
/// If None, all features are used.
161+
pub fn max_features(mut self, max_features: Option<usize>) -> Self {
162+
self.0.max_features = max_features;
128163
self
129164
}
130165

131-
/// Sets the regex espression used to split decuments into tokens
132-
pub fn split_regex(mut self, regex_str: &str) -> Self {
133-
self.0.split_regex_expr = regex_str.to_string();
166+
///If true, all documents used for fitting will be converted to lowercase.
167+
pub fn convert_to_lowercase(mut self, convert_to_lowercase: bool) -> Self {
168+
self.0.convert_to_lowercase = convert_to_lowercase;
134169
self
135170
}
136171

0 commit comments

Comments
 (0)