Skip to content

Conversation

@ivarflakstad
Copy link
Member

@ivarflakstad ivarflakstad commented Nov 14, 2025

This should let us (and users of candle) easily switch out KvCache implementations in models.

I explored using dyn KvCache instead, but it quickly became a pain, so generics it is.

@ivarflakstad ivarflakstad marked this pull request as ready for review November 20, 2025 11:33
pub trait KvCache {
type Mask;
fn new(dim: usize, max_seq_len: usize) -> Self;
fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does append insert into the cache?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth adding some doc strings here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it can depend on the cache implementation, but typically you append to the kv cache. It's not like your typical cache ala hashmap where you insert by key.

dim: usize,
current_seq_len: usize,
grow_by: usize,
increment: usize,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the change?

pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.k.append(k)?;
self.v.append(v)?;
self.k.append(&k.contiguous()?)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to know what the rationale for adding contiguous here is 👀

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the default behaviour in the model code I used as reference. Is there any reason it shouldn't call contiguous?
I could omit it but it improves performance for the cache impls I tested, and calling contiguous multiple times has almost zero cost (if tensor is already contiguous it just clones an Arc)

Comment on lines +11 to +16
fn append_with_mask(
&mut self,
k: &Tensor,
v: &Tensor,
_mask: Option<&Self::Mask>,
) -> Result<(Tensor, Tensor)> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of having mask being an Option here when we already have append[_without_mask]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True.
I initially added optional mask to append and later realized it is better expressed as a separate fn.

}

fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.append_with_mask(k, v, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see here we could just call candle::bail! instead of having to do that in append_with_mask. feels a bit clunky


#[derive(Debug, Clone)]
pub struct Cache {
pub struct InnerCache {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Cache was ok as a name, although I get the sentiment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's only used in one kv cache variant, and not even the default one, so it feels wrong that it should be the Cache

}

impl LayerWeights {
impl<C: KvCache> LayerWeights<C> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be generic? It feels like we're expecting a specific behaviour from the cache

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes you want the KvCache variant that strictly has the highest throughput. Sometimes you want one that is more careful about how it consumes memory. Etc


#[derive(Debug, Clone)]
pub struct ModelWeights {
pub struct ModelWeights<C: KvCache = DefaultKvCache> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this means "if not specified, use DefaultKvCache"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!
By default you get ConcatKvCache, which has the highest throughput but grows indefinitely (until you reset() in the model impl)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants