Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sgl-model-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ tonic-v12 = { version = "0.12.3", package = "tonic" }
serial_test = "3.0"
rsa = { version = "0.9", features = ["sha2"] }


[[bench]]
name = "l1_cache_benchmark"
harness = false
[[bench]]
name = "request_processing"
harness = false
Expand Down
64 changes: 64 additions & 0 deletions sgl-model-gateway/benches/l1_cache_benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@

use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
use sgl_model_gateway::tokenizer::cache::L1Cache;
use sgl_model_gateway::tokenizer::mock::MockTokenizer;

fn generate_prompt(turns: usize) -> String {
let mut prompt = String::new();
prompt.push_str("<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>");
for i in 0..turns {
prompt.push_str(&format!(
"<|im_start|>user\nIteration {} prompt text to test hashing and tokenization performance.<|im_end|>",
i
));
prompt.push_str("<|im_start|>assistant\nI am processing your request and generating a valid response.<|im_end|>");
}
prompt
}

fn bench_l1_cache(c: &mut Criterion) {
let special_tokens = vec!["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();

// We test with exponentially increasing turns to see the O(N^2) impact vs O(N)
for turns in [2, 10, 50].iter() {
let input = generate_prompt(*turns);
let mut group = c.benchmark_group(format!("L1-Cache-Turns-{}", turns));

// Measure throughput in terms of characters processed per second
group.throughput(Throughput::Elements(input.len() as u64));

// Insertion Benchmark
// Current code re-hashes and re-tokenizes the prefix at every boundary.
// Optimization targets this method specifically.
group.bench_function("insert_at_boundaries", |b| {
b.iter(|| {
// We create a new cache per iteration to ensure we are benchmarking
// the full insertion logic and not a "no-op" on existing entries.
let cache = L1Cache::new(100 * 1024 * 1024);
let _ = cache.insert_at_boundaries(
black_box(&input),
black_box(&tokenizer),
black_box(&special_tokens),
black_box(false),
);
})
});

// Lookup Benchmark
// This measures the efficiency of the backward search.
group.bench_function("longest_prefix_match", |b| {
let cache = L1Cache::new(100 * 1024 * 1024);
let _ = cache.insert_at_boundaries(&input, &tokenizer, &special_tokens, false);

b.iter(|| {
let _ = cache.longest_prefix_match(black_box(&input), black_box(&special_tokens));
})
});

group.finish();
}
}

criterion_group!(benches, bench_l1_cache);
criterion_main!(benches);
55 changes: 34 additions & 21 deletions sgl-model-gateway/src/tokenizer/cache/l1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,19 @@ impl L1Cache {
return None;
}

// Search backwards from the longest boundary to find the best match
for &boundary_pos in boundaries.iter().rev() {
let prefix = &input[0..boundary_pos];
let prefix_bytes = prefix.as_bytes();
let hash = blake3::hash(prefix_bytes);
let hash_bytes: Blake3Hash = *hash.as_bytes();
// Build all prefix hashes incrementally O(N)
let mut hasher = blake3::Hasher::new();
let mut prefix_hashes = Vec::with_capacity(boundaries.len());
let mut last_pos = 0;

for &boundary_pos in &boundaries {
hasher.update(input[last_pos..boundary_pos].as_bytes());
prefix_hashes.push((boundary_pos, *hasher.clone().finalize().as_bytes()));
last_pos = boundary_pos;
}

// Search from the longest boundary to find the best match
for (boundary_pos, hash_bytes) in prefix_hashes.into_iter().rev() {
let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;

if let Some(entry) = self.shards[shard_idx].get(&hash_bytes) {
Expand All @@ -167,8 +173,7 @@ impl L1Cache {

/// Insert prefix entries at ALL special token boundaries
///
/// Re-tokenizes each prefix to ensure correctness (BPE tokenization is not prefix-stable).
/// This is more expensive on cache misses but provides correct tokens for cache hits.
/// Uses incremental hashing and tokenization for O(N) performance.
///
/// Optimized for workloads with high prefix reuse (e.g., chat templates with repeated system prompts).
pub fn insert_at_boundaries<E: super::super::traits::Encoder + ?Sized>(
Expand All @@ -184,25 +189,33 @@ impl L1Cache {
return Ok(());
}

// Calculate how much memory we need and tokenize each prefix
let mut hasher = blake3::Hasher::new();
let mut running_tokens = Vec::new();
let mut last_pos = 0;
let mut entries_to_insert = Vec::with_capacity(boundaries.len());
for &boundary_pos in &boundaries {
// Extract prefix up to this special token boundary
let prefix = &input[0..boundary_pos];
let prefix_bytes = prefix.as_bytes();
let hash = blake3::hash(prefix_bytes);
let hash_bytes: Blake3Hash = *hash.as_bytes();

// Re-tokenize the prefix for guaranteed correctness
// This is the only way to know the exact token boundaries
let prefix_encoding = tokenizer.encode(prefix, add_special_tokens)?;
// Convert to Arc<[TokenIdType]> for zero-copy sharing
let prefix_tokens: Arc<[TokenIdType]> = prefix_encoding.token_ids().into();

for (i, &boundary_pos) in boundaries.iter().enumerate() {
let delta_text = &input[last_pos..boundary_pos];

// 1. Incremental Hash update
hasher.update(delta_text.as_bytes());
let hash_bytes: Blake3Hash = *hasher.clone().finalize().as_bytes();

// 2. Incremental Tokenization
// Only add special tokens (like BOS) for the very first segment to avoid duplicates
let segment_encoding = tokenizer.encode(delta_text, (i == 0) && add_special_tokens)?;
running_tokens.extend_from_slice(segment_encoding.token_ids());

// 3. Prepare entry
// Convert current tokens to Arc<[TokenIdType]> for sharing
let prefix_tokens: Arc<[TokenIdType]> = running_tokens.as_slice().into();

// Size = text bytes + token storage
let size_bytes = boundary_pos + prefix_tokens.len() * size_of::<TokenIdType>();

entries_to_insert.push((hash_bytes, prefix_tokens, size_bytes));

last_pos = boundary_pos;
}

if entries_to_insert.is_empty() {
Expand Down