Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion candle-examples/examples/quantized-gemma/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use candle_nn::kv_cache::DefaultKvCache;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -175,7 +176,7 @@ fn main() -> anyhow::Result<()> {
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;

let mut model = {
let mut model: ModelWeights<DefaultKvCache> = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
Expand Down
3 changes: 2 additions & 1 deletion candle-examples/examples/quantized-qwen3/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use candle_nn::kv_cache::DefaultKvCache;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -189,7 +190,7 @@ fn main() -> anyhow::Result<()> {
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;

let mut model = {
let mut model: Qwen3<DefaultKvCache> = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
Expand Down
223 changes: 170 additions & 53 deletions candle-nn/src/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,43 @@
//!
use candle::{DType, Device, Result, Tensor};

pub type DefaultKvCache = ConcatKvCache;

pub trait KvCache {
type Mask;
fn new(dim: usize, max_seq_len: usize) -> Self;
fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does append insert into the cache?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth adding some doc strings here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it can depend on the cache implementation, but typically you append to the kv cache. It's not like your typical cache ala hashmap where you insert by key.

fn append_with_mask(
&mut self,
k: &Tensor,
v: &Tensor,
_mask: Option<&Self::Mask>,
) -> Result<(Tensor, Tensor)> {
Comment on lines +11 to +16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of having mask being an Option here when we already have append[_without_mask]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True.
I initially added optional mask to append and later realized it is better expressed as a separate fn.

self.append(k, v)
}
fn reset(&mut self);
}

#[derive(Debug, Clone)]
pub struct Cache {
pub struct InnerCache {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Cache was ok as a name, although I get the sentiment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's only used in one kv cache variant, and not even the default one, so it feels wrong that it should be the Cache

// all_data is an option on a Tensor, this makes it possible to only create the actual tensor
// on the first call where the batch size is easily known.
// Also this makes it safe to clone a KvCache that has been reset (as in it will not share
// its internal state with the cloned instance).
all_data: Option<Tensor>,
dim: usize,
current_seq_len: usize,
grow_by: usize,
increment: usize,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the change?

max_seq_len: usize,
}

impl Cache {
impl InnerCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
current_seq_len: 0,
grow_by: max_seq_len,
increment: max_seq_len,
max_seq_len,
}
}
Expand Down Expand Up @@ -68,10 +85,10 @@ impl Cache {
let ad = self.all_data.as_mut().unwrap();
while self.current_seq_len + seq_len > self.max_seq_len {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.grow_by;
shape[self.dim] = self.increment;
let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
*ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
self.max_seq_len += self.grow_by;
self.max_seq_len += self.increment;
}
ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Expand All @@ -80,31 +97,46 @@ impl Cache {
}

#[derive(Debug, Clone)]
pub struct KvCache {
k: Cache,
v: Cache,
pub struct IncrementalKvCache {
k: InnerCache,
v: InnerCache,
}

impl KvCache {
impl KvCache for IncrementalKvCache {
type Mask = ();
fn new(dim: usize, max_seq_len: usize) -> Self {
IncrementalKvCache::new(dim, max_seq_len)
}

fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.append(k, v)
}

fn reset(&mut self) {
self.reset()
}
}

impl IncrementalKvCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = Cache::new(dim, max_seq_len);
let v = Cache::new(dim, max_seq_len);
let k = InnerCache::new(dim, max_seq_len);
let v = InnerCache::new(dim, max_seq_len);
Self { k, v }
}

pub fn k_cache(&self) -> &Cache {
pub fn k_cache(&self) -> &InnerCache {
&self.k
}

pub fn v_cache(&self) -> &Cache {
pub fn v_cache(&self) -> &InnerCache {
&self.v
}

pub fn k_cache_mut(&mut self) -> &mut Cache {
pub fn k_cache_mut(&mut self) -> &mut InnerCache {
&mut self.k
}

pub fn v_cache_mut(&mut self) -> &mut Cache {
pub fn v_cache_mut(&mut self) -> &mut InnerCache {
&mut self.v
}

Expand All @@ -117,8 +149,8 @@ impl KvCache {
}

pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.k.append(k)?;
self.v.append(v)?;
self.k.append(&k.contiguous()?)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to know what the rationale for adding contiguous here is 👀

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the default behaviour in the model code I used as reference. Is there any reason it shouldn't call contiguous?
I could omit it but it improves performance for the cache impls I tested, and calling contiguous multiple times has almost zero cost (if tensor is already contiguous it just clones an Arc)

self.v.append(&v.contiguous()?)?;
let out_k = self.k.current_data()?;
let out_v = self.v.current_data()?;
let k = match out_k {
Expand Down Expand Up @@ -338,6 +370,21 @@ pub struct RotatingKvCache {
v: RotatingCache,
}

impl KvCache for RotatingKvCache {
type Mask = ();
fn new(dim: usize, max_seq_len: usize) -> Self {
RotatingKvCache::new(dim, max_seq_len)
}

fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.append(k, v)
}

fn reset(&mut self) {
self.reset()
}
}

impl RotatingKvCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = RotatingCache::new(dim, max_seq_len);
Expand Down Expand Up @@ -414,34 +461,90 @@ impl IndicesAndMask {

#[derive(Debug, Clone)]
pub struct ScatteredKvCache {
k: Tensor,
v: Tensor,
k: Option<Tensor>,
v: Option<Tensor>,
dim: usize,
context: usize,
}
impl KvCache for ScatteredKvCache {
type Mask = IndicesAndMask;

fn new(dim: usize, max_seq_len: usize) -> Self {
ScatteredKvCache::new(dim, max_seq_len)
}

fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.append_with_mask(k, v, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see here we could just call candle::bail! instead of having to do that in append_with_mask. feels a bit clunky

}

fn append_with_mask(
&mut self,
k: &Tensor,
v: &Tensor,
mask: Option<&Self::Mask>,
) -> Result<(Tensor, Tensor)> {
if let Some(mask) = mask {
self.scattered_append(k, v, mask)
} else {
candle::bail!("ScatteredKvCache requires InidicesAndMask")
}
}

fn reset(&mut self) {
self.reset()
}
}

impl ScatteredKvCache {
pub fn append(
pub fn new(dim: usize, context: usize) -> Self {
Self {
k: None,
v: None,
dim,
context,
}
}

pub fn scattered_append(
&mut self,
k: &Tensor,
v: &Tensor,
iam: &IndicesAndMask,
) -> Result<(Tensor, Tensor)> {
if self.context <= k.dim(2)? {
if self.context <= k.dim(self.dim)? {
return Ok((k.clone(), v.clone()));
}
let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
if self.k.is_none() {
let mut k_shape = k.dims().to_vec();
k_shape[self.dim] = self.context;
self.k = Some(Tensor::zeros(k_shape.clone(), k.dtype(), k.device())?);
}
if self.v.is_none() {
let mut v_shape = v.dims().to_vec();
v_shape[self.dim] = self.context;
self.v = Some(Tensor::zeros(v_shape.clone(), v.dtype(), v.device())?);
}

let indices = iam.indices.unsqueeze(self.dim)?.unsqueeze(1)?;
let indices = indices.broadcast_as(k.shape())?.contiguous()?;
self.k.scatter_set(&indices, k, 2)?;
self.v.scatter_set(&indices, v, 2)?;
Ok((self.k.clone(), self.v.clone()))
let new_k = self.k.as_mut().unwrap();
let new_v = self.v.as_mut().unwrap();
new_k.scatter_set(&indices, k, self.dim)?;
new_v.scatter_set(&indices, v, self.dim)?;
Ok((new_k.clone(), new_v.clone()))
}

pub fn k(&self) -> &Tensor {
&self.k
pub fn k(&self) -> Option<&Tensor> {
self.k.as_ref()
}

pub fn v(&self) -> &Tensor {
&self.v
pub fn v(&self) -> Option<&Tensor> {
self.v.as_ref()
}

pub fn reset(&mut self) {
self.k = None;
self.v = None;
}
}

Expand Down Expand Up @@ -469,16 +572,8 @@ impl ScatteredCacheBuilder {
})
}

pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
let batch_size = self.batch_size();
let shape = (batch_size, num_heads, self.context, head_dim);
let k = Tensor::zeros(shape, self.dtype, self.device())?;
let v = Tensor::zeros(shape, self.dtype, self.device())?;
Ok(ScatteredKvCache {
k,
v,
context: self.context,
})
pub fn make_cache(&self, head_dim: usize) -> ScatteredKvCache {
ScatteredKvCache::new(head_dim, self.context)
}

pub fn positions(&self) -> &[usize] {
Expand All @@ -499,7 +594,6 @@ impl ScatteredCacheBuilder {
self.indices[batch_index] = 0;
}

#[allow(clippy::needless_range_loop)]
pub fn indices_and_mask(
&mut self,
seq_len: usize,
Expand All @@ -525,18 +619,26 @@ impl ScatteredCacheBuilder {
let mut indices = Vec::with_capacity(seq_len);
let mut all_pos = vec![usize::MAX; context];
if start_pos < context {
for i in 0..start_pos {
all_pos[i] = i;
}
all_pos
.iter_mut()
.enumerate()
.take(start_pos)
.for_each(|(i, p)| {
*p = i;
});
} else {
let offset = start_pos - start_index;
for i in 0..context {
all_pos[i] = if i < start_index {
i + offset
} else {
i + offset - context
};
}
all_pos
.iter_mut()
.enumerate()
.take(context)
.for_each(|(i, p)| {
*p = if i < start_index {
i + offset
} else {
i + offset - context
};
});
}
for seq_i in 0..seq_len {
let index = self.indices[batch_i];
Expand Down Expand Up @@ -584,7 +686,6 @@ impl ScatteredCacheBuilder {
&self.device
}

#[allow(clippy::needless_range_loop)]
fn indices_and_mask_abs(
&mut self,
seq_len: usize,
Expand Down Expand Up @@ -642,7 +743,7 @@ impl ScatteredCacheBuilder {
/// - GPU inference (CUDA, Metal)
/// - Autoregressive generation (token-by-token decoding)
///
/// **Use `KvCache` instead for:**
/// **Use `IncrementalKvCache` instead for:**
/// - CPU-only inference
/// - When you need fixed memory allocation upfront
///
Expand Down Expand Up @@ -670,6 +771,22 @@ pub struct ConcatKvCache {
dim: usize,
}

impl KvCache for ConcatKvCache {
type Mask = ();

fn new(dim: usize, _: usize) -> Self {
ConcatKvCache::new(dim)
}

fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.append(k, v)
}

fn reset(&mut self) {
self.reset()
}
}

impl ConcatKvCache {
/// Create a new empty concatenation-based KV-cache
///
Expand Down
Loading