Skip to content

Commit 16d2aa9

Browse files
committed
add kv cache skeleton
Signed-off-by: Alex Chi <[email protected]>
1 parent 632721a commit 16d2aa9

File tree

7 files changed

+22
-31
lines changed

7 files changed

+22
-31
lines changed

src/tiny_llm_week1_ref/attention.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,6 @@ def scaled_dot_product_attention_grouped(
2323
scale: float | None = None,
2424
mask: mx.array | None = None,
2525
) -> mx.array:
26-
"""
27-
Compute scaled dot-product attention.
28-
29-
query: batch_size x
30-
"""
3126
factor = mx.rsqrt(query.shape[-1]) if scale is None else scale
3227
expected_shape = query.shape
3328
query = query.reshape(-1, query.shape[-3], query.shape[-2], query.shape[-1])
@@ -44,9 +39,7 @@ def scaled_dot_product_attention_grouped(
4439
if mask is not None:
4540
mask = mask.reshape(-1, H, n_repeats, mask.shape[-2], mask.shape[-1])
4641
scores = scores + mask
47-
result = mx.matmul(
48-
softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype), value
49-
)
42+
result = mx.matmul(softmax(scores, axis=-1), value)
5043
return result.reshape(expected_shape)
5144

5245

src/tiny_llm_week1_ref/qwen2.py

-4
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,11 @@ def __call__(
6666
.reshape(B, L, self.num_kv_heads, self.head_dim)
6767
.astype(mx.float32)
6868
)
69-
# offset = cache.offset
7069
projection_q = self.rope(projection_q, offset=slice(offset, offset + L))
7170
projection_k = self.rope(projection_k, offset=slice(offset, offset + L))
7271
projection_q = projection_q.transpose(0, 2, 1, 3)
7372
projection_k = projection_k.transpose(0, 2, 1, 3)
7473
projection_v = projection_v.transpose(0, 2, 1, 3)
75-
# TODO: it is possible to get a sensible result without using a kv-cache? Otherwise we have to include kv-cache in week 1.
76-
# mlx-lm's KvCache seems to do more than just caching, we could extract something out of it.
77-
# projection_k, projection_v = cache.update_and_fetch(projection_k, projection_v)
7874
assert (
7975
projection_k.dtype == mx.float32
8076
) # TODO: can we use float16? also a test framework to ensure all data types are casted correctly.

src/tiny_llm_week2_ref/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .quantize import *
77
from .qwen2 import *
88
from .generate import *
9+
from .kv_cache import *

src/tiny_llm_week2_ref/attention.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,6 @@ def scaled_dot_product_attention_grouped(
2323
scale: float | None = None,
2424
mask: mx.array | None = None,
2525
) -> mx.array:
26-
"""
27-
Compute scaled dot-product attention.
28-
29-
query: batch_size x
30-
"""
3126
factor = mx.rsqrt(query.shape[-1]) if scale is None else scale
3227
expected_shape = query.shape
3328
query = query.reshape(-1, query.shape[-3], query.shape[-2], query.shape[-1])
@@ -44,9 +39,7 @@ def scaled_dot_product_attention_grouped(
4439
if mask is not None:
4540
mask = mask.reshape(-1, H, n_repeats, mask.shape[-2], mask.shape[-1])
4641
scores = scores + mask
47-
result = mx.matmul(
48-
softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype), value
49-
)
42+
result = mx.matmul(softmax(scores, axis=-1), value)
5043
return result.reshape(expected_shape)
5144

5245

src/tiny_llm_week2_ref/generate.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import mlx.core as mx
22
from .qwen2 import Qwen2Model
33
from mlx_lm.tokenizer_utils import TokenizerWrapper
4+
from .kv_cache import TinyKvCache
45

56

67
def simple_generate(model: Qwen2Model, tokenizer: TokenizerWrapper, prompt: str) -> str:
8+
kv_cache = [TinyKvCache() for _ in range(model.num_hidden_layers)]
9+
710
def _step(model, y, offset):
8-
logits = model(y[None], offset)
11+
logits = model(y[None], offset, kv_cache)
912
logits = logits[:, -1, :]
1013
logprobs = logits - mx.logsumexp(logits, keepdims=True)
1114
sampler = lambda x: mx.argmax(x, axis=-1)

src/tiny_llm_week2_ref/kv_cache.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from typing import Optional
2+
3+
import mlx.core as mx
4+
5+
6+
class TinyKvCache:
7+
def update_and_fetch(self, key: mx.array, value: mx.array, offset: int) -> mx.array:
8+
pass

src/tiny_llm_week2_ref/qwen2.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77
from .embedding import Embedding
88
from .quantize import dequantize_linear
9+
from .kv_cache import TinyKvCache
910

1011

1112
class Qwen2MultiHeadAttention:
@@ -48,6 +49,7 @@ def __call__(
4849
self,
4950
x: mx.array,
5051
offset: int,
52+
cache: TinyKvCache,
5153
) -> mx.array:
5254
B, L, _ = x.shape
5355
orig_dtype = x.dtype
@@ -66,19 +68,12 @@ def __call__(
6668
.reshape(B, L, self.num_kv_heads, self.head_dim)
6769
.astype(mx.float32)
6870
)
69-
# offset = cache.offset
7071
projection_q = self.rope(projection_q, offset=slice(offset, offset + L))
7172
projection_k = self.rope(projection_k, offset=slice(offset, offset + L))
7273
projection_q = projection_q.transpose(0, 2, 1, 3)
7374
projection_k = projection_k.transpose(0, 2, 1, 3)
7475
projection_v = projection_v.transpose(0, 2, 1, 3)
75-
# TODO: it is possible to get a sensible result without using a kv-cache? Otherwise we have to include kv-cache in week 1.
76-
# mlx-lm's KvCache seems to do more than just caching, we could extract something out of it.
77-
# projection_k, projection_v = cache.update_and_fetch(projection_k, projection_v)
78-
assert (
79-
projection_k.dtype == mx.float32
80-
) # TODO: can we use float16? also a test framework to ensure all data types are casted correctly.
81-
assert projection_v.dtype == mx.float32
76+
projection_k, projection_v = cache.update_and_fetch(projection_k, projection_v)
8277
x = scaled_dot_product_attention_grouped(
8378
projection_q,
8479
projection_k,
@@ -157,8 +152,9 @@ def __call__(
157152
self,
158153
x: mx.array,
159154
offset: int,
155+
cache: TinyKvCache,
160156
) -> mx.array:
161-
r = self.self_attn(self.input_layernorm(x), offset)
157+
r = self.self_attn(self.input_layernorm(x), offset, cache)
162158
h = x + r
163159
r = self.mlp(self.post_attention_layernorm(h))
164160
out = h + r
@@ -230,9 +226,10 @@ def __call__(
230226
self,
231227
inputs: mx.array,
232228
offset: int,
229+
cache: list[TinyKvCache],
233230
) -> mx.array:
234231
h = self.embedding(inputs)
235232
for layer in range(self.num_hidden_layers):
236-
h = self.layers_inner[layer](h, offset)
233+
h = self.layers_inner[layer](h, offset, cache[layer])
237234
h = self.norm(h)
238235
return linear(h, self.w_lm_head)

0 commit comments

Comments
 (0)