-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Generic KvCache #3188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Generic KvCache #3188
Conversation
…default behaviour
…dle into generic-kvcache-proposal
| pub trait KvCache { | ||
| type Mask; | ||
| fn new(dim: usize, max_seq_len: usize) -> Self; | ||
| fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does append insert into the cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably worth adding some doc strings here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it can depend on the cache implementation, but typically you append to the kv cache. It's not like your typical cache ala hashmap where you insert by key.
| dim: usize, | ||
| current_seq_len: usize, | ||
| grow_by: usize, | ||
| increment: usize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aligns better with the new name
| pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { | ||
| self.k.append(k)?; | ||
| self.v.append(v)?; | ||
| self.k.append(&k.contiguous()?)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious to know what the rationale for adding contiguous here is 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the default behaviour in the model code I used as reference. Is there any reason it shouldn't call contiguous?
I could omit it but it improves performance for the cache impls I tested, and calling contiguous multiple times has almost zero cost (if tensor is already contiguous it just clones an Arc)
| fn append_with_mask( | ||
| &mut self, | ||
| k: &Tensor, | ||
| v: &Tensor, | ||
| _mask: Option<&Self::Mask>, | ||
| ) -> Result<(Tensor, Tensor)> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of having mask being an Option here when we already have append[_without_mask]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True.
I initially added optional mask to append and later realized it is better expressed as a separate fn.
| } | ||
|
|
||
| fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { | ||
| self.append_with_mask(k, v, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see here we could just call candle::bail! instead of having to do that in append_with_mask. feels a bit clunky
|
|
||
| #[derive(Debug, Clone)] | ||
| pub struct Cache { | ||
| pub struct InnerCache { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Cache was ok as a name, although I get the sentiment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's only used in one kv cache variant, and not even the default one, so it feels wrong that it should be the Cache
| } | ||
|
|
||
| impl LayerWeights { | ||
| impl<C: KvCache> LayerWeights<C> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it need to be generic? It feels like we're expecting a specific behaviour from the cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes you want the KvCache variant that strictly has the highest throughput. Sometimes you want one that is more careful about how it consumes memory. Etc
|
|
||
| #[derive(Debug, Clone)] | ||
| pub struct ModelWeights { | ||
| pub struct ModelWeights<C: KvCache = DefaultKvCache> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this means "if not specified, use DefaultKvCache"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
By default you get ConcatKvCache, which has the highest throughput but grows indefinitely (until you reset() in the model impl)
|
Don't know how visible attention.rs might be to the larger ecosystem but the KVCache implementation therein appears to work quite well for older and modern (NV at least) hardware and already has variable data type and offloaded scaling support. Given the tight coupling between attention and kvcache, is it worth considering the entire library as a candle ecosystem candidate (would need some "genericization" for ease of adoption by other projects)? |
|
Yep! I actually mentioned this PR to @guoqingbao just the other day for this very reason |
Thanks for including me. I'm wondering whether we can pre-allocate a large KV-cache tensor and copy each new KV-cache slice into the existing tensor. This would avoid repeatedly concatenating and destroying tensors, especially as the context length grows significantly. A skeleton: use candle::{Result, Tensor, Device, DType};
use candle::index::IndexOp;
pub struct KvCache {
pub k: Tensor, // [max_seq_len, num_heads, head_dim]
pub v: Tensor, // [max_seq_len, num_heads, head_dim]
pub max_seq_len: usize,
pub cur_len: usize,
}
impl KvCache {
/// Preallocate empty KV cache for one layer
pub fn new(
max_seq_len: usize,
num_heads: usize,
head_dim: usize,
device: &Device,
dtype: DType,
) -> Result<Self> {
let shape = (max_seq_len, num_heads, head_dim);
Ok(Self {
k: Tensor::zeros(shape, dtype, device)?,
v: Tensor::zeros(shape, dtype, device)?,
max_seq_len,
cur_len: 0,
})
}
/// Insert a new KV slice for the next token.
/// new_k/new_v shape: [num_heads, head_dim]
pub fn insert(&mut self, new_k: &Tensor, new_v: &Tensor) -> Result<()> {
let pos = self.cur_len;
if pos >= self.max_seq_len {
candle::bail!(
"KV cache full — need resize or rotary window",
);
}
// k[pos] = new_k
self.k = self.k.assign((pos, .., ..), new_k)?;
// v[pos] = new_v
self.v = self.v.assign((pos, .., ..), new_v)?;
self.cur_len += 1;
Ok(())
}
}I think we may need to update the current SDPA attention to support this layout, or use the existing FlashAttention implementation, which accepts tensors shaped as |
|
Just to add some context here — some of the impetus for this generic KvCache trait came from my benchmarks showing ConcatKvCache/Existing custom cache for Llama3 (no pre-allocation) was 2-5x faster on CUDA compared to the previous pre-allocated + slice approach, with the gap growing as sequence length increased (tested 300-2000 tokens). On CPU and WASM32, concat matched pre-allocation performance. |
Interesting - i wonder how that concat performance thing pans out on older HW like V100s where random/async access isn't the same. At a high level (with software architect hat on sideways while i grok machinery of this ecosystem), my hope is that attention implementations would reside within attention.rs and the various KV cache implementations being nearly integral to attention mechanics would live there or in a tightly-coupled sister crate, providing a common abstraction to consumers in the ecosystem. CUDA and BlackHole RISC-V accelerators don't handle data in the same layouts or blocks necessarily so having an interoperable KV construct which would map to attention implementations relevant to the layout of different hardware classes/their compiler or input abstractions should allow the ecosystem to matrix-test/benchmark/qualify accuracy across those as new hardware classes become supported and attn implementations and cache layouts are added. Since @guoqingbao works with all sorts of hardware, that library is less likely (IMO) to suffer from the "new and shiny causes performance regressions on actually available" problem we often see w/ CUDA itself while wider adoption would provide the FOSS focus needed to help keep it in sync w upstream for Reasoning behind that train of thought is that attention efficacy is absolutely critical to output quality and no amount of performance benefit adds up to the time loss (or in the worst case, loss of life - critical paths exist and they are adopting) resulting from "workslop." We just did a panel w/ NV and one of the CSPs in Boston talking about what it will take for adoption in healthcare and everyone at the event with industry background concluded that scientific-grade reproducibility and accuracy are prerequisite to any formal adoption of language models and attention-based architectures in decision-making paths of research or healthcare in lifesci. Some use-cases effectively don't care about performance when the criticality of results mandates getting it right the first time. |
Thanks for the additional context — that’s really helpful. |
Just for full transparency - we have already had concat based kv cache in several models for a while (ref), it just wasn't moved into its own reusable struct. Which was very much needed! 👍 Most of the latency of the current |
This should let us (and users of candle) easily switch out KvCache implementations in models.
I explored using
dyn KvCacheinstead, but it quickly became a pain, so generics it is.