Skip to content

Commit 40ae9f6

Browse files
committed
perf(rust): update tokenizers storage, significantly increasing performance
1 parent 7ac5a21 commit 40ae9f6

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

rust/src/lib.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ pub unsafe extern "C" fn csharp_to_rust_u32_array(buffer: *const u32, len: i32)
108108
}
109109

110110
// Tokenizer stuff starts here
111-
use once_cell::unsync::Lazy;
112111
use std::collections::HashMap;
113112
use std::fmt;
113+
use std::sync::{LazyLock, RwLock};
114114
use tokenizers::tokenizer::Tokenizer;
115115
use uuid::Uuid;
116116

@@ -151,9 +151,11 @@ pub struct TokenizerResult {
151151

152152
static mut LAST_ERROR_MESSAGE: String = String::new();
153153

154-
static mut TOKENIZER_SESSION: Lazy<HashMap<String, TokenizerInfo>> = Lazy::new(|| HashMap::new());
154+
static TOKENIZER_SESSION: LazyLock<RwLock<HashMap<String, TokenizerInfo>>> =
155+
LazyLock::new(|| RwLock::new(HashMap::new()));
155156

156-
static mut TOKENIZER_DB: Lazy<HashMap<String, Tokenizer>> = Lazy::new(|| HashMap::new());
157+
static TOKENIZER_DB: LazyLock<RwLock<HashMap<String, Tokenizer>>> =
158+
LazyLock::new(|| RwLock::new(HashMap::new()));
157159

158160
#[no_mangle]
159161
pub unsafe extern "C" fn get_last_error_message() -> *mut ByteBuffer {
@@ -181,7 +183,7 @@ pub unsafe extern "C" fn tokenizer_initialize(
181183
};
182184

183185
// Add to TOKENIZER_SESSION
184-
TOKENIZER_SESSION.insert(id.clone(), info);
186+
TOKENIZER_SESSION.write().unwrap().insert(id.clone(), info);
185187

186188
// Initialize TOKENIZER_DB if not already initialized
187189
let tokenizer = match Tokenizer::from_file(&utf8_path) {
@@ -194,7 +196,7 @@ pub unsafe extern "C" fn tokenizer_initialize(
194196
};
195197
}
196198
};
197-
TOKENIZER_DB.insert(id.clone(), tokenizer);
199+
TOKENIZER_DB.write().unwrap().insert(id.clone(), tokenizer);
198200

199201
let session_id = id.clone();
200202
TokenizerResult {
@@ -254,7 +256,8 @@ pub unsafe extern "C" fn tokenizer_encode(
254256
};
255257

256258
// Retrieve the tokenizer associated with the session ID
257-
let tokenizer = match TOKENIZER_DB.get(&session_id).cloned() {
259+
let lock = TOKENIZER_DB.read().unwrap();
260+
let tokenizer = match lock.get(&session_id) {
258261
Some(t) => t,
259262
None => {
260263
LAST_ERROR_MESSAGE = format!("Tokenizer for session ID '{}' not found", session_id);
@@ -332,7 +335,8 @@ pub unsafe extern "C" fn tokenizer_decode(
332335
};
333336

334337
// Retrieve the tokenizer associated with the session ID
335-
let tokenizer = match TOKENIZER_DB.get(&session_id).cloned() {
338+
let lock = TOKENIZER_DB.read().unwrap();
339+
let tokenizer = match lock.get(&session_id) {
336340
Some(t) => t,
337341
None => {
338342
LAST_ERROR_MESSAGE = format!("Tokenizer for session ID '{}' not found", session_id);
@@ -388,8 +392,8 @@ pub unsafe extern "C" fn get_version(
388392
};
389393

390394
// Retrieve the TokenizerInfo associated with the session ID
391-
let session_info = match TOKENIZER_SESSION.get(&session_id) {
392-
Some(info) => info,
395+
let version = match TOKENIZER_SESSION.read().unwrap().get(&session_id) {
396+
Some(info) => info.library_version.clone(),
393397
None => {
394398
LAST_ERROR_MESSAGE = format!("Session info for session ID '{}' not found", session_id);
395399
return TokenizerResult {
@@ -399,9 +403,6 @@ pub unsafe extern "C" fn get_version(
399403
}
400404
};
401405

402-
// Get the library version from the TokenizerInfo
403-
let version = session_info.library_version.clone();
404-
405406
// Return success with version as ByteBuffer
406407
TokenizerResult {
407408
error_code: TokenizerErrorCode::Success,
@@ -430,8 +431,8 @@ pub unsafe extern "C" fn tokenizer_cleanup(
430431
};
431432

432433
// Remove tokenizer and session info
433-
let removed_tokenizer = TOKENIZER_DB.remove(&session_id);
434-
let removed_session = TOKENIZER_SESSION.remove(&session_id);
434+
let removed_tokenizer = TOKENIZER_DB.write().unwrap().remove(&session_id);
435+
let removed_session = TOKENIZER_SESSION.write().unwrap().remove(&session_id);
435436

436437
if removed_tokenizer.is_none() || removed_session.is_none() {
437438
LAST_ERROR_MESSAGE = format!("Session ID '{}' not found", session_id);

0 commit comments

Comments
 (0)