11use crate :: PreprocessingError ;
22use linfa:: ParamGuard ;
33use regex:: Regex ;
4- use std:: cell:: { Ref , RefCell } ;
54use std:: collections:: HashSet ;
5+ use std:: sync:: OnceLock ;
66
77#[ cfg( feature = "serde" ) ]
88use serde_crate:: { Deserialize , Serialize } ;
@@ -68,7 +68,7 @@ impl SerdeRegex {
6868pub struct CountVectorizerValidParams {
6969 convert_to_lowercase : bool ,
7070 split_regex_expr : String ,
71- split_regex : RefCell < Option < SerdeRegex > > ,
71+ split_regex : OnceLock < SerdeRegex > ,
7272 n_gram_range : ( usize , usize ) ,
7373 normalize : bool ,
7474 document_frequency : ( f32 , f32 ) ,
@@ -92,8 +92,11 @@ impl CountVectorizerValidParams {
9292 self . convert_to_lowercase
9393 }
9494
95- pub fn split_regex ( & self ) -> Ref < ' _ , Regex > {
96- Ref :: map ( self . split_regex . borrow ( ) , |x| x. as_ref ( ) . unwrap ( ) . as_re ( ) )
95+ pub fn split_regex ( & self ) -> & Regex {
96+ self . split_regex
97+ . get ( )
98+ . expect ( "Regex not initialized" )
99+ . as_re ( )
97100 }
98101
99102 pub fn n_gram_range ( & self ) -> ( usize , usize ) {
@@ -126,7 +129,7 @@ impl std::default::Default for CountVectorizerParams {
126129 Self ( CountVectorizerValidParams {
127130 convert_to_lowercase : true ,
128131 split_regex_expr : r"\b\w\w+\b" . to_string ( ) ,
129- split_regex : RefCell :: new ( None ) ,
132+ split_regex : OnceLock :: new ( ) ,
130133 n_gram_range : ( 1 , 1 ) ,
131134 normalize : true ,
132135 document_frequency : ( 0. , 1. ) ,
@@ -224,7 +227,8 @@ impl ParamGuard for CountVectorizerParams {
224227 min_freq, max_freq,
225228 ) )
226229 } else {
227- * self . 0 . split_regex . borrow_mut ( ) = Some ( SerdeRegex :: new ( & self . 0 . split_regex_expr ) ?) ;
230+ let regex = SerdeRegex :: new ( & self . 0 . split_regex_expr ) ?;
231+ let _ = self . 0 . split_regex . set ( regex) ;
228232
229233 Ok ( & self . 0 )
230234 }
0 commit comments