Skip to content

Commit bfdb834

Browse files
committed
add kvcache
Signed-off-by: Alex Chi Z <[email protected]>
1 parent 16d2aa9 commit bfdb834

File tree

11 files changed

+142
-50
lines changed

11 files changed

+142
-50
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ You may join skyzh's Discord server and study with the tiny-llm community.
2929
| 1.5 | Transformer Block || 🚧 | 🚧 |
3030
| 1.6 | Load the Model || 🚧 | 🚧 |
3131
| 1.7 | Generate Responses ||| 🚧 |
32-
| 2.1 | KV Cache | 🚧 | 🚧 | 🚧 |
32+
| 2.1 | KV Cache | | 🚧 | 🚧 |
3333
| 2.2 | Quantized Matmul and Linear (CPU) | 🚧 | 🚧 | 🚧 |
3434
| 2.3 | Quantized Matmul and Linear (Metal) | 🚧 | 🚧 | 🚧 |
3535
| 2.4 | Attention Kernel | 🚧 | 🚧 | 🚧 |

book/src/week1-01-attention.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ src/tiny_llm/attention.py
8383
* [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)
8484
* [PyTorch MultiHeadAttention API](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
8585
* [MLX MultiHeadAttention API](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MultiHeadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
86+
* [The Illustrated GPT-2 (Visualizing Transformer Language Models)](https://jalammar.github.io/illustrated-gpt2) helps you better understand what key, value, and query are.
8687

8788
Implement `MultiHeadAttention`. The layer takes a batch of vectors `x`, maps it through the K, V, Q weight matrixes, and
8889
use the attention function we implemented in day 1 to compute the result. The output needs to be mapped using the O

main_ref_impl_week2.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from mlx_lm import load
2+
from tiny_llm_week2_ref import Qwen2Model, simple_generate
3+
import mlx.core as mx
4+
5+
with mx.stream(mx.gpu):
6+
mlx_model, tokenizer = load(
7+
"Qwen/Qwen2-7B-Instruct-MLX",
8+
tokenizer_config={"eos_token": "<|im_end|>"},
9+
model_config={"tie_word_embeddings": False, "rope_traditional": True},
10+
)
11+
tiny_llm_model = Qwen2Model(mlx_model)
12+
13+
prompt = "Give me a short introduction to large language model."
14+
messages = [
15+
{"role": "system", "content": "You are a helpful assistant."},
16+
{"role": "user", "content": prompt},
17+
]
18+
prompt = tokenizer.apply_chat_template(
19+
messages, tokenize=False, add_generation_prompt=True
20+
)
21+
simple_generate(tiny_llm_model, tokenizer, prompt)

src/tiny_llm_week1_ref/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def scaled_dot_product_attention_grouped(
2323
scale: float | None = None,
2424
mask: mx.array | None = None,
2525
) -> mx.array:
26-
factor = mx.rsqrt(query.shape[-1]) if scale is None else scale
26+
factor = mx.rsqrt(query.shape[-1]) if scale is None else mx.array(scale)
27+
factor = factor.astype(query.dtype)
2728
expected_shape = query.shape
2829
query = query.reshape(-1, query.shape[-3], query.shape[-2], query.shape[-1])
2930
key = key.reshape(-1, key.shape[-3], key.shape[-2], key.shape[-1])

src/tiny_llm_week1_ref/generate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@ def _step(model, y, offset):
1414

1515
# prefill with the prompt
1616
tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False))
17-
offset = tokens.size
1817
detokenizer = tokenizer.detokenizer
1918
detokenizer.reset()
2019
# generate/decode
2120
while True:
22-
token, _ = _step(model, tokens, offset)
21+
token, _ = _step(model, tokens, tokens.size)
2322
tokens = mx.concat([tokens, token])
2423
if token.item() == tokenizer.eos_token_id:
2524
break

src/tiny_llm_week1_ref/positional_encoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,4 @@ def __call__(
5454
else:
5555
y = mx.concat([real, imag], axis=-1)
5656
y = y.reshape(N, S, H, D)
57-
return y
57+
return y.astype(x.dtype)

src/tiny_llm_week1_ref/qwen2.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,37 +50,26 @@ def __call__(
5050
offset: int,
5151
) -> mx.array:
5252
B, L, _ = x.shape
53-
orig_dtype = x.dtype
54-
projection_q = (
55-
linear(x, self.wq, bias=self.bq)
56-
.reshape(B, L, self.num_heads, self.head_dim)
57-
.astype(mx.float32)
53+
projection_q = linear(x, self.wq, bias=self.bq).reshape(
54+
B, L, self.num_heads, self.head_dim
5855
)
59-
projection_k = (
60-
linear(x, self.wk, bias=self.bk)
61-
.reshape(B, L, self.num_kv_heads, self.head_dim)
62-
.astype(mx.float32)
56+
projection_k = linear(x, self.wk, bias=self.bk).reshape(
57+
B, L, self.num_kv_heads, self.head_dim
6358
)
64-
projection_v = (
65-
linear(x, self.wv, bias=self.bv)
66-
.reshape(B, L, self.num_kv_heads, self.head_dim)
67-
.astype(mx.float32)
59+
projection_v = linear(x, self.wv, bias=self.bv).reshape(
60+
B, L, self.num_kv_heads, self.head_dim
6861
)
6962
projection_q = self.rope(projection_q, offset=slice(offset, offset + L))
7063
projection_k = self.rope(projection_k, offset=slice(offset, offset + L))
7164
projection_q = projection_q.transpose(0, 2, 1, 3)
7265
projection_k = projection_k.transpose(0, 2, 1, 3)
7366
projection_v = projection_v.transpose(0, 2, 1, 3)
74-
assert (
75-
projection_k.dtype == mx.float32
76-
) # TODO: can we use float16? also a test framework to ensure all data types are casted correctly.
77-
assert projection_v.dtype == mx.float32
7867
x = scaled_dot_product_attention_grouped(
79-
projection_q,
80-
projection_k,
81-
projection_v,
68+
projection_q.astype(mx.float32),
69+
projection_k.astype(mx.float32),
70+
projection_v.astype(mx.float32),
8271
scale=self.scale,
83-
).astype(orig_dtype)
72+
).astype(x.dtype)
8473
x = x.transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size)
8574
return linear(x, self.wo)
8675

src/tiny_llm_week2_ref/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def scaled_dot_product_attention_grouped(
2323
scale: float | None = None,
2424
mask: mx.array | None = None,
2525
) -> mx.array:
26-
factor = mx.rsqrt(query.shape[-1]) if scale is None else scale
26+
factor = mx.rsqrt(query.shape[-1]) if scale is None else mx.array(scale)
27+
factor = factor.astype(query.dtype)
2728
expected_shape = query.shape
2829
query = query.reshape(-1, query.shape[-3], query.shape[-2], query.shape[-1])
2930
key = key.reshape(-1, key.shape[-3], key.shape[-2], key.shape[-1])

src/tiny_llm_week2_ref/generate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import mlx.core as mx
22
from .qwen2 import Qwen2Model
33
from mlx_lm.tokenizer_utils import TokenizerWrapper
4-
from .kv_cache import TinyKvCache
4+
from .kv_cache import *
55

66

77
def simple_generate(model: Qwen2Model, tokenizer: TokenizerWrapper, prompt: str) -> str:
8-
kv_cache = [TinyKvCache() for _ in range(model.num_hidden_layers)]
8+
kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)]
99

1010
def _step(model, y, offset):
1111
logits = model(y[None], offset, kv_cache)
@@ -17,13 +17,14 @@ def _step(model, y, offset):
1717

1818
# prefill with the prompt
1919
tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False))
20-
offset = tokens.size
20+
offset = 0
2121
detokenizer = tokenizer.detokenizer
2222
detokenizer.reset()
2323
# generate/decode
2424
while True:
2525
token, _ = _step(model, tokens, offset)
26-
tokens = mx.concat([tokens, token])
26+
offset += tokens.size
27+
tokens = token
2728
if token.item() == tokenizer.eos_token_id:
2829
break
2930
detokenizer.add_token(token.item())

