-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Generic KvCache #3188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Generic KvCache #3188
Conversation
…default behaviour
…dle into generic-kvcache-proposal
| pub trait KvCache { | ||
| type Mask; | ||
| fn new(dim: usize, max_seq_len: usize) -> Self; | ||
| fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does append insert into the cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably worth adding some doc strings here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it can depend on the cache implementation, but typically you append to the kv cache. It's not like your typical cache ala hashmap where you insert by key.
| dim: usize, | ||
| current_seq_len: usize, | ||
| grow_by: usize, | ||
| increment: usize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the change?
| pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { | ||
| self.k.append(k)?; | ||
| self.v.append(v)?; | ||
| self.k.append(&k.contiguous()?)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious to know what the rationale for adding contiguous here is 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the default behaviour in the model code I used as reference. Is there any reason it shouldn't call contiguous?
I could omit it but it improves performance for the cache impls I tested, and calling contiguous multiple times has almost zero cost (if tensor is already contiguous it just clones an Arc)
| fn append_with_mask( | ||
| &mut self, | ||
| k: &Tensor, | ||
| v: &Tensor, | ||
| _mask: Option<&Self::Mask>, | ||
| ) -> Result<(Tensor, Tensor)> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of having mask being an Option here when we already have append[_without_mask]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True.
I initially added optional mask to append and later realized it is better expressed as a separate fn.
| } | ||
|
|
||
| fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { | ||
| self.append_with_mask(k, v, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see here we could just call candle::bail! instead of having to do that in append_with_mask. feels a bit clunky
|
|
||
| #[derive(Debug, Clone)] | ||
| pub struct Cache { | ||
| pub struct InnerCache { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Cache was ok as a name, although I get the sentiment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's only used in one kv cache variant, and not even the default one, so it feels wrong that it should be the Cache
| } | ||
|
|
||
| impl LayerWeights { | ||
| impl<C: KvCache> LayerWeights<C> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it need to be generic? It feels like we're expecting a specific behaviour from the cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes you want the KvCache variant that strictly has the highest throughput. Sometimes you want one that is more careful about how it consumes memory. Etc
|
|
||
| #[derive(Debug, Clone)] | ||
| pub struct ModelWeights { | ||
| pub struct ModelWeights<C: KvCache = DefaultKvCache> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this means "if not specified, use DefaultKvCache"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
By default you get ConcatKvCache, which has the highest throughput but grows indefinitely (until you reset() in the model impl)
This should let us (and users of candle) easily switch out KvCache implementations in models.
I explored using
dyn KvCacheinstead, but it quickly became a pain, so generics it is.