-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkv_cache_demo.rs
More file actions
145 lines (114 loc) · 4.7 KB
/
kv_cache_demo.rs
File metadata and controls
145 lines (114 loc) · 4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
//! Demonstrates using TurboQuant as a quantized KV cache with attention.
//!
//! Creates a single-layer cache, pushes 1024 key-value pairs, computes
//! attention scores for a query, and reports memory statistics.
//!
//! No external dependencies (no `rand`). Uses deterministic LCG + sine
//! patterns for reproducible data.
use turboquant::attention::QuantizedKVCache;
use turboquant::packed::TurboQuantConfig;
// ---------------------------------------------------------------------------
// Named constants
// ---------------------------------------------------------------------------
/// Number of quantization bits (TQ3).
const BITS: u8 = 3;
/// Head dimension (typical for LLMs like LLaMA, Mistral).
const DIM: usize = 128;
/// Seed for the PolarQuant rotation sign pattern.
const ROTATION_SEED: u64 = 42;
/// Seed for the QJL Rademacher matrix.
const QJL_SEED: u64 = 12345;
/// Number of layers in the cache.
const NUM_LAYERS: usize = 1;
/// Layer index for this demo.
const LAYER: usize = 0;
/// Number of KV pairs to push into the cache.
const NUM_ENTRIES: usize = 1024;
/// Number of attention scores to display.
const DISPLAY_SCORES: usize = 8;
use turboquant::test_utils::{pseudo_random_vec, LCG_MULTIPLIER};
/// Amplitude for key vector generation.
const KEY_AMPLITUDE: f32 = 1.0;
/// Amplitude for value vector generation.
const VALUE_AMPLITUDE: f32 = 0.5;
/// Query frequency scaling.
const QUERY_FREQUENCY: f64 = 0.25;
/// Number of bytes per FP16 element.
const BYTES_PER_FP16: usize = 2;
/// Number of KV components (key + value).
const KV_PAIR_COUNT: usize = 2;
/// Bytes per kilobyte.
const BYTES_PER_KB: f64 = 1024.0;
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Deterministic pseudo-random vector scaled by `amplitude`, delegating the
/// core LCG to the shared `test_utils::pseudo_random_vec`.
fn lcg_vec(dim: usize, seed: u64, amplitude: f32) -> Vec<f32> {
pseudo_random_vec(dim, seed)
.into_iter()
.map(|x| amplitude * x)
.collect()
}
fn main() {
// -- 1. Create cache ------------------------------------------------------
let config = TurboQuantConfig::new(BITS, DIM)
.expect("valid config")
.with_seed(ROTATION_SEED);
let mut cache = QuantizedKVCache::new(config, NUM_LAYERS, QJL_SEED);
// -- 2. Push KV pairs (batch) -----------------------------------------------
println!("Pushing {NUM_ENTRIES} key-value pairs via push_batch (d={DIM}, TQ{BITS})...");
let keys: Vec<Vec<f32>> = (0..NUM_ENTRIES)
.map(|i| {
let seed = (i as u64).wrapping_mul(LCG_MULTIPLIER).wrapping_add(1000);
lcg_vec(DIM, seed, KEY_AMPLITUDE)
})
.collect();
let values: Vec<Vec<f32>> = (0..NUM_ENTRIES)
.map(|i| {
let seed = (i as u64).wrapping_mul(LCG_MULTIPLIER).wrapping_add(2000);
lcg_vec(DIM, seed, VALUE_AMPLITUDE)
})
.collect();
let key_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let val_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
cache
.push_batch(LAYER, &key_refs, &val_refs)
.expect("push_batch succeeded");
// -- 3. Compute attention scores ------------------------------------------
let query: Vec<f32> = (0..DIM)
.map(|i| (i as f64 * QUERY_FREQUENCY).sin() as f32)
.collect();
let scores = cache
.attention_scores(LAYER, &query)
.expect("attention succeeded");
// -- 4. Print cache stats -------------------------------------------------
let quantized_bytes = cache.memory_usage();
let fp16_bytes = cache.fp16_equivalent_memory();
let compression_ratio = fp16_bytes as f64 / quantized_bytes as f64;
println!();
println!("Cache Statistics");
println!("================");
println!("Entries: {}", cache.entry_count(LAYER));
println!(
"Quantized memory: {:.1} KB",
quantized_bytes as f64 / BYTES_PER_KB
);
println!(
"FP16 equivalent: {:.1} KB",
fp16_bytes as f64 / BYTES_PER_KB
);
println!("Compression ratio: {compression_ratio:.2}x");
println!();
// Sanity check: FP16 equivalent should match manual calculation
let expected_fp16 = NUM_ENTRIES * DIM * BYTES_PER_FP16 * KV_PAIR_COUNT;
assert_eq!(fp16_bytes, expected_fp16);
// -- 5. Print attention scores --------------------------------------------
println!(
"Attention scores (first {DISPLAY_SCORES} of {}):",
scores.len()
);
for (i, &score) in scores.iter().take(DISPLAY_SCORES).enumerate() {
println!(" key[{i:4}]: {score:+.6}");
}
}