src/tiny_llm_week2_ref/kv_cache.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,89 @@
44

55

66
class TinyKvCache:
7-
def update_and_fetch(self, key: mx.array, value: mx.array, offset: int) -> mx.array:
7+
def update_and_fetch(
8+
self, key: mx.array, value: mx.array, offset: int
9+
) -> tuple[mx.array, mx.array]:
810
pass
11+
12+
13+
class TinyKvFullCache(TinyKvCache):
14+
def __init__(self):
15+
self.key_values = None
16+
17+
def update_and_fetch(
18+
self, key: mx.array, value: mx.array, offset: int
19+
) -> tuple[mx.array, mx.array]:
20+
if self.key_values is None:
21+
assert offset == 0
22+
self.key_values = (key, value)
23+
return key, value
24+
else:
25+
B, H, _, D = key.shape
26+
assert key.shape == value.shape
27+
prev_keys, prev_values = self.key_values
28+
assert prev_keys.shape == (B, H, offset, D)
29+
assert prev_values.shape == (B, H, offset, D)
30+
new_keys = mx.concat([prev_keys, key], axis=2)
31+
new_values = mx.concat([prev_values, value], axis=2)
32+
self.key_values = (new_keys, new_values)
33+
return new_keys, new_values
34+
35+
36+
class TinyKvRotatingCache(TinyKvCache):
37+
def __init__(self, max_seq_len: int):
38+
self.max_seq_len = max_seq_len
39+
self.key_values = None
40+
self.head = 0
41+
self.head_offset = 0
42+
43+
def update_and_fetch(
44+
self, key: mx.array, value: mx.array, offset: int
45+
) -> tuple[mx.array, mx.array]:
46+
if self.key_values is None:
47+
assert offset == 0
48+
B, H, L, D = key.shape
49+
assert L <= self.max_seq_len
50+
keys = mx.zeros((B, H, self.max_seq_len, D))
51+
values = mx.zeros((B, H, self.max_seq_len, D))
52+
keys[:, :, :L, :] = key
53+
values[:, :, :L, :] = value
54+
self.key_values = (keys, values)
55+
self.head = L
56+
self.head_offset = L
57+
return keys[:, :, :L, :], values[:, :, :L, :]
58+
else:
59+
B, H, L, D = key.shape
60+
assert key.shape == value.shape
61+
assert offset == self.head_offset
62+
assert L <= self.max_seq_len
63+
keys, values = self.key_values
64+
if self.head + L <= self.max_seq_len:
65+
keys[:, :, self.head : self.head + L, :] = key
66+
values[:, :, self.head : self.head + L, :] = value
67+
self.head += L
68+
self.head_offset += L
69+
else:
70+
fill_size = self.max_seq_len - self.head
71+
keys[:, :, self.head : self.max_seq_len, :] = key[:, :, :fill_size, :]
72+
values[:, :, self.head : self.max_seq_len, :] = value[
73+
:, :, :fill_size, :
74+
]
75+
remaining_size = L - fill_size
76+
keys[:, :, :remaining_size, :] = key[:, :, fill_size:, :]
77+
values[:, :, :remaining_size, :] = value[:, :, fill_size:, :]
78+
self.head = remaining_size
79+
self.head_offset += L
80+
self.key_values = (keys, values)
81+
if self.head_offset < self.max_seq_len:
82+
return keys[:, :, : self.head_offset, :], values[
83+
:, :, : self.head_offset, :
84+
]
85+
else:
86+
before_keys = keys[:, :, self.head_offset :, :]
87+
before_values = values[:, :, self.head_offset :, :]
88+
after_keys = keys[:, :, : self.head_offset, :]
89+
after_values = values[:, :, : self.head_offset, :]
90+
keys = mx.concat([after_keys, before_keys], axis=2)
91+
values = mx.concat([after_values, before_values], axis=2)
92+
return keys, values

0 commit comments

Comments
 (0)