Skip to content

Commit d3d14fc

Browse files
committed
linfa-preprocessing::countgrams: Multi-threading made possible
1 parent 5272ad1 commit d3d14fc

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use crate::PreprocessingError;
22
use linfa::ParamGuard;
33
use regex::Regex;
4-
use std::cell::{Ref, RefCell};
54
use std::collections::HashSet;
5+
use std::sync::OnceLock;
66

77
#[cfg(feature = "serde")]
88
use serde_crate::{Deserialize, Serialize};
@@ -68,7 +68,7 @@ impl SerdeRegex {
6868
pub struct CountVectorizerValidParams {
6969
convert_to_lowercase: bool,
7070
split_regex_expr: String,
71-
split_regex: RefCell<Option<SerdeRegex>>,
71+
split_regex: OnceLock<SerdeRegex>,
7272
n_gram_range: (usize, usize),
7373
normalize: bool,
7474
document_frequency: (f32, f32),
@@ -92,8 +92,11 @@ impl CountVectorizerValidParams {
9292
self.convert_to_lowercase
9393
}
9494

95-
pub fn split_regex(&self) -> Ref<'_, Regex> {
96-
Ref::map(self.split_regex.borrow(), |x| x.as_ref().unwrap().as_re())
95+
pub fn split_regex(&self) -> &Regex {
96+
self.split_regex
97+
.get()
98+
.expect("Regex not initialized")
99+
.as_re()
97100
}
98101

99102
pub fn n_gram_range(&self) -> (usize, usize) {
@@ -126,7 +129,7 @@ impl std::default::Default for CountVectorizerParams {
126129
Self(CountVectorizerValidParams {
127130
convert_to_lowercase: true,
128131
split_regex_expr: r"\b\w\w+\b".to_string(),
129-
split_regex: RefCell::new(None),
132+
split_regex: OnceLock::new(),
130133
n_gram_range: (1, 1),
131134
normalize: true,
132135
document_frequency: (0., 1.),
@@ -224,7 +227,8 @@ impl ParamGuard for CountVectorizerParams {
224227
min_freq, max_freq,
225228
))
226229
} else {
227-
*self.0.split_regex.borrow_mut() = Some(SerdeRegex::new(&self.0.split_regex_expr)?);
230+
let regex = SerdeRegex::new(&self.0.split_regex_expr)?;
231+
let _ = self.0.split_regex.set(regex);
228232

229233
Ok(&self.0)
230234
}

0 commit comments

Comments
 (0)