-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Generic KvCache #3188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Generic KvCache #3188
Changes from all commits
b68f8cb
62353eb
572c09f
459539f
60e1a01
e8c6563
6aa56be
459f191
e50656d
e3feb24
1f0cae4
9cee4bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,26 +2,43 @@ | |
| //! | ||
| use candle::{DType, Device, Result, Tensor}; | ||
|
|
||
| pub type DefaultKvCache = ConcatKvCache; | ||
|
|
||
| pub trait KvCache { | ||
| type Mask; | ||
| fn new(dim: usize, max_seq_len: usize) -> Self; | ||
| fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)>; | ||
| fn append_with_mask( | ||
| &mut self, | ||
| k: &Tensor, | ||
| v: &Tensor, | ||
| _mask: Option<&Self::Mask>, | ||
| ) -> Result<(Tensor, Tensor)> { | ||
|
Comment on lines
+11
to
+16
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the point of having
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True. |
||
| self.append(k, v) | ||
| } | ||
| fn reset(&mut self); | ||
| } | ||
|
|
||
| #[derive(Debug, Clone)] | ||
| pub struct Cache { | ||
| pub struct InnerCache { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it's only used in one kv cache variant, and not even the default one, so it feels wrong that it should be the |
||
| // all_data is an option on a Tensor, this makes it possible to only create the actual tensor | ||
| // on the first call where the batch size is easily known. | ||
| // Also this makes it safe to clone a KvCache that has been reset (as in it will not share | ||
| // its internal state with the cloned instance). | ||
| all_data: Option<Tensor>, | ||
| dim: usize, | ||
| current_seq_len: usize, | ||
| grow_by: usize, | ||
| increment: usize, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why the change? |
||
| max_seq_len: usize, | ||
| } | ||
|
|
||
| impl Cache { | ||
| impl InnerCache { | ||
| pub fn new(dim: usize, max_seq_len: usize) -> Self { | ||
| Self { | ||
| all_data: None, | ||
| dim, | ||
| current_seq_len: 0, | ||
| grow_by: max_seq_len, | ||
| increment: max_seq_len, | ||
| max_seq_len, | ||
| } | ||
| } | ||
|
|
@@ -68,10 +85,10 @@ impl Cache { | |
| let ad = self.all_data.as_mut().unwrap(); | ||
| while self.current_seq_len + seq_len > self.max_seq_len { | ||
| let mut shape = src.dims().to_vec(); | ||
| shape[self.dim] = self.grow_by; | ||
| shape[self.dim] = self.increment; | ||
| let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?; | ||
| *ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?; | ||
| self.max_seq_len += self.grow_by; | ||
| self.max_seq_len += self.increment; | ||
| } | ||
| ad.slice_set(src, self.dim, self.current_seq_len)?; | ||
| self.current_seq_len += seq_len; | ||
|
|
@@ -80,31 +97,46 @@ impl Cache { | |
| } | ||
|
|
||
| #[derive(Debug, Clone)] | ||
| pub struct KvCache { | ||
| k: Cache, | ||
| v: Cache, | ||
| pub struct IncrementalKvCache { | ||
| k: InnerCache, | ||
| v: InnerCache, | ||
| } | ||
|
|
||
| impl KvCache { | ||
| impl KvCache for IncrementalKvCache { | ||
| type Mask = (); | ||
| fn new(dim: usize, max_seq_len: usize) -> Self { | ||
| IncrementalKvCache::new(dim, max_seq_len) | ||
| } | ||
|
|
||
| fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { | ||
| self.append(k, v) | ||
| } | ||
|
|
||
| fn reset(&mut self) { | ||
| self.reset() | ||
| } | ||
| } | ||
|
|
||
| impl IncrementalKvCache { | ||
| pub fn new(dim: usize, max_seq_len: usize) -> Self { | ||
| let k = Cache::new(dim, max_seq_len); | ||
| let v = Cache::new(dim, max_seq_len); | ||
| let k = InnerCache::new(dim, max_seq_len); | ||
| let v = InnerCache::new(dim, max_seq_len); | ||
| Self { k, v } | ||
| } | ||
|
|
||
| pub fn k_cache(&self) -> &Cache { | ||
| pub fn k_cache(&self) -> &InnerCache { | ||
| &self.k | ||
| } | ||
|
|
||
| pub fn v_cache(&self) -> &Cache { | ||
| pub fn v_cache(&self) -> &InnerCache { | ||
| &self.v | ||
| } | ||
|
|
||
| pub fn k_cache_mut(&mut self) -> &mut Cache { | ||
| pub fn k_cache_mut(&mut self) -> &mut InnerCache { | ||
| &mut self.k | ||
| } | ||
|
|
||
| pub fn v_cache_mut(&mut self) -> &mut Cache { | ||
| pub fn v_cache_mut(&mut self) -> &mut InnerCache { | ||
| &mut self.v | ||
| } | ||
|
|
||
|
|
@@ -117,8 +149,8 @@ impl KvCache { | |
| } | ||
|
|
||
| pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { | ||
| self.k.append(k)?; | ||
| self.v.append(v)?; | ||
| self.k.append(&k.contiguous()?)?; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious to know what the rationale for adding
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the default behaviour in the model code I used as reference. Is there any reason it shouldn't call contiguous? |
||
| self.v.append(&v.contiguous()?)?; | ||
| let out_k = self.k.current_data()?; | ||
| let out_v = self.v.current_data()?; | ||
| let k = match out_k { | ||
|
|
@@ -338,6 +370,21 @@ pub struct RotatingKvCache { | |
| v: RotatingCache, | ||
| } | ||
|
|
||
| impl KvCache for RotatingKvCache { | ||
| type Mask = (); | ||
| fn new(dim: usize, max_seq_len: usize) -> Self { | ||
| RotatingKvCache::new(dim, max_seq_len) | ||
| } | ||
|
|
||
| fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { | ||
| self.append(k, v) | ||
| } | ||
|
|
||
| fn reset(&mut self) { | ||
| self.reset() | ||
| } | ||
| } | ||
|
|
||
| impl RotatingKvCache { | ||
| pub fn new(dim: usize, max_seq_len: usize) -> Self { | ||
| let k = RotatingCache::new(dim, max_seq_len); | ||
|
|
@@ -414,34 +461,90 @@ impl IndicesAndMask { | |
|
|
||
| #[derive(Debug, Clone)] | ||
| pub struct ScatteredKvCache { | ||
| k: Tensor, | ||
| v: Tensor, | ||
| k: Option<Tensor>, | ||
| v: Option<Tensor>, | ||
| dim: usize, | ||
| context: usize, | ||
| } | ||
| impl KvCache for ScatteredKvCache { | ||
| type Mask = IndicesAndMask; | ||
|
|
||
| fn new(dim: usize, max_seq_len: usize) -> Self { | ||
| ScatteredKvCache::new(dim, max_seq_len) | ||
| } | ||
|
|
||
| fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { | ||
| self.append_with_mask(k, v, None) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see here we could just call |
||
| } | ||
|
|
||
| fn append_with_mask( | ||
| &mut self, | ||
| k: &Tensor, | ||
| v: &Tensor, | ||
| mask: Option<&Self::Mask>, | ||
| ) -> Result<(Tensor, Tensor)> { | ||
| if let Some(mask) = mask { | ||
| self.scattered_append(k, v, mask) | ||
| } else { | ||
| candle::bail!("ScatteredKvCache requires InidicesAndMask") | ||
| } | ||
| } | ||
|
|
||
| fn reset(&mut self) { | ||
| self.reset() | ||
| } | ||
| } | ||
|
|
||
| impl ScatteredKvCache { | ||
| pub fn append( | ||
| pub fn new(dim: usize, context: usize) -> Self { | ||
| Self { | ||
| k: None, | ||
| v: None, | ||
| dim, | ||
| context, | ||
| } | ||
| } | ||
|
|
||
| pub fn scattered_append( | ||
| &mut self, | ||
| k: &Tensor, | ||
| v: &Tensor, | ||
| iam: &IndicesAndMask, | ||
| ) -> Result<(Tensor, Tensor)> { | ||
| if self.context <= k.dim(2)? { | ||
| if self.context <= k.dim(self.dim)? { | ||
| return Ok((k.clone(), v.clone())); | ||
| } | ||
| let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?; | ||
| if self.k.is_none() { | ||
| let mut k_shape = k.dims().to_vec(); | ||
| k_shape[self.dim] = self.context; | ||
| self.k = Some(Tensor::zeros(k_shape.clone(), k.dtype(), k.device())?); | ||
| } | ||
| if self.v.is_none() { | ||
| let mut v_shape = v.dims().to_vec(); | ||
| v_shape[self.dim] = self.context; | ||
| self.v = Some(Tensor::zeros(v_shape.clone(), v.dtype(), v.device())?); | ||
| } | ||
|
|
||
| let indices = iam.indices.unsqueeze(self.dim)?.unsqueeze(1)?; | ||
| let indices = indices.broadcast_as(k.shape())?.contiguous()?; | ||
| self.k.scatter_set(&indices, k, 2)?; | ||
| self.v.scatter_set(&indices, v, 2)?; | ||
| Ok((self.k.clone(), self.v.clone())) | ||
| let new_k = self.k.as_mut().unwrap(); | ||
| let new_v = self.v.as_mut().unwrap(); | ||
| new_k.scatter_set(&indices, k, self.dim)?; | ||
| new_v.scatter_set(&indices, v, self.dim)?; | ||
| Ok((new_k.clone(), new_v.clone())) | ||
| } | ||
|
|
||
| pub fn k(&self) -> &Tensor { | ||
| &self.k | ||
| pub fn k(&self) -> Option<&Tensor> { | ||
| self.k.as_ref() | ||
| } | ||
|
|
||
| pub fn v(&self) -> &Tensor { | ||
| &self.v | ||
| pub fn v(&self) -> Option<&Tensor> { | ||
| self.v.as_ref() | ||
| } | ||
|
|
||
| pub fn reset(&mut self) { | ||
| self.k = None; | ||
| self.v = None; | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -469,16 +572,8 @@ impl ScatteredCacheBuilder { | |
| }) | ||
| } | ||
|
|
||
| pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> { | ||
| let batch_size = self.batch_size(); | ||
| let shape = (batch_size, num_heads, self.context, head_dim); | ||
| let k = Tensor::zeros(shape, self.dtype, self.device())?; | ||
| let v = Tensor::zeros(shape, self.dtype, self.device())?; | ||
| Ok(ScatteredKvCache { | ||
| k, | ||
| v, | ||
| context: self.context, | ||
| }) | ||
| pub fn make_cache(&self, head_dim: usize) -> ScatteredKvCache { | ||
| ScatteredKvCache::new(head_dim, self.context) | ||
| } | ||
|
|
||
| pub fn positions(&self) -> &[usize] { | ||
|
|
@@ -499,7 +594,6 @@ impl ScatteredCacheBuilder { | |
| self.indices[batch_index] = 0; | ||
| } | ||
|
|
||
| #[allow(clippy::needless_range_loop)] | ||
| pub fn indices_and_mask( | ||
| &mut self, | ||
| seq_len: usize, | ||
|
|
@@ -525,18 +619,26 @@ impl ScatteredCacheBuilder { | |
| let mut indices = Vec::with_capacity(seq_len); | ||
| let mut all_pos = vec![usize::MAX; context]; | ||
| if start_pos < context { | ||
| for i in 0..start_pos { | ||
| all_pos[i] = i; | ||
| } | ||
| all_pos | ||
| .iter_mut() | ||
| .enumerate() | ||
| .take(start_pos) | ||
| .for_each(|(i, p)| { | ||
| *p = i; | ||
| }); | ||
| } else { | ||
| let offset = start_pos - start_index; | ||
| for i in 0..context { | ||
| all_pos[i] = if i < start_index { | ||
| i + offset | ||
| } else { | ||
| i + offset - context | ||
| }; | ||
| } | ||
| all_pos | ||
| .iter_mut() | ||
| .enumerate() | ||
| .take(context) | ||
| .for_each(|(i, p)| { | ||
| *p = if i < start_index { | ||
| i + offset | ||
| } else { | ||
| i + offset - context | ||
| }; | ||
| }); | ||
| } | ||
| for seq_i in 0..seq_len { | ||
| let index = self.indices[batch_i]; | ||
|
|
@@ -584,7 +686,6 @@ impl ScatteredCacheBuilder { | |
| &self.device | ||
| } | ||
|
|
||
| #[allow(clippy::needless_range_loop)] | ||
| fn indices_and_mask_abs( | ||
| &mut self, | ||
| seq_len: usize, | ||
|
|
@@ -642,7 +743,7 @@ impl ScatteredCacheBuilder { | |
| /// - GPU inference (CUDA, Metal) | ||
| /// - Autoregressive generation (token-by-token decoding) | ||
| /// | ||
| /// **Use `KvCache` instead for:** | ||
| /// **Use `IncrementalKvCache` instead for:** | ||
| /// - CPU-only inference | ||
| /// - When you need fixed memory allocation upfront | ||
| /// | ||
|
|
@@ -670,6 +771,22 @@ pub struct ConcatKvCache { | |
| dim: usize, | ||
| } | ||
|
|
||
| impl KvCache for ConcatKvCache { | ||
| type Mask = (); | ||
|
|
||
| fn new(dim: usize, _: usize) -> Self { | ||
| ConcatKvCache::new(dim) | ||
| } | ||
|
|
||
| fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { | ||
| self.append(k, v) | ||
| } | ||
|
|
||
| fn reset(&mut self) { | ||
| self.reset() | ||
| } | ||
| } | ||
|
|
||
| impl ConcatKvCache { | ||
| /// Create a new empty concatenation-based KV-cache | ||
| /// | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does
appendinsert into the cache?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably worth adding some doc strings here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it can depend on the cache implementation, but typically you append to the kv cache. It's not like your typical cache ala hashmap where you insert by key.